Implemented a cleaner way to detect whether CWRS codebooks fit in 32 or 64 bits
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50
51 int fits_in32(int _n, int _m)
52 {
53    static const celt_int16_t maxN[15] = {
54       255, 255, 255, 255, 255, 238,  95,  53,
55        36,  27,  22,  18,  16,  15,  13};
56    static const celt_int16_t maxM[28] = {
57       255, 255, 255, 255, 255, 109,  60,  40,
58        29,  24,  20,  18,  16,  14,  13};
59    if (_n>=14)
60    {
61       if (_m>=14)
62          return 0;
63       else
64          return _n <= maxN[_m];
65    } else {
66       return _m <= maxM[_n];
67    }   
68 }
69 int fits_in64(int _n, int _m)
70 {
71    static const celt_int16_t maxN[28] = {
72       255, 255, 255, 255, 255, 255, 255, 255, 
73       255, 255, 245, 166, 122,  94,  77,  64, 
74        56,  49,  44,  40,  37,  34,  32,  30,
75        29,  27,  26,  25};
76    static const celt_int16_t maxM[28] = {
77       255, 255, 255, 255, 255, 255, 255, 255,
78       255, 255, 178, 129, 100,  81,  68,  58,
79        51,  46,  42,  38,  36,  33,  31,  30,
80        28, 27, 26, 25};
81    if (_n>=27)
82    {
83       if (_m>=27)
84          return 0;
85       else
86          return _n <= maxN[_m];
87    } else {
88       return _m <= maxM[_n];
89    }
90 }
91
92 /*Computes the next row/column of any recurrence that obeys the relation
93    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
94   _ui0 is the base case for the new row/column.*/
95 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
96   celt_uint32_t ui1;
97   int           j;
98   /* doing a do-while would overrun the array if we had less than 2 samples */
99   j=1; do {
100     ui1=_ui[j]+_ui[j-1]+_ui0;
101     _ui[j-1]=_ui0;
102     _ui0=ui1;
103   } while (++j<_len);
104   _ui[j-1]=_ui0;
105 }
106
107 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
108   celt_uint64_t ui1;
109   int           j;
110   /* doing a do-while would overrun the array if we had less than 2 samples */
111   j=1; do {
112     ui1=_ui[j]+_ui[j-1]+_ui0;
113     _ui[j-1]=_ui0;
114     _ui0=ui1;
115   } while (++j<_len);
116   _ui[j-1]=_ui0;
117 }
118
119 /*Computes the previous row/column of any recurrence that obeys the relation
120    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
121   _ui0 is the base case for the new row/column.*/
122 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
123   celt_uint32_t ui1;
124   int           j;
125   /* doing a do-while would overrun the array if we had less than 2 samples */
126   j=1; do {
127     ui1=_ui[j]-_ui[j-1]-_ui0;
128     _ui[j-1]=_ui0;
129     _ui0=ui1;
130   } while (++j<_n);
131   _ui[j-1]=_ui0;
132 }
133
134 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
135   celt_uint64_t ui1;
136   int           j;
137   /* doing a do-while would overrun the array if we had less than 2 samples */
138   j=1; do {
139     ui1=_ui[j]-_ui[j-1]-_ui0;
140     _ui[j-1]=_ui0;
141     _ui0=ui1;
142   } while (++j<_n);
143   _ui[j-1]=_ui0;
144 }
145
146 /*Returns the number of ways of choosing _m elements from a set of size _n with
147    replacement when a sign bit is needed for each unique element.
148   On input, _u should be initialized to column (_m-1) of U(n,m).
149   On exit, _u will be initialized to column _m of U(n,m).*/
150 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
151   celt_uint32_t ret;
152   celt_uint32_t ui0;
153   celt_uint32_t ui1;
154   int           j;
155   ret=ui0=2;
156   celt_assert(_n>=2);
157   j=1; do {
158     ui1=_ui[j]+_ui[j-1]+ui0;
159     _ui[j-1]=ui0;
160     ui0=ui1;
161     ret+=ui0;
162   } while (++j<_n);
163   _ui[j-1]=ui0;
164   return ret;
165 }
166
167 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
168   celt_uint64_t ret;
169   celt_uint64_t ui0;
170   celt_uint64_t ui1;
171   int           j;
172   ret=ui0=2;
173   celt_assert(_n>=2);
174   j=1; do {
175     ui1=_ui[j]+_ui[j-1]+ui0;
176     _ui[j-1]=ui0;
177     ui0=ui1;
178     ret+=ui0;
179   } while (++j<_n);
180   _ui[j-1]=ui0;
181   return ret;
182 }
183
184 /*Returns the number of ways of choosing _m elements from a set of size _n with
185    replacement when a sign bit is needed for each unique element.
186   On exit, _u will be initialized to column _m of U(n,m).*/
187 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
188   int k;
189   CELT_MEMSET(_u,0,_n);
190   if(_m<=0)return 1;
191   if(_n<=0)return 0;
192   for(k=1;k<_m;k++)unext32(_u,_n,2);
193   return ncwrs_unext32(_n,_u);
194 }
195
196 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
197   int k;
198   CELT_MEMSET(_u,0,_n);
199   if(_m<=0)return 1;
200   if(_n<=0)return 0;
201   for(k=1;k<_m;k++)unext64(_u,_n,2);
202   return ncwrs_unext64(_n,_u);
203 }
204
205 /*Returns the _i'th combination of _m elements chosen from a set of size _n
206    with associated sign bits.
207   _x: Returns the combination with elements sorted in ascending order.
208   _s: Returns the associated sign bits.
209   _u: Temporary storage already initialized to column _m of U(n,m).
210       Its contents will be overwritten.*/
211 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
212   int j;
213   int k;
214   for(k=j=0;k<_m;k++){
215     celt_uint32_t p;
216     celt_uint32_t t;
217     p=_u[_n-j-1];
218     if(k>0){
219       t=p>>1;
220       if(t<=_i||_s[k-1])_i+=t;
221     }
222     while(p<=_i){
223       _i-=p;
224       j++;
225       p=_u[_n-j-1];
226     }
227     t=p>>1;
228     _s[k]=_i>=t;
229     _x[k]=j;
230     if(_s[k])_i-=t;
231     uprev32(_u,_n-j,2);
232   }
233 }
234
235 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
236   int j;
237   int k;
238   for(k=j=0;k<_m;k++){
239     celt_uint64_t p;
240     celt_uint64_t t;
241     p=_u[_n-j-1];
242     if(k>0){
243       t=p>>1;
244       if(t<=_i||_s[k-1])_i+=t;
245     }
246     while(p<=_i){
247       _i-=p;
248       j++;
249       p=_u[_n-j-1];
250     }
251     t=p>>1;
252     _s[k]=_i>=t;
253     _x[k]=j;
254     if(_s[k])_i-=t;
255     uprev64(_u,_n-j,2);
256   }
257 }
258
259 /*Returns the index of the given combination of _m elements chosen from a set
260    of size _n with associated sign bits.
261   _x: The combination with elements sorted in ascending order.
262   _s: The associated sign bits.
263   _u: Temporary storage already initialized to column _m of U(n,m).
264       Its contents will be overwritten.*/
265 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
266  celt_uint32_t *_u){
267   celt_uint32_t i;
268   int           j;
269   int           k;
270   i=0;
271   for(k=j=0;k<_m;k++){
272     celt_uint32_t p;
273     p=_u[_n-j-1];
274     if(k>0)p>>=1;
275     while(j<_x[k]){
276       i+=p;
277       j++;
278       p=_u[_n-j-1];
279     }
280     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
281     uprev32(_u,_n-j,2);
282   }
283   return i;
284 }
285
286 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
287  celt_uint64_t *_u){
288   celt_uint64_t i;
289   int           j;
290   int           k;
291   i=0;
292   for(k=j=0;k<_m;k++){
293     celt_uint64_t p;
294     p=_u[_n-j-1];
295     if(k>0)p>>=1;
296     while(j<_x[k]){
297       i+=p;
298       j++;
299       p=_u[_n-j-1];
300     }
301     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
302     uprev64(_u,_n-j,2);
303   }
304   return i;
305 }
306
307 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
308    a pulse vector _y of length _n.
309   _y: Returns the vector of pulses.
310   _x: The combination with elements sorted in ascending order. _x[_m] = -1
311   _s: The associated sign bits.*/
312 void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
313   int k;
314   const int signs[2]={1,-1};
315   CELT_MEMSET(_y, 0, _n);
316   k=0; do {
317     _y[_x[k]]+=signs[_s[k]];
318   } while (++k<_m);
319 }
320
321 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
322    pulses with associated sign bits _s.
323   _x: Returns the combination with elements sorted in ascending order.
324   _s: Returns the associated sign bits.
325   _y: The vector of pulses, whose sum of absolute values must be _m.*/
326 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
327   int j;
328   int k;
329   for(k=j=0;j<_n;j++){
330     if(_y[j]){
331       int n;
332       int s;
333       n=abs(_y[j]);
334       s=_y[j]<0;
335       do {
336         _x[k]=j;
337         _s[k]=s;
338         k++;
339       } while (--n>0);
340     }
341   }
342 }
343
344 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
345  ec_enc *_enc){
346   VARDECL(celt_uint32_t,u);
347   celt_uint32_t nc;
348   celt_uint32_t i;
349   SAVE_STACK;
350   ALLOC(u,_n,celt_uint32_t);
351   nc=ncwrs_u32(_n,_m,u);
352   i=icwrs32(_n,_m,_x,_s,u);
353   ec_enc_uint(_enc,i,nc);
354   RESTORE_STACK;
355 }
356
357 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
358  ec_enc *_enc){
359   VARDECL(celt_uint64_t,u);
360   celt_uint64_t nc;
361   celt_uint64_t i;
362   SAVE_STACK;
363   ALLOC(u,_n,celt_uint64_t);
364   nc=ncwrs_u64(_n,_m,u);
365   i=icwrs64(_n,_m,_x,_s,u);
366   ec_enc_uint64(_enc,i,nc);
367   RESTORE_STACK;
368 }
369
370 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
371 {
372    VARDECL(int, comb);
373    VARDECL(int, signs);
374    SAVE_STACK;
375
376    ALLOC(comb, K, int);
377    ALLOC(signs, K, int);
378
379    pulse2comb(N, K, comb, signs, _y);
380    /* Simple heuristic to figure out whether it fits in 32 bits */
381    if(fits_in32(N,K))
382    {
383       encode_comb32(N, K, comb, signs, enc);
384    } else {
385       encode_comb64(N, K, comb, signs, enc);
386    }
387    RESTORE_STACK;
388 }
389
390 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
391   VARDECL(celt_uint32_t,u);
392   SAVE_STACK;
393   ALLOC(u,_n,celt_uint32_t);
394   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
395   RESTORE_STACK;
396 }
397
398 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
399   VARDECL(celt_uint64_t,u);
400   SAVE_STACK;
401   ALLOC(u,_n,celt_uint64_t);
402   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
403   RESTORE_STACK;
404 }
405
406 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
407 {
408    VARDECL(int, comb);
409    VARDECL(int, signs);
410    SAVE_STACK;
411
412    ALLOC(comb, K, int);
413    ALLOC(signs, K, int);
414    /* Simple heuristic to figure out whether it fits in 32 bits */
415    if(fits_in32(N,K))
416    {
417       decode_comb32(N, K, comb, signs, dec);
418    } else {
419       decode_comb64(N, K, comb, signs, dec);
420    }
421    comb2pulse(N, K, _y, comb, signs);
422    RESTORE_STACK;
423 }