Some bit-allocation tuning
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    /*printf ("in: %d %d ", val, L);*/
58    if (L>14)
59       val >>= L-14;
60    else if (L<14)
61       val <<= 14-L;
62    L <<= frac;
63    /*printf ("%d\n", val);*/
64    for (i=0;i<frac;i++)
65 {
66       val = (val*val) >> 15;
67       /*printf ("%d\n", val);*/
68       if (val > 16384)
69          L |= (1<<(frac-i-1));
70       else   
71          val <<= 1;
72 }
73    return L;
74 }
75
76 int fits_in32(int _n, int _m)
77 {
78    static const celt_int16_t maxN[15] = {
79       255, 255, 255, 255, 255, 109,  60,  40,
80        29,  24,  20,  18,  16,  14,  13};
81    static const celt_int16_t maxM[15] = {
82       255, 255, 255, 255, 255, 238,  95,  53,
83        36,  27,  22,  18,  16,  15,  13};
84    if (_n>=14)
85    {
86       if (_m>=14)
87          return 0;
88       else
89          return _n <= maxN[_m];
90    } else {
91       return _m <= maxM[_n];
92    }   
93 }
94
95 #define MASK32 (0xFFFFFFFF)
96
97 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
98 static const unsigned INV_TABLE[128]={
99   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
100   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
101   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
102   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
103   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
104   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
105   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
106   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
107   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
108   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
109   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
110   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
111   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
112   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
113   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
114   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
115   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
116   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
117   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
118   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
119   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
120   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
121   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
122   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
123   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
124   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
125   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
126   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
127   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
128   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
129   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
130   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
131 };
132
133 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
134   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
135    fits in 32 bits, but currently the table for multiplicative inverses is only
136    valid for _d<128.*/
137 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
138  celt_uint32_t _c,celt_uint32_t _d){
139   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
140 }
141
142 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
143   _d does not actually have to be even, but imusdiv32odd will be faster when
144    it's odd, so you should use that instead.
145   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
146    table for multiplicative inverses is only valid for _d<256).
147   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
148    32 bits.*/
149 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
150  celt_uint32_t _c,celt_uint32_t _d){
151   unsigned inv;
152   int      mask;
153   int      shift;
154   int      one;
155   shift=EC_ILOG(_d^_d-1);
156   inv=INV_TABLE[_d-1>>shift];
157   shift--;
158   one=1<<shift;
159   mask=one-1;
160   return (_a*(_b>>shift)-(_c>>shift)+
161    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
162 }
163
164 /*Computes the next row/column of any recurrence that obeys the relation
165    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
166   _ui0 is the base case for the new row/column.*/
167 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
168   celt_uint32_t ui1;
169   int           j;
170   /* doing a do-while would overrun the array if we had less than 2 samples */
171   j=1; do {
172     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
173     _ui[j-1]=_ui0;
174     _ui0=ui1;
175   } while (++j<_len);
176   _ui[j-1]=_ui0;
177 }
178
179 /*Computes the previous row/column of any recurrence that obeys the relation
180    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
181   _ui0 is the base case for the new row/column.*/
182 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
183   celt_uint32_t ui1;
184   int           j;
185   /* doing a do-while would overrun the array if we had less than 2 samples */
186   j=1; do {
187     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
188     _ui[j-1]=_ui0;
189     _ui0=ui1;
190   } while (++j<_n);
191   _ui[j-1]=_ui0;
192 }
193
194 /*Returns the number of ways of choosing _m elements from a set of size _n with
195    replacement when a sign bit is needed for each unique element.
196   On input, _u should be initialized to column (_m-1) of U(n,m).
197   On exit, _u will be initialized to column _m of U(n,m).*/
198 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
199   celt_uint32_t ret;
200   celt_uint32_t ui0;
201   celt_uint32_t ui1;
202   int           j;
203   ret=ui0=2;
204   celt_assert(_n>=2);
205   j=1; do {
206     ui1=_ui[j]+_ui[j-1]+ui0;
207     _ui[j-1]=ui0;
208     ui0=ui1;
209     ret+=ui0;
210   } while (++j<_n);
211   _ui[j-1]=ui0;
212   return ret;
213 }
214
215 /*Returns the number of ways of choosing _m elements from a set of size _n with
216    replacement when a sign bit is needed for each unique element.
217   _u: On exit, _u[i] contains U(i+1,_m).*/
218 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
219   celt_uint32_t ret;
220   celt_uint32_t um2;
221   int           k;
222   /*If _m==0, _u[] should be set to zero and the return should be 1.*/
223   celt_assert(_m>0);
224   /*We'll overflow our buffer unless _n>=2.*/
225   celt_assert(_n>=2);
226   um2=_u[0]=1;
227   if(_m<=6){
228     if(_m<2){
229       k=1;
230       do _u[k]=1;
231       while(++k<_n);
232     }
233     else{
234       k=1;
235       do _u[k]=(k<<1)+1;
236       while(++k<_n);
237       for(k=2;k<_m;k++)unext32(_u,_n,1);
238     }
239   }
240   else{
241     celt_uint32_t um1;
242     celt_uint32_t n2m1;
243     _u[1]=n2m1=um1=(_m<<1)-1;
244     for(k=2;k<_n;k++){
245       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
246       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2;
247       if(++k>=_n)break;
248       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1;
249     }
250   }
251   ret=1;
252   k=1;
253   do ret+=_u[k];
254   while(++k<_n);
255   return ret<<1;
256 }
257
258
259 /*Returns the _i'th combination of _m elements chosen from a set of size _n
260    with associated sign bits.
261   _y: Returns the vector of pulses.
262   _u: Must contain entries [1..._n] of column _m of U() on input.
263       Its contents will be destructively modified.*/
264 void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
265  celt_uint32_t *_u){
266   celt_uint32_t p;
267   celt_uint32_t q;
268   int           j;
269   int           k;
270   celt_assert(_n>0);
271   p=_nc;
272   q=0;
273   j=0;
274   k=_m;
275   do{
276     int s;
277     int yj;
278     p-=q;
279     q=_u[_n-j-1];
280     p-=q;
281     s=_i>=p;
282     if(s)_i-=p;
283     yj=k;
284     while(q>_i){
285       uprev32(_u,_n-j,--k>0);
286       p=q;
287       q=_u[_n-j-1];
288     }
289     _i-=q;
290     yj-=k;
291     _y[j]=yj-(yj<<1&-s);
292   }
293   while(++j<_n);
294 }
295
296 /*Returns the index of the given combination of _m elements chosen from a set
297    of size _n with associated sign bits.
298   _y:  The vector of pulses, whose sum of absolute values must be _m.
299   _nc: Returns V(_n,_m).*/
300 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
301  celt_uint32_t *_u){
302   celt_uint32_t nc;
303   celt_uint32_t i;
304   int           j;
305   int           k;
306   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
307   celt_assert(_n>=2);
308   nc=1;
309   i=_y[_n-1]<0;
310   _u[0]=0;
311   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
312   k=abs(_y[_n-1]);
313   j=_n-2;
314   nc+=_u[_m];
315   i+=_u[k];
316   k+=abs(_y[j]);
317   if(_y[j]<0)i+=_u[k+1];
318   while(j-->0){
319     unext32(_u,_m+2,0);
320     nc+=_u[_m];
321     i+=_u[k];
322     k+=abs(_y[j]);
323     if(_y[j]<0)i+=_u[k+1];
324   }
325   /*If _m==0, nc should not be doubled.*/
326   celt_assert(_m>0);
327   *_nc=nc<<1;
328   return i;
329 }
330
331 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
332   VARDECL(celt_uint32_t,u);
333   celt_uint32_t nc;
334   celt_uint32_t i;
335   SAVE_STACK;
336   ALLOC(u,_m+2,celt_uint32_t);
337   i=icwrs32(_n,_m,&nc,_y,u);
338   ec_enc_uint(_enc,i,nc);
339   RESTORE_STACK;
340 }
341
342 int get_required_bits(int N, int K, int frac)
343 {
344    int nbits = 0;
345    if(fits_in32(N,K))
346    {
347       VARDECL(celt_uint32_t,u);
348       SAVE_STACK;
349       ALLOC(u,N,celt_uint32_t);
350       nbits = log2_frac(ncwrs_u32(N,K,u), frac);
351       RESTORE_STACK;
352    } else {
353       nbits = log2_frac(N, frac);
354       nbits += get_required_bits(N/2+1, (K+1)/2, frac);
355       nbits += get_required_bits(N/2+1, K/2, frac);
356    }
357    return nbits;
358 }
359
360
361 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
362 {
363    if (K==0) {
364    } else if (N==1)
365    {
366       ec_enc_bits(enc, _y[0]<0, 1);
367    } else if(fits_in32(N,K))
368    {
369       encode_pulse32(N, K, _y, enc);
370    } else {
371      int i;
372      int count=0;
373      int split;
374      split = (N+1)/2;
375      for (i=0;i<split;i++)
376         count += abs(_y[i]);
377      ec_enc_uint(enc,count,K+1);
378      encode_pulses(_y, split, count, enc);
379      encode_pulses(_y+split, N-split, K-count, enc);
380    }
381 }
382
383 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
384   VARDECL(celt_uint32_t,u);
385   celt_uint32_t nc;
386   SAVE_STACK;
387   ALLOC(u,_n,celt_uint32_t);
388   nc=ncwrs_u32(_n,_m,u);
389   cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u);
390   RESTORE_STACK;
391 }
392
393 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
394 {
395    if (K==0) {
396       int i;
397       for (i=0;i<N;i++)
398          _y[i] = 0;
399    } else if (N==1)
400    {
401       int s = ec_dec_bits(dec, 1);
402       if (s==0)
403          _y[0] = K;
404       else
405          _y[0] = -K;
406    } else if(fits_in32(N,K))
407    {
408       decode_pulse32(N, K, _y, dec);
409    } else {
410      int split;
411      int count = ec_dec_uint(dec,K+1);
412      split = (N+1)/2;
413      decode_pulses(_y, split, count, dec);
414      decode_pulses(_y+split, N-split, K-count, dec);
415    }
416 }