More compute_allocation() fixes.
[opus.git] / libcelt / rate.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include <math.h>
38 #include "modes.h"
39 #include "cwrs.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 #include "entcode.h"
44 #include "rate.h"
45
46
47 #ifndef STATIC_MODES
48
49 /*Determines if V(N,K) fits in a 32-bit unsigned integer.
50   N and K are themselves limited to 15 bits.*/
51 static int fits_in32(int _n, int _k)
52 {
53    static const celt_int16 maxN[15] = {
54       32767, 32767, 32767, 1476, 283, 109,  60,  40,
55        29,  24,  20,  18,  16,  14,  13};
56    static const celt_int16 maxK[15] = {
57       32767, 32767, 32767, 32767, 1172, 238,  95,  53,
58        36,  27,  22,  18,  16,  15,  13};
59    if (_n>=14)
60    {
61       if (_k>=14)
62          return 0;
63       else
64          return _n <= maxN[_k];
65    } else {
66       return _k <= maxK[_n];
67    }
68 }
69
70 void compute_pulse_cache(CELTMode *m, int LM)
71 {
72    int i;
73    int curr=0;
74    int nbEntries=0;
75    int entryN[100], entryK[100], entryI[100];
76    const celt_int16 *eBands = m->eBands;
77    PulseCache *cache = &m->cache;
78    celt_int16 *cindex;
79    unsigned char *bits;
80
81    cindex = celt_alloc(sizeof(cache->index[0])*m->nbEBands*(LM+2));
82    cache->index = cindex;
83
84    /* Scan for all unique band sizes */
85    for (i=0;i<=LM+1;i++)
86    {
87       int j;
88       for (j=0;j<m->nbEBands;j++)
89       {
90          int k;
91          int N = (eBands[j+1]-eBands[j])<<i>>1;
92          cindex[i*m->nbEBands+j] = -1;
93          /* Find other bands that have the same size */
94          for (k=0;k<=i;k++)
95          {
96             int n;
97             for (n=0;n<m->nbEBands && (k!=i || n<j);n++)
98             {
99                if (N == (eBands[n+1]-eBands[n])<<k>>1)
100                {
101                   cindex[i*m->nbEBands+j] = cindex[k*m->nbEBands+n];
102                   break;
103                }
104             }
105          }
106          if (cache->index[i*m->nbEBands+j] == -1 && N!=0)
107          {
108             int K;
109             entryN[nbEntries] = N;
110             K = 0;
111             while (fits_in32(N,get_pulses(K+1)) && K<MAX_PSEUDO)
112                K++;
113             entryK[nbEntries] = K;
114             cindex[i*m->nbEBands+j] = curr;
115             entryI[nbEntries] = curr;
116
117             curr += K+1;
118             nbEntries++;
119          }
120       }
121    }
122    bits = celt_alloc(sizeof(unsigned char)*curr);
123    cache->bits = bits;
124    cache->size = curr;
125    /* Compute the cache for all unique sizes */
126    for (i=0;i<nbEntries;i++)
127    {
128       int j;
129       unsigned char *ptr = bits+entryI[i];
130       celt_int16 tmp[MAX_PULSES+1];
131       get_required_bits(tmp, entryN[i], get_pulses(entryK[i]), BITRES);
132       for (j=1;j<=entryK[i];j++)
133          ptr[j] = tmp[get_pulses(j)]-1;
134       ptr[0] = entryK[i];
135    }
136 }
137
138 #endif /* !STATIC_MODES */
139
140
141 #define ALLOC_STEPS 6
142
143 static inline int interp_bits2pulses(const CELTMode *m, int start, int end,
144       int *bits1, int *bits2, const int *thresh, int total, int *bits,
145       int *ebits, int *fine_priority, int len, int _C, int LM, void *ec, int encode, int prev)
146 {
147    int psum;
148    int lo, hi;
149    int i, j;
150    int logM;
151    const int C = CHANNELS(_C);
152    int codedBands=-1;
153    int alloc_floor;
154    int left, percoeff;
155    int done;
156    SAVE_STACK;
157
158    alloc_floor = C<<BITRES;
159
160    logM = LM<<BITRES;
161    lo = 0;
162    hi = 1<<ALLOC_STEPS;
163    for (i=0;i<ALLOC_STEPS;i++)
164    {
165       int mid = (lo+hi)>>1;
166       psum = 0;
167       done = 0;
168       for (j=start;j<end;j++)
169       {
170          int tmp = bits1[j] + (mid*bits2[j]>>ALLOC_STEPS);
171          /* Don't allocate more than we can actually use */
172          if (tmp >= thresh[j] && !done)
173          {
174             psum += tmp;
175          } else {
176             done = 1;
177             if (tmp >= alloc_floor)
178                psum += alloc_floor;
179          }
180       }
181       if (psum > total)
182          hi = mid;
183       else
184          lo = mid;
185    }
186    psum = 0;
187    /*printf ("interp bisection gave %d\n", lo);*/
188    done = 0;
189    for (j=start;j<end;j++)
190    {
191       int tmp = bits1[j] + (lo*bits2[j]>>ALLOC_STEPS);
192       if (tmp < thresh[j] || done)
193       {
194          done = 1;
195          if (tmp >= alloc_floor)
196             tmp = alloc_floor;
197          else
198             tmp = 0;
199       }
200       /* Don't allocate more than we can actually use */
201       tmp = IMIN(tmp, 64*C<<BITRES<<LM);
202       bits[j] = tmp;
203       psum += tmp;
204    }
205
206    /* Decide which bands to skip, working backwards from the end. */
207    for (codedBands=end;;codedBands--)
208    {
209       int band_width;
210       int band_bits;
211       int rem;
212       j = codedBands-1;
213       /*Figure out how many left-over bits we would be adding to this band.
214         This can include bits we've stolen back from higher, skipped bands.*/
215       left = total-psum;
216       percoeff = left/(m->eBands[codedBands]-m->eBands[start]);
217       left -= (m->eBands[codedBands]-m->eBands[start])*percoeff;
218       /* Never skip the first band: we'd be coding a bit to signal that we're
219           going to waste all the other bits.
220          This means we won't be using the extra bit we reserved to signal the
221           end of manual skipping, but that will get added back in by
222           quant_all_bands().*/
223       if (j<=start)
224          break;
225       rem = IMAX(left-(m->eBands[j]-m->eBands[start]),0);
226       band_width = m->eBands[codedBands]-m->eBands[j];
227       band_bits = bits[j] + percoeff*band_width + rem;
228       /*Only code a skip decision if we're above the threshold for this band.
229         Otherwise it is force-skipped.
230         This ensures that a) we have enough bits to code the skip flag and b)
231          there are actually some bits to redistribute.*/
232       if (band_bits >= IMAX(thresh[j], alloc_floor+(1<<BITRES)+1))
233       {
234          if (encode)
235          {
236             /*This if() block is the only part of the allocation function that
237                is not a mandatory part of the bitstream: any bands we choose to
238                skip here must be explicitly signaled.*/
239             /*Choose a threshold with some hysteresis to keep bands from
240                fluctuating in and out.*/
241             if (band_bits > ((j<prev?7:9)*band_width<<LM<<BITRES)>>4)
242             {
243                ec_enc_bit_prob((ec_enc *)ec, 1, 32768);
244                break;
245             }
246             ec_enc_bit_prob((ec_enc *)ec, 0, 32768);
247          } else if (ec_dec_bit_prob((ec_dec *)ec, 32768)) {
248             break;
249          }
250          /*We used a bit to skip this band.*/
251          psum += 1<<BITRES;
252          band_bits -= 1<<BITRES;
253       }
254       /*Reclaim the bits originally allocated to this band.*/
255       psum -= bits[j];
256       if (band_bits >= alloc_floor)
257       {
258          /*If we have enough for a fine energy bit per channel, use it.*/
259          psum += alloc_floor;
260          bits[j] = alloc_floor;
261       } else {
262          /*Otherwise this band gets nothing at all.*/
263          bits[j] = 0;
264       }
265    }
266
267    /* Allocate the remaining bits */
268    if (codedBands>start) {
269       for (j=start;j<codedBands;j++)
270          bits[j] += percoeff*(m->eBands[j+1]-m->eBands[j]);
271       for (j=start;j<codedBands;j++)
272       {
273          int tmp = IMIN(left, m->eBands[j+1]-m->eBands[j]);
274          bits[j] += tmp;
275          left -= tmp;
276       }
277    }
278    /*for (j=0;j<end;j++)printf("%d ", bits[j]);printf("\n");*/
279    for (j=start;j<codedBands;j++)
280    {
281       int N0, N, den;
282       int offset;
283       int NClogN;
284
285       celt_assert(bits[j] >= 0);
286       N0 = m->eBands[j+1]-m->eBands[j];
287       N=N0<<LM;
288       NClogN = N*C*(m->logN[j] + logM);
289
290       /* Compensate for the extra DoF in stereo */
291       den=(C*N+ ((C==2 && N>2) ? 1 : 0));
292
293       /* Offset for the number of fine bits by log2(N)/2 + FINE_OFFSET
294          compared to their "fair share" of total/N */
295       offset = (NClogN>>1)-N*C*FINE_OFFSET;
296
297       /* N=2 is the only point that doesn't match the curve */
298       if (N==2)
299          offset += N*C<<BITRES>>2;
300
301       /* Changing the offset for allocating the second and third fine energy bit */
302       if (bits[j] + offset < den*2<<BITRES)
303          offset += NClogN>>2;
304       else if (bits[j] + offset < den*3<<BITRES)
305          offset += NClogN>>3;
306
307       /* Divide with rounding */
308       ebits[j] = IMAX(0, (bits[j] + offset + (den<<(BITRES-1))) / (den<<BITRES));
309
310       /* If we rounded down, make it a candidate for final fine energy pass */
311       fine_priority[j] = ebits[j]*(den<<BITRES) >= bits[j]+offset;
312
313       /* For N=1, all bits go to fine energy except for a single sign bit */
314       if (N==1)
315       {
316          ebits[j] = IMAX(0,(bits[j]/C >> BITRES)-1);
317          fine_priority[j] = (ebits[j]+1)*C<<BITRES >= bits[j];
318       }
319       /* Make sure not to bust */
320       if (C*ebits[j] > (bits[j]>>BITRES))
321          ebits[j] = bits[j]/C >> BITRES;
322
323       /* More than that is useless because that's about as far as PVQ can go */
324       if (ebits[j]>7)
325          ebits[j]=7;
326
327       /* The other bits are assigned to PVQ */
328       bits[j] -= C*ebits[j]<<BITRES;
329       celt_assert(bits[j] >= 0);
330       celt_assert(ebits[j] >= 0);
331    }
332    /* The skipped bands use all their bits for fine energy. */
333    for (;j<end;j++)
334    {
335       ebits[j] = bits[j]/C >> BITRES;
336       celt_assert(C*ebits[j]<<BITRES == bits[j]);
337       bits[j] = 0;
338       fine_priority[j] = 0;
339    }
340    RESTORE_STACK;
341    return codedBands;
342 }
343
344 int compute_allocation(const CELTMode *m, int start, int end, int *offsets, int alloc_trim,
345       int total, int *pulses, int *ebits, int *fine_priority, int _C, int LM, void *ec, int encode, int prev)
346 {
347    int lo, hi, len, j;
348    const int C = CHANNELS(_C);
349    int codedBands;
350    VARDECL(int, bits1);
351    VARDECL(int, bits2);
352    VARDECL(int, thresh);
353    VARDECL(int, trim_offset);
354    SAVE_STACK;
355    
356    total = IMAX(total, 0);
357    len = m->nbEBands;
358    ALLOC(bits1, len, int);
359    ALLOC(bits2, len, int);
360    ALLOC(thresh, len, int);
361    ALLOC(trim_offset, len, int);
362
363    /* Below this threshold, we're sure not to allocate any PVQ bits */
364    for (j=start;j<end;j++)
365       thresh[j] = IMAX((C)<<BITRES, (3*(m->eBands[j+1]-m->eBands[j])<<LM<<BITRES)>>4);
366    /* Tilt of the allocation curve */
367    for (j=start;j<end;j++)
368       trim_offset[j] = C*(m->eBands[j+1]-m->eBands[j])*(alloc_trim-5-LM)*(m->nbEBands-j-1)
369             <<(LM+BITRES)>>6;
370
371    lo = 0;
372    hi = m->nbAllocVectors - 1;
373    while (hi-lo != 1)
374    {
375       int psum = 0;
376       int mid = (lo+hi) >> 1;
377       for (j=start;j<end;j++)
378       {
379          int N = m->eBands[j+1]-m->eBands[j];
380          bits1[j] = C*N*m->allocVectors[mid*len+j]<<LM>>2;
381          if (bits1[j] > 0)
382             bits1[j] += trim_offset[j];
383          if (bits1[j] < 0)
384             bits1[j] = 0;
385          bits1[j] += offsets[j];
386          if (bits1[j] >= thresh[j])
387             psum += bits1[j];
388          else if (bits1[j] >= C<<BITRES)
389             psum += C<<BITRES;
390
391          /*printf ("%d ", bits[j]);*/
392       }
393       /*printf ("\n");*/
394       if (psum > total)
395          hi = mid;
396       else
397          lo = mid;
398       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
399    }
400    /*printf ("interp between %d and %d\n", lo, hi);*/
401    for (j=start;j<end;j++)
402    {
403       int N = m->eBands[j+1]-m->eBands[j];
404       bits1[j] = (C*N*m->allocVectors[lo*len+j]<<LM>>2);
405       bits2[j] = (C*N*m->allocVectors[hi*len+j]<<LM>>2) - bits1[j];
406       if (bits1[j] > 0)
407          bits1[j] += trim_offset[j];
408       if (bits1[j] < 0)
409          bits1[j] = 0;
410       bits1[j] += offsets[j];
411    }
412    codedBands = interp_bits2pulses(m, start, end, bits1, bits2, thresh,
413          total, pulses, ebits, fine_priority, len, C, LM, ec, encode, prev);
414    RESTORE_STACK;
415    return codedBands;
416 }
417