a39c1a1dc6a22ad07bb581489fe6d297fc5c0f0e
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    int pow2;
58    pow2=!(val&val-1);
59    /*printf ("in: %d %d ", val, L);*/
60    if (L>14)
61       val >>= L-14;
62    else if (L<14)
63       val <<= 14-L;
64    L <<= frac;
65    /*printf ("%d\n", val);*/
66    for (i=0;i<frac;i++)
67 {
68       val = (val*val) >> 15;
69       /*printf ("%d\n", val);*/
70       if (val > 16384)
71          L |= (1<<(frac-i-1));
72       else
73          val <<= 1;
74 }
75    /*The previous loop returns a conservatively low estimate.
76      If val wasn't a power of two, we should round up.*/
77    L += !pow2;
78    return L;
79 }
80
81 int fits_in32(int _n, int _m)
82 {
83    static const celt_int16_t maxN[15] = {
84       255, 255, 255, 255, 255, 109,  60,  40,
85        29,  24,  20,  18,  16,  14,  13};
86    static const celt_int16_t maxM[15] = {
87       255, 255, 255, 255, 255, 238,  95,  53,
88        36,  27,  22,  18,  16,  15,  13};
89    if (_n>=14)
90    {
91       if (_m>=14)
92          return 0;
93       else
94          return _n <= maxN[_m];
95    } else {
96       return _m <= maxM[_n];
97    }   
98 }
99
100 #define MASK32 (0xFFFFFFFF)
101
102 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
103 static const celt_uint32_t INV_TABLE[128]={
104   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
105   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
106   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
107   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
108   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
109   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
110   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
111   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
112   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
113   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
114   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
115   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
116   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
117   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
118   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
119   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
120   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
121   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
122   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
123   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
124   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
125   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
126   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
127   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
128   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
129   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
130   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
131   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
132   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
133   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
134   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
135   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
136 };
137
138 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
139   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
140    fits in 32 bits, but currently the table for multiplicative inverses is only
141    valid for _d<128.*/
142 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
143  celt_uint32_t _c,celt_uint32_t _d){
144   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
145 }
146
147 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
148   _d does not actually have to be even, but imusdiv32odd will be faster when
149    it's odd, so you should use that instead.
150   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
151    table for multiplicative inverses is only valid for _d<256).
152   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
153    32 bits.*/
154 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
155  celt_uint32_t _c,celt_uint32_t _d){
156   celt_uint32_t inv;
157   int      mask;
158   int      shift;
159   int      one;
160   shift=EC_ILOG(_d^_d-1);
161   inv=INV_TABLE[_d-1>>shift];
162   shift--;
163   one=1<<shift;
164   mask=one-1;
165   return (_a*(_b>>shift)-(_c>>shift)+
166    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
167 }
168
169 /*Computes the next row/column of any recurrence that obeys the relation
170    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
171   _ui0 is the base case for the new row/column.*/
172 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
173   celt_uint32_t ui1;
174   int           j;
175   /* doing a do-while would overrun the array if we had less than 2 samples */
176   j=1; do {
177     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
178     _ui[j-1]=_ui0;
179     _ui0=ui1;
180   } while (++j<_len);
181   _ui[j-1]=_ui0;
182 }
183
184 /*Computes the previous row/column of any recurrence that obeys the relation
185    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
186   _ui0 is the base case for the new row/column.*/
187 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
188   celt_uint32_t ui1;
189   int           j;
190   /* doing a do-while would overrun the array if we had less than 2 samples */
191   j=1; do {
192     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
193     _ui[j-1]=_ui0;
194     _ui0=ui1;
195   } while (++j<_n);
196   _ui[j-1]=_ui0;
197 }
198
199 /*Returns the number of ways of choosing _m elements from a set of size _n with
200    replacement when a sign bit is needed for each unique element.
201   _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
202 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
203   celt_uint32_t um2;
204   int           k;
205   int           len;
206   len=_m+2;
207   _u[0]=0;
208   _u[1]=um2=1;
209   if(_n<=6){
210     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
211     /*If _n==1, _u[i] should be 1 for i>1.*/
212     celt_assert(_n>=2);
213     /*If _m==0, the following do-while loop will overflow the buffer.*/
214     celt_assert(_m>0);
215     k=2;
216     do _u[k]=(k<<1)-1;
217     while(++k<len);
218     for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1);
219   }
220   else{
221     celt_uint32_t um1;
222     celt_uint32_t n2m1;
223     _u[2]=n2m1=um1=(_n<<1)-1;
224     for(k=3;k<len;k++){
225       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
226       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
227       if(++k>=len)break;
228       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
229     }
230   }
231   return _u[_m]+_u[_m+1];
232 }
233
234
235
236 /*Returns the _i'th combination of _m elements chosen from a set of size _n
237    with associated sign bits.
238   _y: Returns the vector of pulses.
239   _u: Must contain entries [0..._m+1] of row _n of U() on input.
240       Its contents will be destructively modified.*/
241 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
242   int j;
243   int k;
244   celt_assert(_n>0);
245   j=0;
246   k=_m;
247   do{
248     celt_uint32_t p;
249     int           s;
250     int           yj;
251     p=_u[k+1];
252     s=_i>=p;
253     if(s)_i-=p;
254     yj=k;
255     p=_u[k];
256     while(p>_i)p=_u[--k];
257     _i-=p;
258     yj-=k;
259     _y[j]=yj-(yj<<1&-s);
260     uprev32(_u,k+2,0);
261   }
262   while(++j<_n);
263 }
264
265
266 /*Returns the index of the given combination of _m elements chosen from a set
267    of size _n with associated sign bits.
268   _y:  The vector of pulses, whose sum of absolute values must be _m.
269   _nc: Returns V(_n,_m).*/
270 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
271  celt_uint32_t *_u){
272   celt_uint32_t i;
273   int           j;
274   int           k;
275   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
276   celt_assert(_n>=2);
277   i=_y[_n-1]<0;
278   _u[0]=0;
279   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
280   k=abs(_y[_n-1]);
281   j=_n-2;
282   i+=_u[k];
283   k+=abs(_y[j]);
284   if(_y[j]<0)i+=_u[k+1];
285   while(j-->0){
286     unext32(_u,_m+2,0);
287     i+=_u[k];
288     k+=abs(_y[j]);
289     if(_y[j]<0)i+=_u[k+1];
290   }
291   *_nc=_u[_m]+_u[_m+1];
292   return i;
293 }
294
295 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
296   VARDECL(celt_uint32_t,u);
297   celt_uint32_t nc;
298   celt_uint32_t i;
299   SAVE_STACK;
300   ALLOC(u,_m+2,celt_uint32_t);
301   i=icwrs32(_n,_m,&nc,_y,u);
302   ec_enc_uint(_enc,i,nc);
303   RESTORE_STACK;
304 }
305
306 int get_required_bits32(int N, int K, int frac)
307 {
308    int nbits;
309    VARDECL(celt_uint32_t,u);
310    SAVE_STACK;
311    ALLOC(u,K+2,celt_uint32_t);
312    nbits = log2_frac(ncwrs_u32(N,K,u), frac);
313    RESTORE_STACK;
314    return nbits;
315 }
316
317 void get_required_bits(celt_int16_t *bits,int N, int MAXK, int frac)
318 {
319    int k;
320    /*We special case k==0 below, since fits_in32 could reject it for large N.*/
321    celt_assert(MAXK>0);
322    if(fits_in32(N,MAXK-1)){
323       bits[0]=0;
324       /*This could be sped up one heck of a lot if we didn't recompute u in
325          ncwrs_u32 every time.*/
326       for(k=1;k<MAXK;k++)bits[k]=get_required_bits32(N,k,frac);
327    }
328    else{
329       VARDECL(celt_int16_t,n1bits);
330       VARDECL(celt_int16_t,_n2bits);
331       celt_int16_t *n2bits;
332       SAVE_STACK;
333       ALLOC(n1bits,MAXK,celt_int16_t);
334       ALLOC(_n2bits,MAXK,celt_int16_t);
335       get_required_bits(n1bits,(N+1)/2,MAXK,frac);
336       if(N&1){
337         n2bits=_n2bits;
338         get_required_bits(n2bits,N/2,MAXK,frac);
339       }else{
340         n2bits=n1bits;
341       }
342       bits[0]=0;
343       for(k=1;k<MAXK;k++){
344          if(fits_in32(N,k))bits[k]=get_required_bits32(N,k,frac);
345          else{
346             int worst_bits;
347             int i;
348             worst_bits=0;
349             for(i=0;i<=k;i++){
350                int split_bits;
351                split_bits=n1bits[i]+n2bits[k-i];
352                if(split_bits>worst_bits)worst_bits=split_bits;
353             }
354             bits[k]=log2_frac(k+1,frac)+worst_bits;
355          }
356       }
357    RESTORE_STACK;
358    }
359 }
360
361
362 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
363 {
364    if (K==0) {
365    } else if (N==1)
366    {
367       ec_enc_bits(enc, _y[0]<0, 1);
368    } else if(fits_in32(N,K))
369    {
370       encode_pulse32(N, K, _y, enc);
371    } else {
372      int i;
373      int count=0;
374      int split;
375      split = (N+1)/2;
376      for (i=0;i<split;i++)
377         count += abs(_y[i]);
378      ec_enc_uint(enc,count,K+1);
379      encode_pulses(_y, split, count, enc);
380      encode_pulses(_y+split, N-split, K-count, enc);
381    }
382 }
383
384 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
385   VARDECL(celt_uint32_t,u);
386   SAVE_STACK;
387   ALLOC(u,_m+2,celt_uint32_t);
388   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
389   RESTORE_STACK;
390 }
391
392 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
393 {
394    if (K==0) {
395       int i;
396       for (i=0;i<N;i++)
397          _y[i] = 0;
398    } else if (N==1)
399    {
400       int s = ec_dec_bits(dec, 1);
401       if (s==0)
402          _y[0] = K;
403       else
404          _y[0] = -K;
405    } else if(fits_in32(N,K))
406    {
407       decode_pulse32(N, K, _y, dec);
408    } else {
409      int split;
410      int count = ec_dec_uint(dec,K+1);
411      split = (N+1)/2;
412      decode_pulses(_y, split, count, dec);
413      decode_pulses(_y+split, N-split, K-count, dec);
414    }
415 }