Making the "data" argument to celt_decode() const as pointed out by Bjoern
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 /*Guaranteed to return a conservatively large estimate of the binary logarithm
53    with frac bits of fractional precision.
54   Tested for all possible 32-bit inputs with frac=4, where the maximum
55    overestimation is 0.06254243 bits.*/
56 int log2_frac(ec_uint32 val, int frac)
57 {
58   int l;
59   l=EC_ILOG(val);
60   if(val&val-1){
61     /*This is (val>>l-16), but guaranteed to round up, even if adding a bias
62        before the shift would cause overflow (e.g., for 0xFFFFxxxx).*/
63     if(l>16)val=(val>>l-16)+((val&(1<<l-16)-1)+(1<<l-16)-1>>l-16);
64     else val<<=16-l;
65     l=l-1<<frac;
66     /*Note that we always need one iteration, since the rounding up above means
67        that we might need to adjust the integer part of the logarithm.*/
68     do{
69       int b;
70       b=(int)(val>>16);
71       l+=b<<frac;
72       val=val+b>>b;
73       val=val*val+0x7FFF>>15;
74     }
75     while(frac-->0);
76     /*If val is not exactly 0x8000, then we have to round up the remainder.*/
77     return l+(val>0x8000);
78   }
79   /*Exact powers of two require no rounding.*/
80   else return l-1<<frac;
81 }
82
83 int fits_in32(int _n, int _m)
84 {
85    static const celt_int16_t maxN[15] = {
86       255, 255, 255, 255, 255, 109,  60,  40,
87        29,  24,  20,  18,  16,  14,  13};
88    static const celt_int16_t maxM[15] = {
89       255, 255, 255, 255, 255, 238,  95,  53,
90        36,  27,  22,  18,  16,  15,  13};
91    if (_n>=14)
92    {
93       if (_m>=14)
94          return 0;
95       else
96          return _n <= maxN[_m];
97    } else {
98       return _m <= maxM[_n];
99    }   
100 }
101
102 #define MASK32 (0xFFFFFFFF)
103
104 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
105 static const celt_uint32_t INV_TABLE[128]={
106   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
107   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
108   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
109   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
110   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
111   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
112   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
113   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
114   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
115   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
116   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
117   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
118   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
119   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
120   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
121   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
122   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
123   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
124   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
125   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
126   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
127   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
128   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
129   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
130   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
131   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
132   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
133   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
134   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
135   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
136   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
137   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
138 };
139
140 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
141   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
142    fits in 32 bits, but currently the table for multiplicative inverses is only
143    valid for _d<128.*/
144 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
145  celt_uint32_t _c,celt_uint32_t _d){
146   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
147 }
148
149 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
150   _d does not actually have to be even, but imusdiv32odd will be faster when
151    it's odd, so you should use that instead.
152   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
153    table for multiplicative inverses is only valid for _d<256).
154   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
155    32 bits.*/
156 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
157  celt_uint32_t _c,celt_uint32_t _d){
158   celt_uint32_t inv;
159   int      mask;
160   int      shift;
161   int      one;
162   shift=EC_ILOG(_d^_d-1);
163   inv=INV_TABLE[_d-1>>shift];
164   shift--;
165   one=1<<shift;
166   mask=one-1;
167   return (_a*(_b>>shift)-(_c>>shift)+
168    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
169 }
170
171 /*Computes the next row/column of any recurrence that obeys the relation
172    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
173   _ui0 is the base case for the new row/column.*/
174 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
175   celt_uint32_t ui1;
176   int           j;
177   /* doing a do-while would overrun the array if we had less than 2 samples */
178   j=1; do {
179     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
180     _ui[j-1]=_ui0;
181     _ui0=ui1;
182   } while (++j<_len);
183   _ui[j-1]=_ui0;
184 }
185
186 /*Computes the previous row/column of any recurrence that obeys the relation
187    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
188   _ui0 is the base case for the new row/column.*/
189 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
190   celt_uint32_t ui1;
191   int           j;
192   /* doing a do-while would overrun the array if we had less than 2 samples */
193   j=1; do {
194     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
195     _ui[j-1]=_ui0;
196     _ui0=ui1;
197   } while (++j<_n);
198   _ui[j-1]=_ui0;
199 }
200
201 /*Returns the number of ways of choosing _m elements from a set of size _n with
202    replacement when a sign bit is needed for each unique element.
203   _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
204 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
205   celt_uint32_t um2;
206   int           k;
207   int           len;
208   len=_m+2;
209   _u[0]=0;
210   _u[1]=um2=1;
211   if(_n<=6){
212     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
213     /*If _n==1, _u[i] should be 1 for i>1.*/
214     celt_assert(_n>=2);
215     /*If _m==0, the following do-while loop will overflow the buffer.*/
216     celt_assert(_m>0);
217     k=2;
218     do _u[k]=(k<<1)-1;
219     while(++k<len);
220     for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1);
221   }
222   else{
223     celt_uint32_t um1;
224     celt_uint32_t n2m1;
225     _u[2]=n2m1=um1=(_n<<1)-1;
226     for(k=3;k<len;k++){
227       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
228       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
229       if(++k>=len)break;
230       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
231     }
232   }
233   return _u[_m]+_u[_m+1];
234 }
235
236
237
238 /*Returns the _i'th combination of _m elements chosen from a set of size _n
239    with associated sign bits.
240   _y: Returns the vector of pulses.
241   _u: Must contain entries [0..._m+1] of row _n of U() on input.
242       Its contents will be destructively modified.*/
243 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
244   int j;
245   int k;
246   celt_assert(_n>0);
247   j=0;
248   k=_m;
249   do{
250     celt_uint32_t p;
251     int           s;
252     int           yj;
253     p=_u[k+1];
254     s=_i>=p;
255     if(s)_i-=p;
256     yj=k;
257     p=_u[k];
258     while(p>_i)p=_u[--k];
259     _i-=p;
260     yj-=k;
261     _y[j]=yj-(yj<<1&-s);
262     uprev32(_u,k+2,0);
263   }
264   while(++j<_n);
265 }
266
267
268 /*Returns the index of the given combination of _m elements chosen from a set
269    of size _n with associated sign bits.
270   _y:  The vector of pulses, whose sum of absolute values must be _m.
271   _nc: Returns V(_n,_m).*/
272 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
273  celt_uint32_t *_u){
274   celt_uint32_t i;
275   int           j;
276   int           k;
277   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
278   celt_assert(_n>=2);
279   i=_y[_n-1]<0;
280   _u[0]=0;
281   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
282   k=abs(_y[_n-1]);
283   j=_n-2;
284   i+=_u[k];
285   k+=abs(_y[j]);
286   if(_y[j]<0)i+=_u[k+1];
287   while(j-->0){
288     unext32(_u,_m+2,0);
289     i+=_u[k];
290     k+=abs(_y[j]);
291     if(_y[j]<0)i+=_u[k+1];
292   }
293   *_nc=_u[_m]+_u[_m+1];
294   return i;
295 }
296
297 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
298   VARDECL(celt_uint32_t,u);
299   celt_uint32_t nc;
300   celt_uint32_t i;
301   SAVE_STACK;
302   ALLOC(u,_m+2,celt_uint32_t);
303   i=icwrs32(_n,_m,&nc,_y,u);
304   ec_enc_uint(_enc,i,nc);
305   RESTORE_STACK;
306 }
307
308 int get_required_bits32(int N, int K, int frac)
309 {
310    int nbits;
311    VARDECL(celt_uint32_t,u);
312    SAVE_STACK;
313    ALLOC(u,K+2,celt_uint32_t);
314    nbits = log2_frac(ncwrs_u32(N,K,u), frac);
315    RESTORE_STACK;
316    return nbits;
317 }
318
319 void get_required_bits(celt_int16_t *bits,int N, int MAXK, int frac)
320 {
321    int k;
322    /*We special case k==0 below, since fits_in32 could reject it for large N.*/
323    celt_assert(MAXK>0);
324    if(fits_in32(N,MAXK-1)){
325       bits[0]=0;
326       /*This could be sped up one heck of a lot if we didn't recompute u in
327          ncwrs_u32 every time.*/
328       for(k=1;k<MAXK;k++)bits[k]=get_required_bits32(N,k,frac);
329    }
330    else{
331       VARDECL(celt_int16_t,n1bits);
332       VARDECL(celt_int16_t,_n2bits);
333       celt_int16_t *n2bits;
334       SAVE_STACK;
335       ALLOC(n1bits,MAXK,celt_int16_t);
336       ALLOC(_n2bits,MAXK,celt_int16_t);
337       get_required_bits(n1bits,(N+1)/2,MAXK,frac);
338       if(N&1){
339         n2bits=_n2bits;
340         get_required_bits(n2bits,N/2,MAXK,frac);
341       }else{
342         n2bits=n1bits;
343       }
344       bits[0]=0;
345       for(k=1;k<MAXK;k++){
346          if(fits_in32(N,k))bits[k]=get_required_bits32(N,k,frac);
347          else{
348             int worst_bits;
349             int i;
350             worst_bits=0;
351             for(i=0;i<=k;i++){
352                int split_bits;
353                split_bits=n1bits[i]+n2bits[k-i];
354                if(split_bits>worst_bits)worst_bits=split_bits;
355             }
356             bits[k]=log2_frac(k+1,frac)+worst_bits;
357          }
358       }
359       RESTORE_STACK;
360    }
361 }
362
363
364 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
365 {
366    if (K==0) {
367    } else if (N==1)
368    {
369       ec_enc_bits(enc, _y[0]<0, 1);
370    } else if(fits_in32(N,K))
371    {
372       encode_pulse32(N, K, _y, enc);
373    } else {
374      int i;
375      int count=0;
376      int split;
377      split = (N+1)/2;
378      for (i=0;i<split;i++)
379         count += abs(_y[i]);
380      ec_enc_uint(enc,count,K+1);
381      encode_pulses(_y, split, count, enc);
382      encode_pulses(_y+split, N-split, K-count, enc);
383    }
384 }
385
386 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
387   VARDECL(celt_uint32_t,u);
388   SAVE_STACK;
389   ALLOC(u,_m+2,celt_uint32_t);
390   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
391   RESTORE_STACK;
392 }
393
394 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
395 {
396    if (K==0) {
397       int i;
398       for (i=0;i<N;i++)
399          _y[i] = 0;
400    } else if (N==1)
401    {
402       int s = ec_dec_bits(dec, 1);
403       if (s==0)
404          _y[0] = K;
405       else
406          _y[0] = -K;
407    } else if(fits_in32(N,K))
408    {
409       decode_pulse32(N, K, _y, dec);
410    } else {
411      int split;
412      int count = ec_dec_uint(dec,K+1);
413      split = (N+1)/2;
414      decode_pulses(_y, split, count, dec);
415      decode_pulses(_y+split, N-split, K-count, dec);
416    }
417 }