Doing the cwrs split in dimensions should save a few bits.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50
51 #if 0
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    /*printf ("in: %d %d ", val, L);*/
58    if (L>14)
59       val >>= L-14;
60    else if (L<14)
61       val <<= 14-L;
62    L <<= frac;
63    /*printf ("%d\n", val);*/
64    for (i=0;i<frac;i++)
65 {
66       val = (val*val) >> 15;
67       /*printf ("%d\n", val);*/
68       if (val > 16384)
69          L |= (1<<(frac-i-1));
70       else   
71          val <<= 1;
72 }
73    return L;
74 }
75 #endif
76
77 int log2_frac64(ec_uint64 val, int frac)
78 {
79    int i;
80    /* EC_ILOG64() actually returns log2()+1, go figure */
81    int L = EC_ILOG64(val)-1;
82    /*printf ("in: %d %d ", val, L);*/
83    if (L>14)
84       val >>= L-14;
85    else if (L<14)
86       val <<= 14-L;
87    L <<= frac;
88    /*printf ("%d\n", val);*/
89    for (i=0;i<frac;i++)
90    {
91       val = (val*val) >> 15;
92       /*printf ("%d\n", val);*/
93       if (val > 16384)
94          L |= (1<<(frac-i-1));
95       else   
96          val <<= 1;
97    }
98    return L;
99 }
100
101 int fits_in32(int _n, int _m)
102 {
103    static const celt_int16_t maxN[15] = {
104       255, 255, 255, 255, 255, 109,  60,  40,
105        29,  24,  20,  18,  16,  14,  13};
106    static const celt_int16_t maxM[15] = {
107       255, 255, 255, 255, 255, 238,  95,  53,
108        36,  27,  22,  18,  16,  15,  13};
109    if (_n>=14)
110    {
111       if (_m>=14)
112          return 0;
113       else
114          return _n <= maxN[_m];
115    } else {
116       return _m <= maxM[_n];
117    }   
118 }
119 int fits_in64(int _n, int _m)
120 {
121    static const celt_int16_t maxN[28] = {
122       255, 255, 255, 255, 255, 255, 255, 255,
123       255, 255, 178, 129, 100,  81,  68,  58,
124        51,  46,  42,  38,  36,  33,  31,  30,
125        28, 27, 26, 25};
126    static const celt_int16_t maxM[28] = {
127       255, 255, 255, 255, 255, 255, 255, 255, 
128       255, 255, 245, 166, 122,  94,  77,  64, 
129        56,  49,  44,  40,  37,  34,  32,  30,
130        29,  27,  26,  25};
131    if (_n>=27)
132    {
133       if (_m>=27)
134          return 0;
135       else
136          return _n <= maxN[_m];
137    } else {
138       return _m <= maxM[_n];
139    }
140 }
141
142 /*Computes the next row/column of any recurrence that obeys the relation
143    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
144   _ui0 is the base case for the new row/column.*/
145 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
146   celt_uint32_t ui1;
147   int           j;
148   /* doing a do-while would overrun the array if we had less than 2 samples */
149   j=1; do {
150     ui1=_ui[j]+_ui[j-1]+_ui0;
151     _ui[j-1]=_ui0;
152     _ui0=ui1;
153   } while (++j<_len);
154   _ui[j-1]=_ui0;
155 }
156
157 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
158   celt_uint64_t ui1;
159   int           j;
160   /* doing a do-while would overrun the array if we had less than 2 samples */
161   j=1; do {
162     ui1=_ui[j]+_ui[j-1]+_ui0;
163     _ui[j-1]=_ui0;
164     _ui0=ui1;
165   } while (++j<_len);
166   _ui[j-1]=_ui0;
167 }
168
169 /*Computes the previous row/column of any recurrence that obeys the relation
170    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
171   _ui0 is the base case for the new row/column.*/
172 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
173   celt_uint32_t ui1;
174   int           j;
175   /* doing a do-while would overrun the array if we had less than 2 samples */
176   j=1; do {
177     ui1=_ui[j]-_ui[j-1]-_ui0;
178     _ui[j-1]=_ui0;
179     _ui0=ui1;
180   } while (++j<_n);
181   _ui[j-1]=_ui0;
182 }
183
184 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
185   celt_uint64_t ui1;
186   int           j;
187   /* doing a do-while would overrun the array if we had less than 2 samples */
188   j=1; do {
189     ui1=_ui[j]-_ui[j-1]-_ui0;
190     _ui[j-1]=_ui0;
191     _ui0=ui1;
192   } while (++j<_n);
193   _ui[j-1]=_ui0;
194 }
195
196 /*Returns the number of ways of choosing _m elements from a set of size _n with
197    replacement when a sign bit is needed for each unique element.
198   On input, _u should be initialized to column (_m-1) of U(n,m).
199   On exit, _u will be initialized to column _m of U(n,m).*/
200 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
201   celt_uint32_t ret;
202   celt_uint32_t ui0;
203   celt_uint32_t ui1;
204   int           j;
205   ret=ui0=2;
206   celt_assert(_n>=2);
207   j=1; do {
208     ui1=_ui[j]+_ui[j-1]+ui0;
209     _ui[j-1]=ui0;
210     ui0=ui1;
211     ret+=ui0;
212   } while (++j<_n);
213   _ui[j-1]=ui0;
214   return ret;
215 }
216
217 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
218   celt_uint64_t ret;
219   celt_uint64_t ui0;
220   celt_uint64_t ui1;
221   int           j;
222   ret=ui0=2;
223   celt_assert(_n>=2);
224   j=1; do {
225     ui1=_ui[j]+_ui[j-1]+ui0;
226     _ui[j-1]=ui0;
227     ui0=ui1;
228     ret+=ui0;
229   } while (++j<_n);
230   _ui[j-1]=ui0;
231   return ret;
232 }
233
234 /*Returns the number of ways of choosing _m elements from a set of size _n with
235    replacement when a sign bit is needed for each unique element.
236   On exit, _u will be initialized to column _m of U(n,m).*/
237 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
238   int k;
239   CELT_MEMSET(_u,0,_n);
240   if(_m<=0)return 1;
241   if(_n<=0)return 0;
242   for(k=1;k<_m;k++)unext32(_u,_n,2);
243   return ncwrs_unext32(_n,_u);
244 }
245
246 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
247   int k;
248   CELT_MEMSET(_u,0,_n);
249   if(_m<=0)return 1;
250   if(_n<=0)return 0;
251   for(k=1;k<_m;k++)unext64(_u,_n,2);
252   return ncwrs_unext64(_n,_u);
253 }
254
255 /*Returns the _i'th combination of _m elements chosen from a set of size _n
256    with associated sign bits.
257   _x: Returns the combination with elements sorted in ascending order.
258   _s: Returns the associated sign bits.
259   _u: Temporary storage already initialized to column _m of U(n,m).
260       Its contents will be overwritten.*/
261 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
262   int j;
263   int k;
264   for(k=j=0;k<_m;k++){
265     celt_uint32_t p;
266     celt_uint32_t t;
267     p=_u[_n-j-1];
268     if(k>0){
269       t=p>>1;
270       if(t<=_i||_s[k-1])_i+=t;
271     }
272     while(p<=_i){
273       _i-=p;
274       j++;
275       p=_u[_n-j-1];
276     }
277     t=p>>1;
278     _s[k]=_i>=t;
279     _x[k]=j;
280     if(_s[k])_i-=t;
281     uprev32(_u,_n-j,2);
282   }
283 }
284
285 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
286   int j;
287   int k;
288   for(k=j=0;k<_m;k++){
289     celt_uint64_t p;
290     celt_uint64_t t;
291     p=_u[_n-j-1];
292     if(k>0){
293       t=p>>1;
294       if(t<=_i||_s[k-1])_i+=t;
295     }
296     while(p<=_i){
297       _i-=p;
298       j++;
299       p=_u[_n-j-1];
300     }
301     t=p>>1;
302     _s[k]=_i>=t;
303     _x[k]=j;
304     if(_s[k])_i-=t;
305     uprev64(_u,_n-j,2);
306   }
307 }
308
309 /*Returns the index of the given combination of _m elements chosen from a set
310    of size _n with associated sign bits.
311   _x: The combination with elements sorted in ascending order.
312   _s: The associated sign bits.
313   _u: Temporary storage already initialized to column _m of U(n,m).
314       Its contents will be overwritten.*/
315 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
316  celt_uint32_t *_u){
317   celt_uint32_t i;
318   int           j;
319   int           k;
320   i=0;
321   for(k=j=0;k<_m;k++){
322     celt_uint32_t p;
323     p=_u[_n-j-1];
324     if(k>0)p>>=1;
325     while(j<_x[k]){
326       i+=p;
327       j++;
328       p=_u[_n-j-1];
329     }
330     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
331     uprev32(_u,_n-j,2);
332   }
333   return i;
334 }
335
336 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
337  celt_uint64_t *_u){
338   celt_uint64_t i;
339   int           j;
340   int           k;
341   i=0;
342   for(k=j=0;k<_m;k++){
343     celt_uint64_t p;
344     p=_u[_n-j-1];
345     if(k>0)p>>=1;
346     while(j<_x[k]){
347       i+=p;
348       j++;
349       p=_u[_n-j-1];
350     }
351     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
352     uprev64(_u,_n-j,2);
353   }
354   return i;
355 }
356
357 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
358    a pulse vector _y of length _n.
359   _y: Returns the vector of pulses.
360   _x: The combination with elements sorted in ascending order. _x[_m] = -1
361   _s: The associated sign bits.*/
362 void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
363   int k;
364   const int signs[2]={1,-1};
365   CELT_MEMSET(_y, 0, _n);
366   k=0; do {
367     _y[_x[k]]+=signs[_s[k]];
368   } while (++k<_m);
369 }
370
371 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
372    pulses with associated sign bits _s.
373   _x: Returns the combination with elements sorted in ascending order.
374   _s: Returns the associated sign bits.
375   _y: The vector of pulses, whose sum of absolute values must be _m.*/
376 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
377   int j;
378   int k;
379   for(k=j=0;j<_n;j++){
380     if(_y[j]){
381       int n;
382       int s;
383       n=abs(_y[j]);
384       s=_y[j]<0;
385       do {
386         _x[k]=j;
387         _s[k]=s;
388         k++;
389       } while (--n>0);
390     }
391   }
392 }
393
394 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
395  ec_enc *_enc){
396   VARDECL(celt_uint32_t,u);
397   celt_uint32_t nc;
398   celt_uint32_t i;
399   SAVE_STACK;
400   ALLOC(u,_n,celt_uint32_t);
401   nc=ncwrs_u32(_n,_m,u);
402   i=icwrs32(_n,_m,_x,_s,u);
403   ec_enc_uint(_enc,i,nc);
404   RESTORE_STACK;
405 }
406
407 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
408  ec_enc *_enc){
409   VARDECL(celt_uint64_t,u);
410   celt_uint64_t nc;
411   celt_uint64_t i;
412   SAVE_STACK;
413   ALLOC(u,_n,celt_uint64_t);
414   nc=ncwrs_u64(_n,_m,u);
415   i=icwrs64(_n,_m,_x,_s,u);
416   ec_enc_uint64(_enc,i,nc);
417   RESTORE_STACK;
418 }
419
420 int get_required_bits(int N, int K, int frac)
421 {
422    int nbits = 0;
423    if(fits_in64(N,K))
424    {
425       VARDECL(celt_uint64_t,u);
426       SAVE_STACK;
427       ALLOC(u,N,celt_uint64_t);
428       nbits = log2_frac64(ncwrs_u64(N,K,u), frac);
429       RESTORE_STACK;
430    } else {
431       nbits = log2_frac64(N, frac);
432       nbits += get_required_bits((N+1)/2, (K+1)/2, frac);
433       nbits += get_required_bits((N+1)/2, K/2, frac);
434    }
435    return nbits;
436 }
437
438
439 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
440 {
441    VARDECL(int, comb);
442    VARDECL(int, signs);
443    SAVE_STACK;
444
445    ALLOC(comb, K, int);
446    ALLOC(signs, K, int);
447
448    pulse2comb(N, K, comb, signs, _y);
449    if (K==0) {
450    } else if (N==1)
451    {
452       ec_enc_bits(enc, _y[0]<0, 1);
453    } else if(fits_in32(N,K))
454    {
455       encode_comb32(N, K, comb, signs, enc);
456    } else if(fits_in64(N,K)) {
457       encode_comb64(N, K, comb, signs, enc);
458    } else {
459      int i;
460      int count=0;
461      int split;
462      split = (N+1)/2;
463      for (i=0;i<split;i++)
464         count += abs(_y[i]);
465      ec_enc_uint(enc,count,K+1);
466      encode_pulses(_y, split, count, enc);
467      encode_pulses(_y+split, N-split, K-count, enc);
468    }
469    RESTORE_STACK;
470 }
471
472 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
473   VARDECL(celt_uint32_t,u);
474   SAVE_STACK;
475   ALLOC(u,_n,celt_uint32_t);
476   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
477   RESTORE_STACK;
478 }
479
480 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
481   VARDECL(celt_uint64_t,u);
482   SAVE_STACK;
483   ALLOC(u,_n,celt_uint64_t);
484   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
485   RESTORE_STACK;
486 }
487
488 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
489 {
490    VARDECL(int, comb);
491    VARDECL(int, signs);
492    SAVE_STACK;
493
494    ALLOC(comb, K, int);
495    ALLOC(signs, K, int);
496    if (K==0) {
497       int i;
498       for (i=0;i<N;i++)
499          _y[i] = 0;
500    } else if (N==1)
501    {
502       int s = ec_dec_bits(dec, 1);
503       if (s==0)
504          _y[0] = K;
505       else
506          _y[0] = -K;
507    } else if(fits_in32(N,K))
508    {
509       decode_comb32(N, K, comb, signs, dec);
510       comb2pulse(N, K, _y, comb, signs);
511    } else if(fits_in64(N,K)) {
512       decode_comb64(N, K, comb, signs, dec);
513       comb2pulse(N, K, _y, comb, signs);
514    } else {
515      int split;
516      int count = ec_dec_uint(dec,K+1);
517      split = (N+1)/2;
518      decode_pulses(_y, split, count, dec);
519      decode_pulses(_y+split, N-split, K-count, dec);
520    }
521    RESTORE_STACK;
522 }