Converted everything to 32-bit CWRS (using split after that)
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50 #include "arch.h"
51
52 int log2_frac(ec_uint32 val, int frac)
53 {
54    int i;
55    /* EC_ILOG() actually returns log2()+1, go figure */
56    int L = EC_ILOG(val)-1;
57    /*printf ("in: %d %d ", val, L);*/
58    if (L>14)
59       val >>= L-14;
60    else if (L<14)
61       val <<= 14-L;
62    L <<= frac;
63    /*printf ("%d\n", val);*/
64    for (i=0;i<frac;i++)
65 {
66       val = (val*val) >> 15;
67       /*printf ("%d\n", val);*/
68       if (val > 16384)
69          L |= (1<<(frac-i-1));
70       else   
71          val <<= 1;
72 }
73    return L;
74 }
75
76 int fits_in32(int _n, int _m)
77 {
78    static const celt_int16_t maxN[15] = {
79       255, 255, 255, 255, 255, 109,  60,  40,
80        29,  24,  20,  18,  16,  14,  13};
81    static const celt_int16_t maxM[15] = {
82       255, 255, 255, 255, 255, 238,  95,  53,
83        36,  27,  22,  18,  16,  15,  13};
84    if (_n>=14)
85    {
86       if (_m>=14)
87          return 0;
88       else
89          return _n <= maxN[_m];
90    } else {
91       return _m <= maxM[_n];
92    }   
93 }
94
95 /*Computes the next row/column of any recurrence that obeys the relation
96    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
97   _ui0 is the base case for the new row/column.*/
98 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
99   celt_uint32_t ui1;
100   int           j;
101   /* doing a do-while would overrun the array if we had less than 2 samples */
102   j=1; do {
103     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
104     _ui[j-1]=_ui0;
105     _ui0=ui1;
106   } while (++j<_len);
107   _ui[j-1]=_ui0;
108 }
109
110 /*Computes the previous row/column of any recurrence that obeys the relation
111    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
112   _ui0 is the base case for the new row/column.*/
113 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
114   celt_uint32_t ui1;
115   int           j;
116   /* doing a do-while would overrun the array if we had less than 2 samples */
117   j=1; do {
118     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
119     _ui[j-1]=_ui0;
120     _ui0=ui1;
121   } while (++j<_n);
122   _ui[j-1]=_ui0;
123 }
124
125 /*Returns the number of ways of choosing _m elements from a set of size _n with
126    replacement when a sign bit is needed for each unique element.
127   On input, _u should be initialized to column (_m-1) of U(n,m).
128   On exit, _u will be initialized to column _m of U(n,m).*/
129 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
130   celt_uint32_t ret;
131   celt_uint32_t ui0;
132   celt_uint32_t ui1;
133   int           j;
134   ret=ui0=2;
135   celt_assert(_n>=2);
136   j=1; do {
137     ui1=_ui[j]+_ui[j-1]+ui0;
138     _ui[j-1]=ui0;
139     ui0=ui1;
140     ret+=ui0;
141   } while (++j<_n);
142   _ui[j-1]=ui0;
143   return ret;
144 }
145
146 /*Returns the number of ways of choosing _m elements from a set of size _n with
147    replacement when a sign bit is needed for each unique element.
148   On exit, _u will be initialized to column _m of U(n,m).*/
149 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
150   int k;
151   CELT_MEMSET(_u,0,_n);
152   if(_m<=0)return 1;
153   if(_n<=0)return 0;
154   for(k=1;k<_m;k++)unext32(_u,_n,2);
155   return ncwrs_unext32(_n,_u);
156 }
157
158 /*Returns the _i'th combination of _m elements chosen from a set of size _n
159    with associated sign bits.
160   _x: Returns the combination with elements sorted in ascending order.
161   _s: Returns the associated sign bits.
162   _u: Temporary storage already initialized to column _m of U(n,m).
163       Its contents will be overwritten.*/
164 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
165   int j;
166   int k;
167   for(k=j=0;k<_m;k++){
168     celt_uint32_t p;
169     celt_uint32_t t;
170     p=_u[_n-j-1];
171     if(k>0){
172       t=p>>1;
173       if(t<=_i||_s[k-1])_i+=t;
174     }
175     while(p<=_i){
176       _i-=p;
177       j++;
178       p=_u[_n-j-1];
179     }
180     t=p>>1;
181     _s[k]=_i>=t;
182     _x[k]=j;
183     if(_s[k])_i-=t;
184     uprev32(_u,_n-j,2);
185   }
186 }
187
188 /*Returns the index of the given combination of _m elements chosen from a set
189    of size _n with associated sign bits.
190   _x: The combination with elements sorted in ascending order.
191   _s: The associated sign bits.
192   _u: Temporary storage already initialized to column _m of U(n,m).
193       Its contents will be overwritten.*/
194 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
195  celt_uint32_t *_u){
196   celt_uint32_t i;
197   int           j;
198   int           k;
199   i=0;
200   for(k=j=0;k<_m;k++){
201     celt_uint32_t p;
202     p=_u[_n-j-1];
203     if(k>0)p>>=1;
204     while(j<_x[k]){
205       i+=p;
206       j++;
207       p=_u[_n-j-1];
208     }
209     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
210     uprev32(_u,_n-j,2);
211   }
212   return i;
213 }
214
215 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
216    a pulse vector _y of length _n.
217   _y: Returns the vector of pulses.
218   _x: The combination with elements sorted in ascending order. _x[_m] = -1
219   _s: The associated sign bits.*/
220 void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
221   int k;
222   const int signs[2]={1,-1};
223   CELT_MEMSET(_y, 0, _n);
224   k=0; do {
225     _y[_x[k]]+=signs[_s[k]];
226   } while (++k<_m);
227 }
228
229 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
230    pulses with associated sign bits _s.
231   _x: Returns the combination with elements sorted in ascending order.
232   _s: Returns the associated sign bits.
233   _y: The vector of pulses, whose sum of absolute values must be _m.*/
234 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
235   int j;
236   int k;
237   for(k=j=0;j<_n;j++){
238     if(_y[j]){
239       int n;
240       int s;
241       n=abs(_y[j]);
242       s=_y[j]<0;
243       do {
244         _x[k]=j;
245         _s[k]=s;
246         k++;
247       } while (--n>0);
248     }
249   }
250 }
251
252 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
253  ec_enc *_enc){
254   VARDECL(celt_uint32_t,u);
255   celt_uint32_t nc;
256   celt_uint32_t i;
257   SAVE_STACK;
258   ALLOC(u,_n,celt_uint32_t);
259   nc=ncwrs_u32(_n,_m,u);
260   i=icwrs32(_n,_m,_x,_s,u);
261   ec_enc_uint(_enc,i,nc);
262   RESTORE_STACK;
263 }
264
265 int get_required_bits(int N, int K, int frac)
266 {
267    int nbits = 0;
268    if(fits_in32(N,K))
269    {
270       VARDECL(celt_uint32_t,u);
271       SAVE_STACK;
272       ALLOC(u,N,celt_uint32_t);
273       nbits = log2_frac(ncwrs_u32(N,K,u), frac);
274       RESTORE_STACK;
275    } else {
276       nbits = log2_frac(N, frac);
277       nbits += get_required_bits(N/2+1, (K+1)/2, frac);
278       nbits += get_required_bits(N/2+1, K/2, frac);
279    }
280    return nbits;
281 }
282
283
284 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
285 {
286    VARDECL(int, comb);
287    VARDECL(int, signs);
288    SAVE_STACK;
289
290    ALLOC(comb, K, int);
291    ALLOC(signs, K, int);
292
293    pulse2comb(N, K, comb, signs, _y);
294    if (K==0) {
295    } else if (N==1)
296    {
297       ec_enc_bits(enc, _y[0]<0, 1);
298    } else if(fits_in32(N,K))
299    {
300       encode_comb32(N, K, comb, signs, enc);
301    } else {
302      int i;
303      int count=0;
304      int split;
305      split = (N+1)/2;
306      for (i=0;i<split;i++)
307         count += abs(_y[i]);
308      ec_enc_uint(enc,count,K+1);
309      encode_pulses(_y, split, count, enc);
310      encode_pulses(_y+split, N-split, K-count, enc);
311    }
312    RESTORE_STACK;
313 }
314
315 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
316   VARDECL(celt_uint32_t,u);
317   SAVE_STACK;
318   ALLOC(u,_n,celt_uint32_t);
319   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
320   RESTORE_STACK;
321 }
322
323 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
324 {
325    VARDECL(int, comb);
326    VARDECL(int, signs);
327    SAVE_STACK;
328
329    ALLOC(comb, K, int);
330    ALLOC(signs, K, int);
331    if (K==0) {
332       int i;
333       for (i=0;i<N;i++)
334          _y[i] = 0;
335    } else if (N==1)
336    {
337       int s = ec_dec_bits(dec, 1);
338       if (s==0)
339          _y[0] = K;
340       else
341          _y[0] = -K;
342    } else if(fits_in32(N,K))
343    {
344       decode_comb32(N, K, comb, signs, dec);
345       comb2pulse(N, K, _y, comb, signs);
346    } else {
347      int split;
348      int count = ec_dec_uint(dec,K+1);
349      split = (N+1)/2;
350      decode_pulses(_y, split, count, dec);
351      decode_pulses(_y+split, N-split, K-count, dec);
352    }
353    RESTORE_STACK;
354 }