Generate slightly more accurate WMOPS figures
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 #if 0
53 int log2_frac(ec_uint32 val, int frac)
54 {
55    int i;
56    /* EC_ILOG() actually returns log2()+1, go figure */
57    int L = EC_ILOG(val)-1;
58    /*printf ("in: %d %d ", val, L);*/
59    if (L>14)
60       val >>= L-14;
61    else if (L<14)
62       val <<= 14-L;
63    L <<= frac;
64    /*printf ("%d\n", val);*/
65    for (i=0;i<frac;i++)
66 {
67       val = (val*val) >> 15;
68       /*printf ("%d\n", val);*/
69       if (val > 16384)
70          L |= (1<<(frac-i-1));
71       else   
72          val <<= 1;
73 }
74    return L;
75 }
76 #endif
77
78 int log2_frac64(ec_uint64 val, int frac)
79 {
80    int i;
81    /* EC_ILOG64() actually returns log2()+1, go figure */
82    int L = EC_ILOG64(val)-1;
83    /*printf ("in: %d %d ", val, L);*/
84    if (L>14)
85       val >>= L-14;
86    else if (L<14)
87       val <<= 14-L;
88    L <<= frac;
89    /*printf ("%d\n", val);*/
90    for (i=0;i<frac;i++)
91    {
92       val = (val*val) >> 15;
93       /*printf ("%d\n", val);*/
94       if (val > 16384)
95          L |= (1<<(frac-i-1));
96       else   
97          val <<= 1;
98    }
99    return L;
100 }
101
102 int fits_in32(int _n, int _m)
103 {
104    static const celt_int16_t maxN[15] = {
105       255, 255, 255, 255, 255, 109,  60,  40,
106        29,  24,  20,  18,  16,  14,  13};
107    static const celt_int16_t maxM[15] = {
108       255, 255, 255, 255, 255, 238,  95,  53,
109        36,  27,  22,  18,  16,  15,  13};
110    if (_n>=14)
111    {
112       if (_m>=14)
113          return 0;
114       else
115          return _n <= maxN[_m];
116    } else {
117       return _m <= maxM[_n];
118    }   
119 }
120 int fits_in64(int _n, int _m)
121 {
122    static const celt_int16_t maxN[28] = {
123       255, 255, 255, 255, 255, 255, 255, 255,
124       255, 255, 178, 129, 100,  81,  68,  58,
125        51,  46,  42,  38,  36,  33,  31,  30,
126        28, 27, 26, 25};
127    static const celt_int16_t maxM[28] = {
128       255, 255, 255, 255, 255, 255, 255, 255, 
129       255, 255, 245, 166, 122,  94,  77,  64, 
130        56,  49,  44,  40,  37,  34,  32,  30,
131        29,  27,  26,  25};
132    if (_n>=27)
133    {
134       if (_m>=27)
135          return 0;
136       else
137          return _n <= maxN[_m];
138    } else {
139       return _m <= maxM[_n];
140    }
141 }
142
143 /*Computes the next row/column of any recurrence that obeys the relation
144    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
145   _ui0 is the base case for the new row/column.*/
146 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
147   celt_uint32_t ui1;
148   int           j;
149   /* doing a do-while would overrun the array if we had less than 2 samples */
150   j=1; do {
151     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
152     _ui[j-1]=_ui0;
153     _ui0=ui1;
154   } while (++j<_len);
155   _ui[j-1]=_ui0;
156 }
157
158 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
159   celt_uint64_t ui1;
160   int           j;
161   /* doing a do-while would overrun the array if we had less than 2 samples */
162   j=1; do {
163     ui1=_ui[j]+_ui[j-1]+_ui0;
164     _ui[j-1]=_ui0;
165     _ui0=ui1;
166   } while (++j<_len);
167   _ui[j-1]=_ui0;
168 }
169
170 /*Computes the previous row/column of any recurrence that obeys the relation
171    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
172   _ui0 is the base case for the new row/column.*/
173 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
174   celt_uint32_t ui1;
175   int           j;
176   /* doing a do-while would overrun the array if we had less than 2 samples */
177   j=1; do {
178     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
179     _ui[j-1]=_ui0;
180     _ui0=ui1;
181   } while (++j<_n);
182   _ui[j-1]=_ui0;
183 }
184
185 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
186   celt_uint64_t ui1;
187   int           j;
188   /* doing a do-while would overrun the array if we had less than 2 samples */
189   j=1; do {
190     ui1=_ui[j]-_ui[j-1]-_ui0;
191     _ui[j-1]=_ui0;
192     _ui0=ui1;
193   } while (++j<_n);
194   _ui[j-1]=_ui0;
195 }
196
197 /*Returns the number of ways of choosing _m elements from a set of size _n with
198    replacement when a sign bit is needed for each unique element.
199   On input, _u should be initialized to column (_m-1) of U(n,m).
200   On exit, _u will be initialized to column _m of U(n,m).*/
201 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
202   celt_uint32_t ret;
203   celt_uint32_t ui0;
204   celt_uint32_t ui1;
205   int           j;
206   ret=ui0=2;
207   celt_assert(_n>=2);
208   j=1; do {
209     ui1=_ui[j]+_ui[j-1]+ui0;
210     _ui[j-1]=ui0;
211     ui0=ui1;
212     ret+=ui0;
213   } while (++j<_n);
214   _ui[j-1]=ui0;
215   return ret;
216 }
217
218 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
219   celt_uint64_t ret;
220   celt_uint64_t ui0;
221   celt_uint64_t ui1;
222   int           j;
223   ret=ui0=2;
224   celt_assert(_n>=2);
225   j=1; do {
226     ui1=_ui[j]+_ui[j-1]+ui0;
227     _ui[j-1]=ui0;
228     ui0=ui1;
229     ret+=ui0;
230   } while (++j<_n);
231   _ui[j-1]=ui0;
232   return ret;
233 }
234
235 /*Returns the number of ways of choosing _m elements from a set of size _n with
236    replacement when a sign bit is needed for each unique element.
237   On exit, _u will be initialized to column _m of U(n,m).*/
238 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
239   int k;
240   CELT_MEMSET(_u,0,_n);
241   if(_m<=0)return 1;
242   if(_n<=0)return 0;
243   for(k=1;k<_m;k++)unext32(_u,_n,2);
244   return ncwrs_unext32(_n,_u);
245 }
246
247 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
248   int k;
249   CELT_MEMSET(_u,0,_n);
250   if(_m<=0)return 1;
251   if(_n<=0)return 0;
252   for(k=1;k<_m;k++)unext64(_u,_n,2);
253   return ncwrs_unext64(_n,_u);
254 }
255
256 /*Returns the _i'th combination of _m elements chosen from a set of size _n
257    with associated sign bits.
258   _x: Returns the combination with elements sorted in ascending order.
259   _s: Returns the associated sign bits.
260   _u: Temporary storage already initialized to column _m of U(n,m).
261       Its contents will be overwritten.*/
262 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
263   int j;
264   int k;
265   for(k=j=0;k<_m;k++){
266     celt_uint32_t p;
267     celt_uint32_t t;
268     p=_u[_n-j-1];
269     if(k>0){
270       t=p>>1;
271       if(t<=_i||_s[k-1])_i+=t;
272     }
273     while(p<=_i){
274       _i-=p;
275       j++;
276       p=_u[_n-j-1];
277     }
278     t=p>>1;
279     _s[k]=_i>=t;
280     _x[k]=j;
281     if(_s[k])_i-=t;
282     uprev32(_u,_n-j,2);
283   }
284 }
285
286 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
287   int j;
288   int k;
289   for(k=j=0;k<_m;k++){
290     celt_uint64_t p;
291     celt_uint64_t t;
292     p=_u[_n-j-1];
293     if(k>0){
294       t=p>>1;
295       if(t<=_i||_s[k-1])_i+=t;
296     }
297     while(p<=_i){
298       _i-=p;
299       j++;
300       p=_u[_n-j-1];
301     }
302     t=p>>1;
303     _s[k]=_i>=t;
304     _x[k]=j;
305     if(_s[k])_i-=t;
306     uprev64(_u,_n-j,2);
307   }
308 }
309
310 /*Returns the index of the given combination of _m elements chosen from a set
311    of size _n with associated sign bits.
312   _x: The combination with elements sorted in ascending order.
313   _s: The associated sign bits.
314   _u: Temporary storage already initialized to column _m of U(n,m).
315       Its contents will be overwritten.*/
316 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
317  celt_uint32_t *_u){
318   celt_uint32_t i;
319   int           j;
320   int           k;
321   i=0;
322   for(k=j=0;k<_m;k++){
323     celt_uint32_t p;
324     p=_u[_n-j-1];
325     if(k>0)p>>=1;
326     while(j<_x[k]){
327       i+=p;
328       j++;
329       p=_u[_n-j-1];
330     }
331     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
332     uprev32(_u,_n-j,2);
333   }
334   return i;
335 }
336
337 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
338  celt_uint64_t *_u){
339   celt_uint64_t i;
340   int           j;
341   int           k;
342   i=0;
343   for(k=j=0;k<_m;k++){
344     celt_uint64_t p;
345     p=_u[_n-j-1];
346     if(k>0)p>>=1;
347     while(j<_x[k]){
348       i+=p;
349       j++;
350       p=_u[_n-j-1];
351     }
352     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
353     uprev64(_u,_n-j,2);
354   }
355   return i;
356 }
357
358 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
359    a pulse vector _y of length _n.
360   _y: Returns the vector of pulses.
361   _x: The combination with elements sorted in ascending order. _x[_m] = -1
362   _s: The associated sign bits.*/
363 void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
364   int k;
365   const int signs[2]={1,-1};
366   CELT_MEMSET(_y, 0, _n);
367   k=0; do {
368     _y[_x[k]]+=signs[_s[k]];
369   } while (++k<_m);
370 }
371
372 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
373    pulses with associated sign bits _s.
374   _x: Returns the combination with elements sorted in ascending order.
375   _s: Returns the associated sign bits.
376   _y: The vector of pulses, whose sum of absolute values must be _m.*/
377 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
378   int j;
379   int k;
380   for(k=j=0;j<_n;j++){
381     if(_y[j]){
382       int n;
383       int s;
384       n=abs(_y[j]);
385       s=_y[j]<0;
386       do {
387         _x[k]=j;
388         _s[k]=s;
389         k++;
390       } while (--n>0);
391     }
392   }
393 }
394
395 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
396  ec_enc *_enc){
397   VARDECL(celt_uint32_t,u);
398   celt_uint32_t nc;
399   celt_uint32_t i;
400   SAVE_STACK;
401   ALLOC(u,_n,celt_uint32_t);
402   nc=ncwrs_u32(_n,_m,u);
403   i=icwrs32(_n,_m,_x,_s,u);
404   ec_enc_uint(_enc,i,nc);
405   RESTORE_STACK;
406 }
407
408 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
409  ec_enc *_enc){
410   VARDECL(celt_uint64_t,u);
411   celt_uint64_t nc;
412   celt_uint64_t i;
413   SAVE_STACK;
414   ALLOC(u,_n,celt_uint64_t);
415   nc=ncwrs_u64(_n,_m,u);
416   i=icwrs64(_n,_m,_x,_s,u);
417   ec_enc_uint64(_enc,i,nc);
418   RESTORE_STACK;
419 }
420
421 int get_required_bits(int N, int K, int frac)
422 {
423    int nbits = 0;
424    if(fits_in64(N,K))
425    {
426       VARDECL(celt_uint64_t,u);
427       SAVE_STACK;
428       ALLOC(u,N,celt_uint64_t);
429       nbits = log2_frac64(ncwrs_u64(N,K,u), frac);
430       RESTORE_STACK;
431    } else {
432       nbits = log2_frac64(N, frac);
433       nbits += get_required_bits(N/2+1, (K+1)/2, frac);
434       nbits += get_required_bits(N/2+1, K/2, frac);
435    }
436    return nbits;
437 }
438
439
440 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
441 {
442    VARDECL(int, comb);
443    VARDECL(int, signs);
444    SAVE_STACK;
445
446    ALLOC(comb, K, int);
447    ALLOC(signs, K, int);
448
449    pulse2comb(N, K, comb, signs, _y);
450    if (K==0) {
451    } else if (N==1)
452    {
453       ec_enc_bits(enc, _y[0]<0, 1);
454    } else if(fits_in32(N,K))
455    {
456       encode_comb32(N, K, comb, signs, enc);
457    } else if(fits_in64(N,K)) {
458       encode_comb64(N, K, comb, signs, enc);
459    } else {
460      int i;
461      int count=0;
462      int split;
463      split = (N+1)/2;
464      for (i=0;i<split;i++)
465         count += abs(_y[i]);
466      ec_enc_uint(enc,count,K+1);
467      encode_pulses(_y, split, count, enc);
468      encode_pulses(_y+split, N-split, K-count, enc);
469    }
470    RESTORE_STACK;
471 }
472
473 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
474   VARDECL(celt_uint32_t,u);
475   SAVE_STACK;
476   ALLOC(u,_n,celt_uint32_t);
477   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
478   RESTORE_STACK;
479 }
480
481 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
482   VARDECL(celt_uint64_t,u);
483   SAVE_STACK;
484   ALLOC(u,_n,celt_uint64_t);
485   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
486   RESTORE_STACK;
487 }
488
489 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
490 {
491    VARDECL(int, comb);
492    VARDECL(int, signs);
493    SAVE_STACK;
494
495    ALLOC(comb, K, int);
496    ALLOC(signs, K, int);
497    if (K==0) {
498       int i;
499       for (i=0;i<N;i++)
500          _y[i] = 0;
501    } else if (N==1)
502    {
503       int s = ec_dec_bits(dec, 1);
504       if (s==0)
505          _y[0] = K;
506       else
507          _y[0] = -K;
508    } else if(fits_in32(N,K))
509    {
510       decode_comb32(N, K, comb, signs, dec);
511       comb2pulse(N, K, _y, comb, signs);
512    } else if(fits_in64(N,K)) {
513       decode_comb64(N, K, comb, signs, dec);
514       comb2pulse(N, K, _y, comb, signs);
515    } else {
516      int split;
517      int count = ec_dec_uint(dec,K+1);
518      split = (N+1)/2;
519      decode_pulses(_y, split, count, dec);
520      decode_pulses(_y+split, N-split, K-count, dec);
521    }
522    RESTORE_STACK;
523 }