Merge branch 'cwrs_speedup'
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    /*printf ("in: %d %d ", val, L);*/
58    if (L>14)
59       val >>= L-14;
60    else if (L<14)
61       val <<= 14-L;
62    L <<= frac;
63    /*printf ("%d\n", val);*/
64    for (i=0;i<frac;i++)
65 {
66       val = (val*val) >> 15;
67       /*printf ("%d\n", val);*/
68       if (val > 16384)
69          L |= (1<<(frac-i-1));
70       else   
71          val <<= 1;
72 }
73    return L;
74 }
75
76 int fits_in32(int _n, int _m)
77 {
78    static const celt_int16_t maxN[15] = {
79       255, 255, 255, 255, 255, 109,  60,  40,
80        29,  24,  20,  18,  16,  14,  13};
81    static const celt_int16_t maxM[15] = {
82       255, 255, 255, 255, 255, 238,  95,  53,
83        36,  27,  22,  18,  16,  15,  13};
84    if (_n>=14)
85    {
86       if (_m>=14)
87          return 0;
88       else
89          return _n <= maxN[_m];
90    } else {
91       return _m <= maxM[_n];
92    }   
93 }
94
95 #define MASK32 (0xFFFFFFFF)
96
97 /*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
98 static const unsigned INV_TABLE[128]={
99   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
100   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
101   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
102   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
103   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
104   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
105   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
106   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
107   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
108   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
109   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
110   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
111   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
112   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
113   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
114   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
115   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
116   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
117   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
118   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
119   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
120   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
121   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
122   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
123   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
124   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
125   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
126   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
127   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
128   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
129   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
130   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
131 };
132
133 /*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
134   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
135    fits in 32 bits, but currently the table for multiplicative inverses is only
136    valid for _d<128.*/
137 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
138  celt_uint32_t _c,celt_uint32_t _d){
139   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
140 }
141
142 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
143   _d does not actually have to be even, but imusdiv32odd will be faster when
144    it's odd, so you should use that instead.
145   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
146    table for multiplicative inverses is only valid for _d<256).
147   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
148    32 bits.*/
149 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
150  celt_uint32_t _c,celt_uint32_t _d){
151   unsigned inv;
152   int      mask;
153   int      shift;
154   int      one;
155   shift=EC_ILOG(_d^_d-1);
156   inv=INV_TABLE[_d-1>>shift];
157   shift--;
158   one=1<<shift;
159   mask=one-1;
160   return (_a*(_b>>shift)-(_c>>shift)+
161    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
162 }
163
164 /*Computes the next row/column of any recurrence that obeys the relation
165    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
166   _ui0 is the base case for the new row/column.*/
167 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
168   celt_uint32_t ui1;
169   int           j;
170   /* doing a do-while would overrun the array if we had less than 2 samples */
171   j=1; do {
172     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
173     _ui[j-1]=_ui0;
174     _ui0=ui1;
175   } while (++j<_len);
176   _ui[j-1]=_ui0;
177 }
178
179 /*Computes the previous row/column of any recurrence that obeys the relation
180    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
181   _ui0 is the base case for the new row/column.*/
182 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
183   celt_uint32_t ui1;
184   int           j;
185   /* doing a do-while would overrun the array if we had less than 2 samples */
186   j=1; do {
187     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
188     _ui[j-1]=_ui0;
189     _ui0=ui1;
190   } while (++j<_n);
191   _ui[j-1]=_ui0;
192 }
193
194 /*Returns the number of ways of choosing _m elements from a set of size _n with
195    replacement when a sign bit is needed for each unique element.
196   _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
197 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
198   celt_uint32_t um2;
199   int           k;
200   int           len;
201   len=_m+2;
202   _u[0]=0;
203   _u[1]=um2=1;
204   if(_n<=6){
205     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
206     /*If _n==1, _u[i] should be 1 for i>1.*/
207     celt_assert(_n>=2);
208     /*If _m==0, the following do-while loop will overflow the buffer.*/
209     celt_assert(_m>0);
210     k=2;
211     do _u[k]=(k<<1)-1;
212     while(++k<len);
213     for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1);
214   }
215   else{
216     celt_uint32_t um1;
217     celt_uint32_t n2m1;
218     _u[2]=n2m1=um1=(_n<<1)-1;
219     for(k=3;k<len;k++){
220       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
221       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
222       if(++k>=len)break;
223       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
224     }
225   }
226   return _u[_m]+_u[_m+1];
227 }
228
229
230
231 /*Returns the _i'th combination of _m elements chosen from a set of size _n
232    with associated sign bits.
233   _y: Returns the vector of pulses.
234   _u: Must contain entries [0..._m+1] of row _n of U() on input.
235       Its contents will be destructively modified.*/
236 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
237   int j;
238   int k;
239   celt_assert(_n>0);
240   j=0;
241   k=_m;
242   do{
243     celt_uint32_t p;
244     int           s;
245     int           yj;
246     p=_u[k+1];
247     s=_i>=p;
248     if(s)_i-=p;
249     yj=k;
250     p=_u[k];
251     while(p>_i)p=_u[--k];
252     _i-=p;
253     yj-=k;
254     _y[j]=yj-(yj<<1&-s);
255     uprev32(_u,k+2,0);
256   }
257   while(++j<_n);
258 }
259
260
261 /*Returns the index of the given combination of _m elements chosen from a set
262    of size _n with associated sign bits.
263   _y:  The vector of pulses, whose sum of absolute values must be _m.
264   _nc: Returns V(_n,_m).*/
265 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
266  celt_uint32_t *_u){
267   celt_uint32_t i;
268   int           j;
269   int           k;
270   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
271   celt_assert(_n>=2);
272   i=_y[_n-1]<0;
273   _u[0]=0;
274   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
275   k=abs(_y[_n-1]);
276   j=_n-2;
277   i+=_u[k];
278   k+=abs(_y[j]);
279   if(_y[j]<0)i+=_u[k+1];
280   while(j-->0){
281     unext32(_u,_m+2,0);
282     i+=_u[k];
283     k+=abs(_y[j]);
284     if(_y[j]<0)i+=_u[k+1];
285   }
286   *_nc=_u[_m]+_u[_m+1];
287   return i;
288 }
289
290 static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
291   VARDECL(celt_uint32_t,u);
292   celt_uint32_t nc;
293   celt_uint32_t i;
294   SAVE_STACK;
295   ALLOC(u,_m+2,celt_uint32_t);
296   i=icwrs32(_n,_m,&nc,_y,u);
297   ec_enc_uint(_enc,i,nc);
298   RESTORE_STACK;
299 }
300
301 int get_required_bits(int N, int K, int frac)
302 {
303    int nbits = 0;
304    if(fits_in32(N,K))
305    {
306       VARDECL(celt_uint32_t,u);
307       SAVE_STACK;
308       ALLOC(u,K+2,celt_uint32_t);
309       nbits = log2_frac(ncwrs_u32(N,K,u), frac);
310       RESTORE_STACK;
311    } else {
312       nbits = log2_frac(N, frac);
313       nbits += get_required_bits(N/2+1, (K+1)/2, frac);
314       nbits += get_required_bits(N/2+1, K/2, frac);
315    }
316    return nbits;
317 }
318
319
320 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
321 {
322    if (K==0) {
323    } else if (N==1)
324    {
325       ec_enc_bits(enc, _y[0]<0, 1);
326    } else if(fits_in32(N,K))
327    {
328       encode_pulse32(N, K, _y, enc);
329    } else {
330      int i;
331      int count=0;
332      int split;
333      split = (N+1)/2;
334      for (i=0;i<split;i++)
335         count += abs(_y[i]);
336      ec_enc_uint(enc,count,K+1);
337      encode_pulses(_y, split, count, enc);
338      encode_pulses(_y+split, N-split, K-count, enc);
339    }
340 }
341
342 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
343   VARDECL(celt_uint32_t,u);
344   SAVE_STACK;
345   ALLOC(u,_m+2,celt_uint32_t);
346   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
347   RESTORE_STACK;
348 }
349
350 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
351 {
352    if (K==0) {
353       int i;
354       for (i=0;i<N;i++)
355          _y[i] = 0;
356    } else if (N==1)
357    {
358       int s = ec_dec_bits(dec, 1);
359       if (s==0)
360          _y[0] = K;
361       else
362          _y[0] = -K;
363    } else if(fits_in32(N,K))
364    {
365       decode_pulse32(N, K, _y, dec);
366    } else {
367      int split;
368      int count = ec_dec_uint(dec,K+1);
369      split = (N+1)/2;
370      decode_pulses(_y, split, count, dec);
371      decode_pulses(_y+split, N-split, K-count, dec);
372    }
373 }