optimisation: Removed a bunch of conditional branches from comb2pulse()
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors.
33    These are based on the function
34      U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
35      U(n,1) = U(1,m) = 2,
36     which counts the number of ways of placing m pulses in n dimensions, where
37      at least one pulse lies in dimension 0.
38    For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include "config.h"
43 #endif
44
45 #include "os_support.h"
46 #include <stdlib.h>
47 #include <string.h>
48 #include "cwrs.h"
49 #include "mathops.h"
50
51 /*Computes the next row/column of any recurrence that obeys the relation
52    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
53   _ui0 is the base case for the new row/column.*/
54 static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
55   celt_uint32_t ui1;
56   int           j;
57   /* doing a do-while would overrun the array if we had less than 2 samples */
58   j=1; do {
59     ui1=_ui[j]+_ui[j-1]+_ui0;
60     _ui[j-1]=_ui0;
61     _ui0=ui1;
62   } while (++j<_len);
63   _ui[j-1]=_ui0;
64 }
65
66 static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
67   celt_uint64_t ui1;
68   int           j;
69   /* doing a do-while would overrun the array if we had less than 2 samples */
70   j=1; do {
71     ui1=_ui[j]+_ui[j-1]+_ui0;
72     _ui[j-1]=_ui0;
73     _ui0=ui1;
74   } while (++j<_len);
75   _ui[j-1]=_ui0;
76 }
77
78 /*Computes the previous row/column of any recurrence that obeys the relation
79    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
80   _ui0 is the base case for the new row/column.*/
81 static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
82   celt_uint32_t ui1;
83   int           j;
84   /* doing a do-while would overrun the array if we had less than 2 samples */
85   j=1; do {
86     ui1=_ui[j]-_ui[j-1]-_ui0;
87     _ui[j-1]=_ui0;
88     _ui0=ui1;
89   } while (++j<_n);
90   _ui[j-1]=_ui0;
91 }
92
93 static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
94   celt_uint64_t ui1;
95   int           j;
96   /* doing a do-while would overrun the array if we had less than 2 samples */
97   j=1; do {
98     ui1=_ui[j]-_ui[j-1]-_ui0;
99     _ui[j-1]=_ui0;
100     _ui0=ui1;
101   } while (++j<_n);
102   _ui[j-1]=_ui0;
103 }
104
105 /*Returns the number of ways of choosing _m elements from a set of size _n with
106    replacement when a sign bit is needed for each unique element.
107   On input, _u should be initialized to column (_m-1) of U(n,m).
108   On exit, _u will be initialized to column _m of U(n,m).*/
109 celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
110   celt_uint32_t ret;
111   celt_uint32_t ui0;
112   celt_uint32_t ui1;
113   int           j;
114   ret=ui0=2;
115   celt_assert(_n>=2);
116   j=1; do {
117     ui1=_ui[j]+_ui[j-1]+ui0;
118     _ui[j-1]=ui0;
119     ui0=ui1;
120     ret+=ui0;
121   } while (++j<_n);
122   _ui[j-1]=ui0;
123   return ret;
124 }
125
126 celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
127   celt_uint64_t ret;
128   celt_uint64_t ui0;
129   celt_uint64_t ui1;
130   int           j;
131   ret=ui0=2;
132   celt_assert(_n>=2);
133   j=1; do {
134     ui1=_ui[j]+_ui[j-1]+ui0;
135     _ui[j-1]=ui0;
136     ui0=ui1;
137     ret+=ui0;
138   } while (++j<_n);
139   _ui[j-1]=ui0;
140   return ret;
141 }
142
143 /*Returns the number of ways of choosing _m elements from a set of size _n with
144    replacement when a sign bit is needed for each unique element.
145   On exit, _u will be initialized to column _m of U(n,m).*/
146 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
147   int k;
148   CELT_MEMSET(_u,0,_n);
149   if(_m<=0)return 1;
150   if(_n<=0)return 0;
151   for(k=1;k<_m;k++)unext32(_u,_n,2);
152   return ncwrs_unext32(_n,_u);
153 }
154
155 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
156   int k;
157   CELT_MEMSET(_u,0,_n);
158   if(_m<=0)return 1;
159   if(_n<=0)return 0;
160   for(k=1;k<_m;k++)unext64(_u,_n,2);
161   return ncwrs_unext64(_n,_u);
162 }
163
164 /*Returns the _i'th combination of _m elements chosen from a set of size _n
165    with associated sign bits.
166   _x: Returns the combination with elements sorted in ascending order.
167   _s: Returns the associated sign bits.
168   _u: Temporary storage already initialized to column _m of U(n,m).
169       Its contents will be overwritten.*/
170 void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
171   int j;
172   int k;
173   for(k=j=0;k<_m;k++){
174     celt_uint32_t p;
175     celt_uint32_t t;
176     p=_u[_n-j-1];
177     if(k>0){
178       t=p>>1;
179       if(t<=_i||_s[k-1])_i+=t;
180     }
181     while(p<=_i){
182       _i-=p;
183       j++;
184       p=_u[_n-j-1];
185     }
186     t=p>>1;
187     _s[k]=_i>=t;
188     _x[k]=j;
189     if(_s[k])_i-=t;
190     uprev32(_u,_n-j,2);
191   }
192 }
193
194 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
195   int j;
196   int k;
197   for(k=j=0;k<_m;k++){
198     celt_uint64_t p;
199     celt_uint64_t t;
200     p=_u[_n-j-1];
201     if(k>0){
202       t=p>>1;
203       if(t<=_i||_s[k-1])_i+=t;
204     }
205     while(p<=_i){
206       _i-=p;
207       j++;
208       p=_u[_n-j-1];
209     }
210     t=p>>1;
211     _s[k]=_i>=t;
212     _x[k]=j;
213     if(_s[k])_i-=t;
214     uprev64(_u,_n-j,2);
215   }
216 }
217
218 /*Returns the index of the given combination of _m elements chosen from a set
219    of size _n with associated sign bits.
220   _x: The combination with elements sorted in ascending order.
221   _s: The associated sign bits.
222   _u: Temporary storage already initialized to column _m of U(n,m).
223       Its contents will be overwritten.*/
224 celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
225  celt_uint32_t *_u){
226   celt_uint32_t i;
227   int           j;
228   int           k;
229   i=0;
230   for(k=j=0;k<_m;k++){
231     celt_uint32_t p;
232     p=_u[_n-j-1];
233     if(k>0)p>>=1;
234     while(j<_x[k]){
235       i+=p;
236       j++;
237       p=_u[_n-j-1];
238     }
239     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
240     uprev32(_u,_n-j,2);
241   }
242   return i;
243 }
244
245 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
246  celt_uint64_t *_u){
247   celt_uint64_t i;
248   int           j;
249   int           k;
250   i=0;
251   for(k=j=0;k<_m;k++){
252     celt_uint64_t p;
253     p=_u[_n-j-1];
254     if(k>0)p>>=1;
255     while(j<_x[k]){
256       i+=p;
257       j++;
258       p=_u[_n-j-1];
259     }
260     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
261     uprev64(_u,_n-j,2);
262   }
263   return i;
264 }
265
266 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
267    a pulse vector _y of length _n.
268   _y: Returns the vector of pulses.
269   _x: The combination with elements sorted in ascending order. _x[_m] = -1
270   _s: The associated sign bits.*/
271 void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
272   int j;
273   int k;
274   int n;
275   CELT_MEMSET(_y, 0, _n);
276   for(k=j=0;k<_m;k+=n){
277      /* _x[_m] = -1 so we won't overflow */
278     for(n=1;_x[k+n]==_x[k];n++);
279     _y[_x[k]]=_s[k]?-n:n;
280   }
281 }
282
283 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
284    pulses with associated sign bits _s.
285   _x: Returns the combination with elements sorted in ascending order.
286   _s: Returns the associated sign bits.
287   _y: The vector of pulses, whose sum of absolute values must be _m.*/
288 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
289   int j;
290   int k;
291   for(k=j=0;j<_n;j++){
292     if(_y[j]){
293       int n;
294       int s;
295       n=abs(_y[j]);
296       s=_y[j]<0;
297       for(;n-->0;k++){
298         _x[k]=j;
299         _s[k]=s;
300       }
301     }
302   }
303 }
304
305 static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
306  ec_enc *_enc){
307   VARDECL(celt_uint32_t,u);
308   celt_uint32_t nc;
309   celt_uint32_t i;
310   SAVE_STACK;
311   ALLOC(u,_n,celt_uint32_t);
312   nc=ncwrs_u32(_n,_m,u);
313   i=icwrs32(_n,_m,_x,_s,u);
314   ec_enc_uint(_enc,i,nc);
315   RESTORE_STACK;
316 }
317
318 static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
319  ec_enc *_enc){
320   VARDECL(celt_uint64_t,u);
321   celt_uint64_t nc;
322   celt_uint64_t i;
323   SAVE_STACK;
324   ALLOC(u,_n,celt_uint64_t);
325   nc=ncwrs_u64(_n,_m,u);
326   i=icwrs64(_n,_m,_x,_s,u);
327   ec_enc_uint64(_enc,i,nc);
328   RESTORE_STACK;
329 }
330
331 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
332 {
333    VARDECL(int, comb);
334    VARDECL(int, signs);
335    SAVE_STACK;
336
337    ALLOC(comb, K, int);
338    ALLOC(signs, K, int);
339
340    pulse2comb(N, K, comb, signs, _y);
341    /* Simple heuristic to figure out whether it fits in 32 bits */
342    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
343    {
344       encode_comb32(N, K, comb, signs, enc);
345    } else {
346       encode_comb64(N, K, comb, signs, enc);
347    }
348    RESTORE_STACK;
349 }
350
351 static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
352   VARDECL(celt_uint32_t,u);
353   SAVE_STACK;
354   ALLOC(u,_n,celt_uint32_t);
355   cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
356   RESTORE_STACK;
357 }
358
359 static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
360   VARDECL(celt_uint64_t,u);
361   SAVE_STACK;
362   ALLOC(u,_n,celt_uint64_t);
363   cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
364   RESTORE_STACK;
365 }
366
367 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
368 {
369    VARDECL(int, comb);
370    VARDECL(int, signs);
371    SAVE_STACK;
372
373    ALLOC(comb, K+1, int);
374    ALLOC(signs, K, int);
375    /* Simple heuristic to figure out whether it fits in 32 bits */
376    if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
377    {
378       decode_comb32(N, K, comb, signs, dec);
379    } else {
380       decode_comb64(N, K, comb, signs, dec);
381    }
382    comb[K] = -1;
383    comb2pulse(N, K, _y, comb, signs);
384    RESTORE_STACK;
385 }