Fixed rsqrt testcase for float
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include <stdlib.h>
46 #include <string.h>
47 #include "cwrs.h"
48 #include "mathops.h"
49
50 /*Computes the next row/column of any recurrence that obeys the relation
51    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
52   _ui0 is the base case for the new row/column.*/
53 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
54   celt_uint32_t ui1;
55   int           j;
56   for(j=1;j<_len;j++){
57     ui1=_ui[j]+_ui[j-1]+_ui0;
58     _ui[j-1]=_ui0;
59     _ui0=ui1;
60   }
61   _ui[j-1]=_ui0;
62 }
63
64 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
65   celt_uint64_t ui1;
66   int           j;
67   for(j=1;j<_len;j++){
68     ui1=_ui[j]+_ui[j-1]+_ui0;
69     _ui[j-1]=_ui0;
70     _ui0=ui1;
71   }
72   _ui[j-1]=_ui0;
73 }
74
75 /*Computes the previous row/column of any recurrence that obeys the relation
76    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
77   _ui0 is the base case for the new row/column.*/
78 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
79   celt_uint32_t ui1;
80   int           j;
81   for(j=1;j<_n;j++){
82     ui1=_ui[j]-_ui[j-1]-_ui0;
83     _ui[j-1]=_ui0;
84     _ui0=ui1;
85   }
86   _ui[j-1]=_ui0;
87 }
88
89 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
90   celt_uint64_t ui1;
91   int           j;
92   for(j=1;j<_n;j++){
93     ui1=_ui[j]-_ui[j-1]-_ui0;
94     _ui[j-1]=_ui0;
95     _ui0=ui1;
96   }
97   _ui[j-1]=_ui0;
98 }
99
100 /*Returns the number of ways of choosing _m elements from a set of size _n with
101    replacement when a sign bit is needed for each unique element.
102   On input, _u should be initialized to column (_m-1) of U(n,m).
103   On exit, _u will be initialized to column _m of U(n,m).*/
104 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
105   celt_uint32_t ret;
106   celt_uint32_t ui0;
107   celt_uint32_t ui1;
108   int           j;
109   ret=ui0=2;
110   for(j=1;j<_n;j++){
111     ui1=_ui[j]+_ui[j-1]+ui0;
112     _ui[j-1]=ui0;
113     ui0=ui1;
114     ret+=ui0;
115   }
116   _ui[j-1]=ui0;
117   return ret;
118 }
119
120 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
121   celt_uint64_t ret;
122   celt_uint64_t ui0;
123   celt_uint64_t ui1;
124   int           j;
125   ret=ui0=2;
126   for(j=1;j<_n;j++){
127     ui1=_ui[j]+_ui[j-1]+ui0;
128     _ui[j-1]=ui0;
129     ui0=ui1;
130     ret+=ui0;
131   }
132   _ui[j-1]=ui0;
133   return ret;
134 }
135
136 /*Returns the number of ways of choosing _m elements from a set of size _n with
137    replacement when a sign bit is needed for each unique element.
138   On exit, _u will be initialized to column _m of U(n,m).*/
139 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
140   int k;
141   memset(_u,0,_n*sizeof(*_u));
142   if(_m<=0)return 1;
143   if(_n<=0)return 0;
144   for(k=1;k<_m;k++)unext32(_u,_n,2);
145   return ncwrs_unext32(_n,_u);
146 }
147
148 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
149   int k;
150   memset(_u,0,_n*sizeof(*_u));
151   if(_m<=0)return 1;
152   if(_n<=0)return 0;
153   for(k=1;k<_m;k++)unext64(_u,_n,2);
154   return ncwrs_unext64(_n,_u);
155 }
156
157 /*Returns the _i'th combination of _m elements chosen from a set of size _n
158    with associated sign bits.
159   _x: Returns the combination with elements sorted in ascending order.
160   _s: Returns the associated sign bits.
161   _u: Temporary storage already initialized to column _m of U(n,m).
162       Its contents will be overwritten.*/
163 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
164   int j;
165   int k;
166   for(k=j=0;k<_m;k++){
167     celt_uint32_t p;
168     celt_uint32_t t;
169     p=_u[_n-j-1];
170     if(k>0){
171       t=p>>1;
172       if(t<=_i||_s[k-1])_i+=t;
173     }
174     while(p<=_i){
175       _i-=p;
176       j++;
177       p=_u[_n-j-1];
178     }
179     t=p>>1;
180     _s[k]=_i>=t;
181     _x[k]=j;
182     if(_s[k])_i-=t;
183     uprev32(_u,_n-j,2);
184   }
185 }
186
187 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
188   int j;
189   int k;
190   for(k=j=0;k<_m;k++){
191     celt_uint64_t p;
192     celt_uint64_t t;
193     p=_u[_n-j-1];
194     if(k>0){
195       t=p>>1;
196       if(t<=_i||_s[k-1])_i+=t;
197     }
198     while(p<=_i){
199       _i-=p;
200       j++;
201       p=_u[_n-j-1];
202     }
203     t=p>>1;
204     _s[k]=_i>=t;
205     _x[k]=j;
206     if(_s[k])_i-=t;
207     uprev64(_u,_n-j,2);
208   }
209 }
210
211 /*Returns the index of the given combination of _m elements chosen from a set
212    of size _n with associated sign bits.
213   _x: The combination with elements sorted in ascending order.
214   _s: The associated sign bits.
215   _u: Temporary storage already initialized to column _m of U(n,m).
216       Its contents will be overwritten.*/
217 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
218  celt_uint32_t *_u){
219   celt_uint32_t i;
220   int           j;
221   int           k;
222   i=0;
223   for(k=j=0;k<_m;k++){
224     celt_uint32_t p;
225     p=_u[_n-j-1];
226     if(k>0)p>>=1;
227     while(j<_x[k]){
228       i+=p;
229       j++;
230       p=_u[_n-j-1];
231     }
232     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
233     uprev32(_u,_n-j,2);
234   }
235   return i;
236 }
237
238 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
239  celt_uint64_t *_u){
240   celt_uint64_t i;
241   int           j;
242   int           k;
243   i=0;
244   for(k=j=0;k<_m;k++){
245     celt_uint64_t p;
246     p=_u[_n-j-1];
247     if(k>0)p>>=1;
248     while(j<_x[k]){
249       i+=p;
250       j++;
251       p=_u[_n-j-1];
252     }
253     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
254     uprev64(_u,_n-j,2);
255   }
256   return i;
257 }
258
259 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
260    a pulse vector _y of length _n.
261   _y: Returns the vector of pulses.
262   _x: The combination with elements sorted in ascending order.
263   _s: The associated sign bits.*/
264 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
265   int j;
266   int k;
267   int n;
268   for(k=j=0;k<_m;k+=n){
269     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
270     while(j<_x[k])_y[j++]=0;
271     _y[j++]=_s[k]?-n:n;
272   }
273   while(j<_n)_y[j++]=0;
274 }
275
276 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
277    pulses with associated sign bits _s.
278   _x: Returns the combination with elements sorted in ascending order.
279   _s: Returns the associated sign bits.
280   _y: The vector of pulses, whose sum of absolute values must be _m.*/
281 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
282   int j;
283   int k;
284   for(k=j=0;j<_n;j++){
285     if(_y[j]){
286       int n;
287       int s;
288       n=abs(_y[j]);
289       s=_y[j]<0;
290       for(;n-->0;k++){
291         _x[k]=j;
292         _s[k]=s;
293       }
294     }
295   }
296 }
297
298 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
299  ec_enc *_enc){
300   VARDECL(celt_uint32_t,u);
301   celt_uint32_t nc;
302   celt_uint32_t i;
303   SAVE_STACK;
304   ALLOC(u,_n,celt_uint32_t);
305   nc=ncwrs_u32(_n,_m,u);
306   i=icwrs32(_n,_m,_x,_s,u);
307   ec_enc_uint(_enc,i,nc);
308   RESTORE_STACK;
309 }
310
311 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
312  ec_enc *_enc){
313   VARDECL(celt_uint64_t,u);
314   celt_uint64_t nc;
315   celt_uint64_t i;
316   SAVE_STACK;
317   ALLOC(u,_n,celt_uint64_t);
318   nc=ncwrs_u64(_n,_m,u);
319   i=icwrs64(_n,_m,_x,_s,u);
320   ec_enc_uint64(_enc,i,nc);
321   RESTORE_STACK;
322 }
323
324 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
325 {
326    VARDECL(int, comb);
327    VARDECL(int, signs);
328    SAVE_STACK;
329
330    ALLOC(comb, K, int);
331    ALLOC(signs, K, int);
332
333    pulse2comb(N, K, comb, signs, _y);
334    /* Simple heuristic to figure out whether it fits in 32 bits */
335    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
336    {
337       encode_comb32(N, K, comb, signs, enc);
338    } else {
339       encode_comb64(N, K, comb, signs, enc);
340    }
341    RESTORE_STACK;
342 }
343
344 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
345   VARDECL(celt_uint32_t,u);
346   SAVE_STACK;
347   ALLOC(u,_n,celt_uint32_t);
348   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
349   RESTORE_STACK;
350 }
351
352 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
353   VARDECL(celt_uint64_t,u);
354   SAVE_STACK;
355   ALLOC(u,_n,celt_uint64_t);
356   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
357   RESTORE_STACK;
358 }
359
360 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
361 {
362    VARDECL(int, comb);
363    VARDECL(int, signs);
364    SAVE_STACK;
365
366    ALLOC(comb, K, int);
367    ALLOC(signs, K, int);
368    /* Simple heuristic to figure out whether it fits in 32 bits */
369    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
370    {
371       decode_comb32(N, K, comb, signs, dec);
372    } else {
373       decode_comb64(N, K, comb, signs, dec);
374    }
375    comb2pulse(N, K, _y, comb, signs);
376    RESTORE_STACK;
377 }