Fix (unexploitable) buffer overrun when _m=1 during the cwrs table init, as
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 /*Guaranteed to return a conservatively large estimate of the binary logarithm
53    with frac bits of fractional precision.
54   Tested for all possible 32-bit inputs with frac=4, where the maximum
55    overestimation is 0.06254243 bits.*/
56 int log2_frac(ec_uint32 val, int frac)
57 {
58   int l;
59   l=EC_ILOG(val);
60   if(val&val-1){
61     /*This is (val>>l-16), but guaranteed to round up, even if adding a bias
62        before the shift would cause overflow (e.g., for 0xFFFFxxxx).*/
63     if(l>16)val=(val>>l-16)+((val&(1<<l-16)-1)+(1<<l-16)-1>>l-16);
64     else val<<=16-l;
65     l=l-1<<frac;
66     /*Note that we always need one iteration, since the rounding up above means
67        that we might need to adjust the integer part of the logarithm.*/
68     do{
69       int b;
70       b=(int)(val>>16);
71       l+=b<<frac;
72       val=val+b>>b;
73       val=val*val+0x7FFF>>15;
74     }
75     while(frac-->0);
76     /*If val is not exactly 0x8000, then we have to round up the remainder.*/
77     return l+(val>0x8000);
78   }
79   /*Exact powers of two require no rounding.*/
80   else return l-1<<frac;
81 }
82
83 int fits_in32(int _n, int _m)
84 {
85    static const celt_int16_t maxN[15] = {
86       255, 255, 255, 255, 255, 109,  60,  40,
87        29,  24,  20,  18,  16,  14,  13};
88    static const celt_int16_t maxM[15] = {
89       255, 255, 255, 255, 255, 238,  95,  53,
90        36,  27,  22,  18,  16,  15,  13};
91    if (_n>=14)
92    {
93       if (_m>=14)
94          return 0;
95       else
96          return _n <= maxN[_m];
97    } else {
98       return _m <= maxM[_n];
99    }   
100 }
101
102 #define MASK32 (0xFFFFFFFF)
103
104 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
105 static const celt_uint32_t INV_TABLE[128]={
106   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
107   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
108   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
109   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
110   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
111   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
112   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
113   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
114   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
115   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
116   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
117   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
118   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
119   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
120   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
121   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
122   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
123   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
124   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
125   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
126   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
127   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
128   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
129   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
130   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
131   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
132   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
133   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
134   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
135   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
136   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
137   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
138 };
139
140 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
141   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
142    fits in 32 bits, but currently the table for multiplicative inverses is only
143    valid for _d<128.*/
144 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
145  celt_uint32_t _c,celt_uint32_t _d){
146   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
147 }
148
149 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
150   _d does not actually have to be even, but imusdiv32odd will be faster when
151    it's odd, so you should use that instead.
152   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
153    table for multiplicative inverses is only valid for _d<256).
154   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
155    32 bits.*/
156 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
157  celt_uint32_t _c,celt_uint32_t _d){
158   celt_uint32_t inv;
159   int      mask;
160   int      shift;
161   int      one;
162   shift=EC_ILOG(_d^_d-1);
163   inv=INV_TABLE[_d-1>>shift];
164   shift--;
165   one=1<<shift;
166   mask=one-1;
167   return (_a*(_b>>shift)-(_c>>shift)+
168    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
169 }
170
171 /*Computes the next row/column of any recurrence that obeys the relation
172    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
173   _ui0 is the base case for the new row/column.*/
174 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
175   celt_uint32_t ui1;
176   int           j;
177   /* doing a do-while would overrun the array if we had less than 2 samples */
178   j=1; do {
179     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
180     _ui[j-1]=_ui0;
181     _ui0=ui1;
182   } while (++j<_len);
183   _ui[j-1]=_ui0;
184 }
185
186 /*Computes the previous row/column of any recurrence that obeys the relation
187    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
188   _ui0 is the base case for the new row/column.*/
189 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
190   celt_uint32_t ui1;
191   int           j;
192   /* doing a do-while would overrun the array if we had less than 2 samples */
193   j=1; do {
194     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
195     _ui[j-1]=_ui0;
196     _ui0=ui1;
197   } while (++j<_n);
198   _ui[j-1]=_ui0;
199 }
200
201 /*Returns the number of ways of choosing _m elements from a set of size _n with
202    replacement when a sign bit is needed for each unique element.
203   _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
204 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
205   celt_uint32_t um2;
206   int           k;
207   int           len;
208   len=_m+2;
209   _u[0]=0;
210   _u[1]=um2=1;
211   if(_n<=6){
212     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
213     /*If _n==1, _u[i] should be 1 for i>1.*/
214     celt_assert(_n>=2);
215     /*If _m==0, the following do-while loop will overflow the buffer.*/
216     celt_assert(_m>0);
217     k=2;
218     do _u[k]=(k<<1)-1;
219     while(++k<len);
220     for(k=2;k<_n;k++)
221       unext32(_u+1,_m+1,1);
222   }
223   else{
224     celt_uint32_t um1;
225     celt_uint32_t n2m1;
226     _u[2]=n2m1=um1=(_n<<1)-1;
227     for(k=3;k<len;k++){
228       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
229       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
230       if(++k>=len)break;
231       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
232     }
233   }
234   return _u[_m]+_u[_m+1];
235 }
236
237
238
239 /*Returns the _i'th combination of _m elements chosen from a set of size _n
240    with associated sign bits.
241   _y: Returns the vector of pulses.
242   _u: Must contain entries [0..._m+1] of row _n of U() on input.
243       Its contents will be destructively modified.*/
244 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
245   int j;
246   int k;
247   celt_assert(_n>0);
248   j=0;
249   k=_m;
250   do{
251     celt_uint32_t p;
252     int           s;
253     int           yj;
254     p=_u[k+1];
255     s=_i>=p;
256     if(s)_i-=p;
257     yj=k;
258     p=_u[k];
259     while(p>_i)p=_u[--k];
260     _i-=p;
261     yj-=k;
262     _y[j]=yj-(yj<<1&-s);
263     uprev32(_u,k+2,0);
264   }
265   while(++j<_n);
266 }
267
268
269 /*Returns the index of the given combination of _m elements chosen from a set
270    of size _n with associated sign bits.
271   _y:  The vector of pulses, whose sum of absolute values must be _m.
272   _nc: Returns V(_n,_m).*/
273 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
274  celt_uint32_t *_u){
275   celt_uint32_t i;
276   int           j;
277   int           k;
278   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
279   celt_assert(_n>=2);
280   i=_y[_n-1]<0;
281   _u[0]=0;
282   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
283   k=abs(_y[_n-1]);
284   j=_n-2;
285   i+=_u[k];
286   k+=abs(_y[j]);
287   if(_y[j]<0)i+=_u[k+1];
288   while(j-->0){
289     unext32(_u,_m+2,0);
290     i+=_u[k];
291     k+=abs(_y[j]);
292     if(_y[j]<0)i+=_u[k+1];
293   }
294   *_nc=_u[_m]+_u[_m+1];
295   return i;
296 }
297
298 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
299   VARDECL(celt_uint32_t,u);
300   celt_uint32_t nc;
301   celt_uint32_t i;
302   SAVE_STACK;
303   ALLOC(u,_m+2,celt_uint32_t);
304   i=icwrs32(_n,_m,&nc,_y,u);
305   ec_enc_uint(_enc,i,nc);
306   RESTORE_STACK;
307 }
308
309 int get_required_bits32(int N, int K, int frac)
310 {
311    int nbits;
312    VARDECL(celt_uint32_t,u);
313    SAVE_STACK;
314    ALLOC(u,K+2,celt_uint32_t);
315    nbits = log2_frac(ncwrs_u32(N,K,u), frac);
316    RESTORE_STACK;
317    return nbits;
318 }
319
320 void get_required_bits(celt_int16_t *bits,int N, int MAXK, int frac)
321 {
322    int k;
323    /*We special case k==0 below, since fits_in32 could reject it for large N.*/
324    celt_assert(MAXK>0);
325    if(fits_in32(N,MAXK-1)){
326       bits[0]=0;
327       /*This could be sped up one heck of a lot if we didn't recompute u in
328          ncwrs_u32 every time.*/
329       for(k=1;k<MAXK;k++)bits[k]=get_required_bits32(N,k,frac);
330    }
331    else{
332       VARDECL(celt_int16_t,n1bits);
333       VARDECL(celt_int16_t,_n2bits);
334       celt_int16_t *n2bits;
335       SAVE_STACK;
336       ALLOC(n1bits,MAXK,celt_int16_t);
337       ALLOC(_n2bits,MAXK,celt_int16_t);
338       get_required_bits(n1bits,(N+1)/2,MAXK,frac);
339       if(N&1){
340         n2bits=_n2bits;
341         get_required_bits(n2bits,N/2,MAXK,frac);
342       }else{
343         n2bits=n1bits;
344       }
345       bits[0]=0;
346       for(k=1;k<MAXK;k++){
347          if(fits_in32(N,k))bits[k]=get_required_bits32(N,k,frac);
348          else{
349             int worst_bits;
350             int i;
351             worst_bits=0;
352             for(i=0;i<=k;i++){
353                int split_bits;
354                split_bits=n1bits[i]+n2bits[k-i];
355                if(split_bits>worst_bits)worst_bits=split_bits;
356             }
357             bits[k]=log2_frac(k+1,frac)+worst_bits;
358          }
359       }
360       RESTORE_STACK;
361    }
362 }
363
364
365 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
366 {
367    if (K==0) {
368    } else if (N==1)
369    {
370       ec_enc_bits(enc, _y[0]<0, 1);
371    } else if(fits_in32(N,K))
372    {
373       encode_pulse32(N, K, _y, enc);
374    } else {
375      int i;
376      int count=0;
377      int split;
378      split = (N+1)/2;
379      for (i=0;i<split;i++)
380         count += abs(_y[i]);
381      ec_enc_uint(enc,count,K+1);
382      encode_pulses(_y, split, count, enc);
383      encode_pulses(_y+split, N-split, K-count, enc);
384    }
385 }
386
387 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
388   VARDECL(celt_uint32_t,u);
389   SAVE_STACK;
390   ALLOC(u,_m+2,celt_uint32_t);
391   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
392   RESTORE_STACK;
393 }
394
395 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
396 {
397    if (K==0) {
398       int i;
399       for (i=0;i<N;i++)
400          _y[i] = 0;
401    } else if (N==1)
402    {
403       int s = ec_dec_bits(dec, 1);
404       if (s==0)
405          _y[0] = K;
406       else
407          _y[0] = -K;
408    } else if(fits_in32(N,K))
409    {
410       decode_pulse32(N, K, _y, dec);
411    } else {
412      int split;
413      int count = ec_dec_uint(dec,K+1);
414      split = (N+1)/2;
415      decode_pulses(_y, split, count, dec);
416      decode_pulses(_y+split, N-split, K-count, dec);
417    }
418 }