This fixes a bunch of bit allocation bugs
[opus.git] / libcelt / rate.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include <math.h>
38 #include "modes.h"
39 #include "cwrs.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 #include "entcode.h"
44 #include "rate.h"
45
46
47 #ifndef STATIC_MODES
48
49 /*Determines if V(N,K) fits in a 32-bit unsigned integer.
50   N and K are themselves limited to 15 bits.*/
51 static int fits_in32(int _n, int _k)
52 {
53    static const celt_int16 maxN[15] = {
54       32767, 32767, 32767, 1476, 283, 109,  60,  40,
55        29,  24,  20,  18,  16,  14,  13};
56    static const celt_int16 maxK[15] = {
57       32767, 32767, 32767, 32767, 1172, 238,  95,  53,
58        36,  27,  22,  18,  16,  15,  13};
59    if (_n>=14)
60    {
61       if (_k>=14)
62          return 0;
63       else
64          return _n <= maxN[_k];
65    } else {
66       return _k <= maxK[_n];
67    }
68 }
69
70 void compute_pulse_cache(CELTMode *m, int LM)
71 {
72    int i;
73    int curr=0;
74    int nbEntries=0;
75    int entryN[100], entryK[100], entryI[100];
76    const celt_int16 *eBands = m->eBands;
77    PulseCache *cache = &m->cache;
78    celt_int16 *cindex;
79    unsigned char *bits;
80
81    cindex = celt_alloc(sizeof(cache->index[0])*m->nbEBands*(LM+2));
82    cache->index = cindex;
83
84    /* Scan for all unique band sizes */
85    for (i=0;i<=LM+1;i++)
86    {
87       int j;
88       for (j=0;j<m->nbEBands;j++)
89       {
90          int k;
91          int N = (eBands[j+1]-eBands[j])<<i>>1;
92          cindex[i*m->nbEBands+j] = -1;
93          /* Find other bands that have the same size */
94          for (k=0;k<=i;k++)
95          {
96             int n;
97             for (n=0;n<m->nbEBands && (k!=i || n<j);n++)
98             {
99                if (N == (eBands[n+1]-eBands[n])<<k>>1)
100                {
101                   cindex[i*m->nbEBands+j] = cindex[k*m->nbEBands+n];
102                   break;
103                }
104             }
105          }
106          if (cache->index[i*m->nbEBands+j] == -1 && N!=0)
107          {
108             int K;
109             entryN[nbEntries] = N;
110             K = 0;
111             while (fits_in32(N,get_pulses(K+1)) && K<MAX_PSEUDO)
112                K++;
113             entryK[nbEntries] = K;
114             cindex[i*m->nbEBands+j] = curr;
115             entryI[nbEntries] = curr;
116
117             curr += K+1;
118             nbEntries++;
119          }
120       }
121    }
122    bits = celt_alloc(sizeof(unsigned char)*curr);
123    cache->bits = bits;
124    cache->size = curr;
125    /* Compute the cache for all unique sizes */
126    for (i=0;i<nbEntries;i++)
127    {
128       int j;
129       unsigned char *ptr = bits+entryI[i];
130       celt_int16 tmp[MAX_PULSES+1];
131       get_required_bits(tmp, entryN[i], get_pulses(entryK[i]), BITRES);
132       for (j=1;j<=entryK[i];j++)
133          ptr[j] = tmp[get_pulses(j)]-1;
134       ptr[0] = entryK[i];
135    }
136 }
137
138 #endif /* !STATIC_MODES */
139
140
141 #define ALLOC_STEPS 6
142
143 static inline int interp_bits2pulses(const CELTMode *m, int start, int end,
144       int *bits1, int *bits2, const int *thresh, int total, int *bits,
145       int *ebits, int *fine_priority, int len, int _C, int LM, int *skip, int prev)
146 {
147    int psum;
148    int lo, hi;
149    int i, j;
150    int logM;
151    const int C = CHANNELS(_C);
152    int codedBands=-1;
153    int alloc_floor;
154    SAVE_STACK;
155
156    alloc_floor = C<<BITRES;
157
158    logM = LM<<BITRES;
159    lo = 0;
160    hi = 1<<ALLOC_STEPS;
161    for (i=0;i<ALLOC_STEPS;i++)
162    {
163       int mid = (lo+hi)>>1;
164       psum = 0;
165       for (j=start;j<end;j++)
166       {
167          int tmp = bits1[j] + (mid*bits2[j]>>ALLOC_STEPS);
168          /* Don't allocate more than we can actually use */
169          if (tmp >= thresh[j])
170             psum += tmp;
171          else if (tmp >= alloc_floor + (1<<BITRES))
172             psum += alloc_floor + (1<<BITRES);
173       }
174       if (psum > (total<<BITRES))
175          hi = mid;
176       else
177          lo = mid;
178    }
179    psum = 0;
180    /*printf ("interp bisection gave %d\n", lo);*/
181    for (j=start;j<end;j++)
182    {
183       int tmp = bits1[j] + (lo*bits2[j]>>ALLOC_STEPS);
184       if (tmp < thresh[j])
185       {
186          if (tmp >= alloc_floor + (1<<BITRES))
187             tmp = alloc_floor + (1<<BITRES);
188          else
189             tmp = 0;
190       }
191       /* Don't allocate more than we can actually use */
192       tmp = IMIN(tmp, 64*C<<BITRES<<LM);
193       bits[j] = tmp;
194       psum += tmp;
195    }
196    for (j=start;j<end;j++)
197    {
198       if (bits[j] < thresh[j])
199          break;
200    }
201    codedBands = j;
202
203    if (*skip==-1)
204    {
205       *skip=0;
206       for (j=codedBands-1;j>=0;j--)
207       {
208          if ((bits[j] > (7*(m->eBands[j+1]-m->eBands[j])<<LM<<BITRES)>>4 && j<prev)
209                || (bits[j] > (9*(m->eBands[j+1]-m->eBands[j])<<LM<<BITRES)>>4))
210             break;
211          else
212             (*skip)++;
213       }
214       *skip = IMIN(*skip, codedBands-start-1);
215    }
216    for (i=0;i<*skip;i++)
217    {
218       /* We add (1<<BITRES) to account for the skip bit */
219       psum = psum - bits[codedBands-1] + (1<<BITRES);
220       if (bits[codedBands-1] >= alloc_floor + (1<<BITRES))
221       {
222          psum += alloc_floor;
223          bits[codedBands-1] = alloc_floor;
224       } else {
225          bits[codedBands-1] = 0;
226       }
227       codedBands--;
228    }
229    /* Allocate the remaining bits */
230    if (codedBands) {
231       int left, perband;
232       left = (total<<BITRES)-psum;
233       perband = left/(m->eBands[codedBands]-m->eBands[start]);
234       for (j=start;j<codedBands;j++)
235          bits[j] += perband*(m->eBands[j+1]-m->eBands[j]);
236       left = left-(m->eBands[codedBands]-m->eBands[start])*perband;
237       for (j=start;j<codedBands;j++)
238       {
239          int tmp = IMIN(left, m->eBands[j+1]-m->eBands[j]);
240          bits[j] += tmp;
241          left -= tmp;
242       }
243    }
244    /*for (j=0;j<end;j++)printf("%d ", bits[j]);printf("\n");*/
245    for (j=start;j<end;j++)
246    {
247       int N0, N, den;
248       int offset;
249       int NClogN;
250
251       celt_assert(bits[j] >= 0);
252       N0 = m->eBands[j+1]-m->eBands[j];
253       N=N0<<LM;
254       NClogN = N*C*(m->logN[j] + logM);
255
256       /* Compensate for the extra DoF in stereo */
257       den=(C*N+ ((C==2 && N>2) ? 1 : 0));
258
259       /* Offset for the number of fine bits by log2(N)/2 + FINE_OFFSET
260          compared to their "fair share" of total/N */
261       offset = (NClogN>>1)-N*C*FINE_OFFSET;
262
263       /* N=2 is the only point that doesn't match the curve */
264       if (N==2)
265          offset += N*C<<BITRES>>2;
266
267       /* Changing the offset for allocating the second and third fine energy bit */
268       if (bits[j] + offset < den*2<<BITRES)
269          offset += NClogN>>2;
270       else if (bits[j] + offset < den*3<<BITRES)
271          offset += NClogN>>3;
272
273       /* Divide with rounding */
274       ebits[j] = IMAX(0, (bits[j] + offset + (den<<(BITRES-1))) / (den<<BITRES));
275
276       /* If we rounded down, make it a candidate for final fine energy pass */
277       fine_priority[j] = ebits[j]*(den<<BITRES) >= bits[j]+offset;
278
279       /* For N=1, all bits go to fine energy except for a single sign bit */
280       if (N==1)
281       {
282          ebits[j] = IMAX(0,(bits[j]/C >> BITRES)-1);
283          fine_priority[j] = (ebits[j]+1)*C<<BITRES >= bits[j];
284       }
285       /* Make sure not to bust */
286       if (C*ebits[j] > (bits[j]>>BITRES))
287          ebits[j] = bits[j]/C >> BITRES;
288
289       /* More than that is useless because that's about as far as PVQ can go */
290       if (ebits[j]>7)
291          ebits[j]=7;
292
293       /* The other bits are assigned to PVQ */
294       bits[j] -= C*ebits[j]<<BITRES;
295       celt_assert(bits[j] >= 0);
296       celt_assert(ebits[j] >= 0);
297    }
298    RESTORE_STACK;
299    return codedBands;
300 }
301
302 int compute_allocation(const CELTMode *m, int start, int end, int *offsets, int alloc_trim,
303       int total, int *pulses, int *ebits, int *fine_priority, int _C, int LM, int *skip, int prev)
304 {
305    int lo, hi, len, j;
306    const int C = CHANNELS(_C);
307    int codedBands;
308    VARDECL(int, bits1);
309    VARDECL(int, bits2);
310    VARDECL(int, thresh);
311    VARDECL(int, trim_offset);
312    SAVE_STACK;
313    
314    total = IMAX(total, 0);
315    len = m->nbEBands;
316    ALLOC(bits1, len, int);
317    ALLOC(bits2, len, int);
318    ALLOC(thresh, len, int);
319    ALLOC(trim_offset, len, int);
320
321    /* Below this threshold, we're sure not to allocate any PVQ bits */
322    for (j=start;j<end;j++)
323       thresh[j] = IMAX((C)<<BITRES, (3*(m->eBands[j+1]-m->eBands[j])<<LM<<BITRES)>>4);
324    /* Tilt of the allocation curve */
325    for (j=start;j<end;j++)
326       trim_offset[j] = C*(m->eBands[j+1]-m->eBands[j])*(alloc_trim-5-LM)*(m->nbEBands-j-1)
327             <<(LM+BITRES)>>6;
328
329    lo = 0;
330    hi = m->nbAllocVectors - 1;
331    while (hi-lo != 1)
332    {
333       int psum = 0;
334       int mid = (lo+hi) >> 1;
335       for (j=start;j<end;j++)
336       {
337          int N = m->eBands[j+1]-m->eBands[j];
338          bits1[j] = C*N*m->allocVectors[mid*len+j]<<LM>>2;
339          if (bits1[j] > 0)
340             bits1[j] += trim_offset[j];
341          if (bits1[j] < 0)
342             bits1[j] = 0;
343          bits1[j] += offsets[j];
344          if (bits1[j] >= thresh[j])
345             psum += bits1[j];
346          else if (bits1[j] >= C<<BITRES)
347             psum += C<<BITRES;
348
349          /*printf ("%d ", bits[j]);*/
350       }
351       /*printf ("\n");*/
352       if (psum > (total<<BITRES))
353          hi = mid;
354       else
355          lo = mid;
356       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
357    }
358    /*printf ("interp between %d and %d\n", lo, hi);*/
359    for (j=start;j<end;j++)
360    {
361       int N = m->eBands[j+1]-m->eBands[j];
362       bits1[j] = (C*N*m->allocVectors[lo*len+j]<<LM>>2);
363       bits2[j] = (C*N*m->allocVectors[hi*len+j]<<LM>>2) - bits1[j];
364       if (bits1[j] > 0)
365          bits1[j] += trim_offset[j];
366       if (bits1[j] < 0)
367          bits1[j] = 0;
368       bits1[j] += offsets[j];
369    }
370    codedBands = interp_bits2pulses(m, start, end, bits1, bits2, thresh,
371          total, pulses, ebits, fine_priority, len, C, LM, skip, prev);
372    RESTORE_STACK;
373    return codedBands;
374 }
375