Change cwrsi() to operate on rows of U instead of columns.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50
51 #if 0
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    /*printf ("in: %d %d ", val, L);*/
58    if (L>14)
59       val >>= L-14;
60    else if (L<14)
61       val <<= 14-L;
62    L <<= frac;
63    /*printf ("%d\n", val);*/
64    for (i=0;i<frac;i++)
65 {
66       val = (val*val) >> 15;
67       /*printf ("%d\n", val);*/
68       if (val > 16384)
69          L |= (1<<(frac-i-1));
70       else   
71          val <<= 1;
72 }
73    return L;
74 }
75 #endif
76
77 int log2_frac64(ec_uint64 val, int frac)
78 {
79    int i;
80    /* EC_ILOG64() actually returns log2()+1, go figure */
81    int L = EC_ILOG64(val)-1;
82    /*printf ("in: %d %d ", val, L);*/
83    if (L>14)
84       val >>= L-14;
85    else if (L<14)
86       val <<= 14-L;
87    L <<= frac;
88    /*printf ("%d\n", val);*/
89    for (i=0;i<frac;i++)
90    {
91       val = (val*val) >> 15;
92       /*printf ("%d\n", val);*/
93       if (val > 16384)
94          L |= (1<<(frac-i-1));
95       else   
96          val <<= 1;
97    }
98    return L;
99 }
100
101 int fits_in32(int _n, int _m)
102 {
103    static const celt_int16_t maxN[15] = {
104       255, 255, 255, 255, 255, 109,  60,  40,
105        29,  24,  20,  18,  16,  14,  13};
106    static const celt_int16_t maxM[15] = {
107       255, 255, 255, 255, 255, 238,  95,  53,
108        36,  27,  22,  18,  16,  15,  13};
109    if (_n>=14)
110    {
111       if (_m>=14)
112          return 0;
113       else
114          return _n <= maxN[_m];
115    } else {
116       return _m <= maxM[_n];
117    }   
118 }
119 int fits_in64(int _n, int _m)
120 {
121    static const celt_int16_t maxN[28] = {
122       255, 255, 255, 255, 255, 255, 255, 255,
123       255, 255, 178, 129, 100,  81,  68,  58,
124        51,  46,  42,  38,  36,  33,  31,  30,
125        28, 27, 26, 25};
126    static const celt_int16_t maxM[28] = {
127       255, 255, 255, 255, 255, 255, 255, 255, 
128       255, 255, 245, 166, 122,  94,  77,  64, 
129        56,  49,  44,  40,  37,  34,  32,  30,
130        29,  27,  26,  25};
131    if (_n>=27)
132    {
133       if (_m>=27)
134          return 0;
135       else
136          return _n <= maxN[_m];
137    } else {
138       return _m <= maxM[_n];
139    }
140 }
141
142 #define MASK32 (0xFFFFFFFF)
143
144 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
145 static const unsigned INV_TABLE[128]={
146   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
147   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
148   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
149   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
150   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
151   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
152   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
153   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
154   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
155   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
156   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
157   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
158   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
159   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
160   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
161   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
162   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
163   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
164   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
165   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
166   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
167   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
168   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
169   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
170   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
171   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
172   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
173   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
174   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
175   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
176   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
177   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
178 };
179
180 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
181   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
182    fits in 32 bits, but currently the table for multiplicative inverses is only
183    valid for _d<128.*/
184 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
185  celt_uint32_t _c,celt_uint32_t _d){
186   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
187 }
188
189 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
190   _d does not actually have to be even, but imusdiv32odd will be faster when
191    it's odd, so you should use that instead.
192   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
193    table for multiplicative inverses is only valid for _d<256).
194   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
195    32 bits.*/
196 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
197  celt_uint32_t _c,celt_uint32_t _d){
198   unsigned inv;
199   int      mask;
200   int      shift;
201   int      one;
202   shift=EC_ILOG(_d^_d-1);
203   inv=INV_TABLE[_d-1>>shift];
204   shift--;
205   one=1<<shift;
206   mask=one-1;
207   return (_a*(_b>>shift)-(_c>>shift)+
208    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
209 }
210
211 /*Computes the next row/column of any recurrence that obeys the relation
212    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
213   _ui0 is the base case for the new row/column.*/
214 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
215   celt_uint32_t ui1;
216   int           j;
217   /* doing a do-while would overrun the array if we had less than 2 samples */
218   j=1; do {
219     ui1=_ui[j]+_ui[j-1]+_ui0;
220     _ui[j-1]=_ui0;
221     _ui0=ui1;
222   } while (++j<_len);
223   _ui[j-1]=_ui0;
224 }
225
226 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
227   celt_uint64_t ui1;
228   int           j;
229   /* doing a do-while would overrun the array if we had less than 2 samples */
230   j=1; do {
231     ui1=_ui[j]+_ui[j-1]+_ui0;
232     _ui[j-1]=_ui0;
233     _ui0=ui1;
234   } while (++j<_len);
235   _ui[j-1]=_ui0;
236 }
237
238 /*Computes the previous row/column of any recurrence that obeys the relation
239    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
240   _ui0 is the base case for the new row/column.*/
241 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
242   celt_uint32_t ui1;
243   int           j;
244   /* doing a do-while would overrun the array if we had less than 2 samples */
245   j=1; do {
246     ui1=_ui[j]-_ui[j-1]-_ui0;
247     _ui[j-1]=_ui0;
248     _ui0=ui1;
249   } while (++j<_n);
250   _ui[j-1]=_ui0;
251 }
252
253 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
254   celt_uint64_t ui1;
255   int           j;
256   /* doing a do-while would overrun the array if we had less than 2 samples */
257   j=1; do {
258     ui1=_ui[j]-_ui[j-1]-_ui0;
259     _ui[j-1]=_ui0;
260     _ui0=ui1;
261   } while (++j<_n);
262   _ui[j-1]=_ui0;
263 }
264
265 /*Returns the number of ways of choosing _m elements from a set of size _n with
266    replacement when a sign bit is needed for each unique element.
267   _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
268 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
269   celt_uint32_t um2;
270   int           k;
271   int           len;
272   len=_m+2;
273   _u[0]=0;
274   _u[1]=um2=1;
275   if(_n<=6){
276     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
277     /*If _n==1, _u[i] should be 1 for i>1.*/
278     celt_assert(_n>=2);
279     /*If _m==0, the following do-while loop will overflow the buffer.*/
280     celt_assert(_m>0);
281     k=2;
282     do _u[k]=(k<<1)-1;
283     while(++k<len);
284     for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1);
285   }
286   else{
287     celt_uint32_t um1;
288     celt_uint32_t n2m1;
289     _u[2]=n2m1=um1=(_n<<1)-1;
290     for(k=3;k<len;k++){
291       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
292       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
293       if(++k>=len)break;
294       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
295     }
296   }
297   return _u[_m]+_u[_m+1];
298 }
299
300 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
301   int k;
302   int len;
303   len=_m+2;
304   _u[0]=0;
305   /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
306   /*If _n==1, _u[i] should be 1 for i>1.*/
307   celt_assert(_n>=2);
308   k=1;
309   do _u[k]=(k<<1)-1;
310   while(++k<len);
311   for(k=2;k<_n;k++)unext64(_u+2,_m,(k<<1)+1);
312   /*TODO: For large _n, an imusdiv64 could make this O(_m) instead of
313      O(_n*_m), but would require an INV_TABLE twice as large, as well as lots
314      of 64x64->64 bit multiplies.*/
315   return _u[_m]+_u[_m+1];
316 }
317
318
319 /*Returns the _i'th combination of _m elements chosen from a set of size _n
320    with associated sign bits.
321   _y: Returns the vector of pulses.
322   _u: Must contain entries [0..._m+1] of row _n of U() on input.
323       Its contents will be destructively modified.*/
324 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
325   int j;
326   int k;
327   celt_assert(_n>0);
328   j=0;
329   k=_m;
330   do{
331     celt_uint32_t p;
332     int           s;
333     int           yj;
334     p=_u[k+1];
335     s=_i>=p;
336     if(s)_i-=p;
337     yj=k;
338     p=_u[k];
339     while(p>_i)p=_u[--k];
340     _i-=p;
341     yj-=k;
342     _y[j]=yj-(yj<<1&-s);
343     uprev32(_u,k+2,0);
344   }
345   while(++j<_n);
346 }
347
348 /*Returns the _i'th combination of _m elements chosen from a set of size _n
349    with associated sign bits.
350   _y: Returns the vector of pulses.
351   _u: Must contain entries [0..._m+1] of row _n of U() on input.
352       Its contents will be destructively modified.*/
353 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u){
354   int j;
355   int k;
356   celt_assert(_n>0);
357   j=0;
358   k=_m;
359   do{
360     celt_uint64_t p;
361     int           s;
362     int           yj;
363     p=_u[k+1];
364     s=_i>=p;
365     if(s)_i-=p;
366     yj=k;
367     p=_u[k];
368     while(p>_i)p=_u[--k];
369     _i-=p;
370     yj-=k;
371     _y[j]=yj-(yj<<1&-s);
372     uprev64(_u,k+2,0);
373   }
374   while(++j<_n);
375 }
376
377 /*Returns the index of the given combination of _m elements chosen from a set
378    of size _n with associated sign bits.
379   _y:  The vector of pulses, whose sum of absolute values must be _m.
380   _nc: Returns V(_n,_m).*/
381 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
382  celt_uint32_t *_u){
383   celt_uint32_t i;
384   int           j;
385   int           k;
386   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
387   celt_assert(_n>=2);
388   i=_y[_n-1]<0;
389   _u[0]=0;
390   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
391   k=abs(_y[_n-1]);
392   j=_n-2;
393   i+=_u[k];
394   k+=abs(_y[j]);
395   if(_y[j]<0)i+=_u[k+1];
396   while(j-->0){
397     unext32(_u,_m+2,0);
398     i+=_u[k];
399     k+=abs(_y[j]);
400     if(_y[j]<0)i+=_u[k+1];
401   }
402   *_nc=_u[_m]+_u[_m+1];
403   return i;
404 }
405
406 /*Returns the index of the given combination of _m elements chosen from a set
407    of size _n with associated sign bits.
408   _y:  The vector of pulses, whose sum of absolute values must be _m.
409   _nc: Returns V(_n,_m).*/
410 celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
411  celt_uint64_t *_u){
412   celt_uint64_t i;
413   int           j;
414   int           k;
415   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
416   celt_assert(_n>=2);
417   i=_y[_n-1]<0;
418   _u[0]=0;
419   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
420   k=abs(_y[_n-1]);
421   j=_n-2;
422   i+=_u[k];
423   k+=abs(_y[j]);
424   if(_y[j]<0)i+=_u[k+1];
425   while(j-->0){
426     unext64(_u,_m+2,0);
427     i+=_u[k];
428     k+=abs(_y[j]);
429     if(_y[j]<0)i+=_u[k+1];
430   }
431   *_nc=_u[_m]+_u[_m+1];
432   return i;
433 }
434
435 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
436   VARDECL(celt_uint32_t,u);
437   celt_uint32_t nc;
438   celt_uint32_t i;
439   SAVE_STACK;
440   ALLOC(u,_m+2,celt_uint32_t);
441   i=icwrs32(_n,_m,&nc,_y,u);
442   ec_enc_uint(_enc,i,nc);
443   RESTORE_STACK;
444 }
445
446 static inline void encode_pulse64(int _n,int _m,const int *_y,ec_enc *_enc){
447   VARDECL(celt_uint64_t,u);
448   celt_uint64_t nc;
449   celt_uint64_t i;
450   SAVE_STACK;
451   ALLOC(u,_m+2,celt_uint64_t);
452   i=icwrs64(_n,_m,&nc,_y,u);
453   ec_enc_uint64(_enc,i,nc);
454   RESTORE_STACK;
455 }
456
457 int get_required_bits(int N, int K, int frac)
458 {
459    int nbits = 0;
460    if(fits_in64(N,K))
461    {
462       VARDECL(celt_uint64_t,u);
463       SAVE_STACK;
464       ALLOC(u,K+2,celt_uint64_t);
465       nbits = log2_frac64(ncwrs_u64(N,K,u), frac);
466       RESTORE_STACK;
467    } else {
468       nbits = log2_frac64(N, frac);
469       nbits += get_required_bits(N/2+1, (K+1)/2, frac);
470       nbits += get_required_bits(N/2+1, K/2, frac);
471    }
472    return nbits;
473 }
474
475
476 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
477 {
478    if (K==0) {
479    } else if (N==1)
480    {
481       ec_enc_bits(enc, _y[0]<0, 1);
482    } else if(fits_in32(N,K))
483    {
484       encode_pulse32(N, K, _y, enc);
485    } else if(fits_in64(N,K)) {
486       encode_pulse64(N, K, _y, enc);
487    } else {
488      int i;
489      int count=0;
490      int split;
491      split = (N+1)/2;
492      for (i=0;i<split;i++)
493         count += abs(_y[i]);
494      ec_enc_uint(enc,count,K+1);
495      encode_pulses(_y, split, count, enc);
496      encode_pulses(_y+split, N-split, K-count, enc);
497    }
498 }
499
500 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
501   VARDECL(celt_uint32_t,u);
502   SAVE_STACK;
503   ALLOC(u,_m+2,celt_uint32_t);
504   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
505   RESTORE_STACK;
506 }
507
508 static inline void decode_pulse64(int _n,int _m,int *_y,ec_dec *_dec){
509   VARDECL(celt_uint64_t,u);
510   SAVE_STACK;
511   ALLOC(u,_m+2,celt_uint64_t);
512   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_y,u);
513   RESTORE_STACK;
514 }
515
516 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
517 {
518    if (K==0) {
519       int i;
520       for (i=0;i<N;i++)
521          _y[i] = 0;
522    } else if (N==1)
523    {
524       int s = ec_dec_bits(dec, 1);
525       if (s==0)
526          _y[0] = K;
527       else
528          _y[0] = -K;
529    } else if(fits_in32(N,K))
530    {
531       decode_pulse32(N, K, _y, dec);
532    } else if(fits_in64(N,K)) {
533       decode_pulse64(N, K, _y, dec);
534    } else {
535      int split;
536      int count = ec_dec_uint(dec,K+1);
537      split = (N+1)/2;
538      decode_pulses(_y, split, count, dec);
539      decode_pulses(_y+split, N-split, K-count, dec);
540    }
541 }