optimisation: another bunch of simplifications to the "simple case" of the
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50
51 /*Computes the next row/column of any recurrence that obeys the relation
52    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
53   _ui0 is the base case for the new row/column.*/
54 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
55   celt_uint32_t ui1;
56   int           j;
57   for(j=1;j<_len;j++){
58     ui1=_ui[j]+_ui[j-1]+_ui0;
59     _ui[j-1]=_ui0;
60     _ui0=ui1;
61   }
62   _ui[j-1]=_ui0;
63 }
64
65 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
66   celt_uint64_t ui1;
67   int           j;
68   for(j=1;j<_len;j++){
69     ui1=_ui[j]+_ui[j-1]+_ui0;
70     _ui[j-1]=_ui0;
71     _ui0=ui1;
72   }
73   _ui[j-1]=_ui0;
74 }
75
76 /*Computes the previous row/column of any recurrence that obeys the relation
77    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
78   _ui0 is the base case for the new row/column.*/
79 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
80   celt_uint32_t ui1;
81   int           j;
82   for(j=1;j<_n;j++){
83     ui1=_ui[j]-_ui[j-1]-_ui0;
84     _ui[j-1]=_ui0;
85     _ui0=ui1;
86   }
87   _ui[j-1]=_ui0;
88 }
89
90 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
91   celt_uint64_t ui1;
92   int           j;
93   for(j=1;j<_n;j++){
94     ui1=_ui[j]-_ui[j-1]-_ui0;
95     _ui[j-1]=_ui0;
96     _ui0=ui1;
97   }
98   _ui[j-1]=_ui0;
99 }
100
101 /*Returns the number of ways of choosing _m elements from a set of size _n with
102    replacement when a sign bit is needed for each unique element.
103   On input, _u should be initialized to column (_m-1) of U(n,m).
104   On exit, _u will be initialized to column _m of U(n,m).*/
105 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
106   celt_uint32_t ret;
107   celt_uint32_t ui0;
108   celt_uint32_t ui1;
109   int           j;
110   ret=ui0=2;
111   for(j=1;j<_n;j++){
112     ui1=_ui[j]+_ui[j-1]+ui0;
113     _ui[j-1]=ui0;
114     ui0=ui1;
115     ret+=ui0;
116   }
117   _ui[j-1]=ui0;
118   return ret;
119 }
120
121 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
122   celt_uint64_t ret;
123   celt_uint64_t ui0;
124   celt_uint64_t ui1;
125   int           j;
126   ret=ui0=2;
127   for(j=1;j<_n;j++){
128     ui1=_ui[j]+_ui[j-1]+ui0;
129     _ui[j-1]=ui0;
130     ui0=ui1;
131     ret+=ui0;
132   }
133   _ui[j-1]=ui0;
134   return ret;
135 }
136
137 /*Returns the number of ways of choosing _m elements from a set of size _n with
138    replacement when a sign bit is needed for each unique element.
139   On exit, _u will be initialized to column _m of U(n,m).*/
140 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
141   int k;
142   CELT_MEMSET(_u,0,_n);
143   if(_m<=0)return 1;
144   if(_n<=0)return 0;
145   for(k=1;k<_m;k++)unext32(_u,_n,2);
146   return ncwrs_unext32(_n,_u);
147 }
148
149 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
150   int k;
151   CELT_MEMSET(_u,0,_n);
152   if(_m<=0)return 1;
153   if(_n<=0)return 0;
154   for(k=1;k<_m;k++)unext64(_u,_n,2);
155   return ncwrs_unext64(_n,_u);
156 }
157
158 /*Returns the _i'th combination of _m elements chosen from a set of size _n
159    with associated sign bits.
160   _x: Returns the combination with elements sorted in ascending order.
161   _s: Returns the associated sign bits.
162   _u: Temporary storage already initialized to column _m of U(n,m).
163       Its contents will be overwritten.*/
164 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
165   int j;
166   int k;
167   for(k=j=0;k<_m;k++){
168     celt_uint32_t p;
169     celt_uint32_t t;
170     p=_u[_n-j-1];
171     if(k>0){
172       t=p>>1;
173       if(t<=_i||_s[k-1])_i+=t;
174     }
175     while(p<=_i){
176       _i-=p;
177       j++;
178       p=_u[_n-j-1];
179     }
180     t=p>>1;
181     _s[k]=_i>=t;
182     _x[k]=j;
183     if(_s[k])_i-=t;
184     uprev32(_u,_n-j,2);
185   }
186 }
187
188 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
189   int j;
190   int k;
191   for(k=j=0;k<_m;k++){
192     celt_uint64_t p;
193     celt_uint64_t t;
194     p=_u[_n-j-1];
195     if(k>0){
196       t=p>>1;
197       if(t<=_i||_s[k-1])_i+=t;
198     }
199     while(p<=_i){
200       _i-=p;
201       j++;
202       p=_u[_n-j-1];
203     }
204     t=p>>1;
205     _s[k]=_i>=t;
206     _x[k]=j;
207     if(_s[k])_i-=t;
208     uprev64(_u,_n-j,2);
209   }
210 }
211
212 /*Returns the index of the given combination of _m elements chosen from a set
213    of size _n with associated sign bits.
214   _x: The combination with elements sorted in ascending order.
215   _s: The associated sign bits.
216   _u: Temporary storage already initialized to column _m of U(n,m).
217       Its contents will be overwritten.*/
218 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
219  celt_uint32_t *_u){
220   celt_uint32_t i;
221   int           j;
222   int           k;
223   i=0;
224   for(k=j=0;k<_m;k++){
225     celt_uint32_t p;
226     p=_u[_n-j-1];
227     if(k>0)p>>=1;
228     while(j<_x[k]){
229       i+=p;
230       j++;
231       p=_u[_n-j-1];
232     }
233     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
234     uprev32(_u,_n-j,2);
235   }
236   return i;
237 }
238
239 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
240  celt_uint64_t *_u){
241   celt_uint64_t i;
242   int           j;
243   int           k;
244   i=0;
245   for(k=j=0;k<_m;k++){
246     celt_uint64_t p;
247     p=_u[_n-j-1];
248     if(k>0)p>>=1;
249     while(j<_x[k]){
250       i+=p;
251       j++;
252       p=_u[_n-j-1];
253     }
254     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
255     uprev64(_u,_n-j,2);
256   }
257   return i;
258 }
259
260 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
261    a pulse vector _y of length _n.
262   _y: Returns the vector of pulses.
263   _x: The combination with elements sorted in ascending order.
264   _s: The associated sign bits.*/
265 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
266   int j;
267   int k;
268   int n;
269   for(k=j=0;k<_m;k+=n){
270     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
271     while(j<_x[k])_y[j++]=0;
272     _y[j++]=_s[k]?-n:n;
273   }
274   while(j<_n)_y[j++]=0;
275 }
276
277 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
278    pulses with associated sign bits _s.
279   _x: Returns the combination with elements sorted in ascending order.
280   _s: Returns the associated sign bits.
281   _y: The vector of pulses, whose sum of absolute values must be _m.*/
282 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
283   int j;
284   int k;
285   for(k=j=0;j<_n;j++){
286     if(_y[j]){
287       int n;
288       int s;
289       n=abs(_y[j]);
290       s=_y[j]<0;
291       for(;n-->0;k++){
292         _x[k]=j;
293         _s[k]=s;
294       }
295     }
296   }
297 }
298
299 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
300  ec_enc *_enc){
301   VARDECL(celt_uint32_t,u);
302   celt_uint32_t nc;
303   celt_uint32_t i;
304   SAVE_STACK;
305   ALLOC(u,_n,celt_uint32_t);
306   nc=ncwrs_u32(_n,_m,u);
307   i=icwrs32(_n,_m,_x,_s,u);
308   ec_enc_uint(_enc,i,nc);
309   RESTORE_STACK;
310 }
311
312 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
313  ec_enc *_enc){
314   VARDECL(celt_uint64_t,u);
315   celt_uint64_t nc;
316   celt_uint64_t i;
317   SAVE_STACK;
318   ALLOC(u,_n,celt_uint64_t);
319   nc=ncwrs_u64(_n,_m,u);
320   i=icwrs64(_n,_m,_x,_s,u);
321   ec_enc_uint64(_enc,i,nc);
322   RESTORE_STACK;
323 }
324
325 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
326 {
327    VARDECL(int, comb);
328    VARDECL(int, signs);
329    SAVE_STACK;
330
331    ALLOC(comb, K, int);
332    ALLOC(signs, K, int);
333
334    pulse2comb(N, K, comb, signs, _y);
335    /* Simple heuristic to figure out whether it fits in 32 bits */
336    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
337    {
338       encode_comb32(N, K, comb, signs, enc);
339    } else {
340       encode_comb64(N, K, comb, signs, enc);
341    }
342    RESTORE_STACK;
343 }
344
345 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
346   VARDECL(celt_uint32_t,u);
347   SAVE_STACK;
348   ALLOC(u,_n,celt_uint32_t);
349   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
350   RESTORE_STACK;
351 }
352
353 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
354   VARDECL(celt_uint64_t,u);
355   SAVE_STACK;
356   ALLOC(u,_n,celt_uint64_t);
357   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
358   RESTORE_STACK;
359 }
360
361 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
362 {
363    VARDECL(int, comb);
364    VARDECL(int, signs);
365    SAVE_STACK;
366
367    ALLOC(comb, K, int);
368    ALLOC(signs, K, int);
369    /* Simple heuristic to figure out whether it fits in 32 bits */
370    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
371    {
372       decode_comb32(N, K, comb, signs, dec);
373    } else {
374       decode_comb64(N, K, comb, signs, dec);
375    }
376    comb2pulse(N, K, _y, comb, signs);
377    RESTORE_STACK;
378 }