Fixed a bunch of warnings
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors. For more details, see:
33    http://people.xiph.org/~tterribe/notes/cwrs.html
34 */
35 #include <stdlib.h>
36 #include "cwrs.h"
37
38 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
39    compute ncwrs() for m+1, for all n. Could also be used when m and n are
40    swapped just by changing nc */
41 static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
42 {
43    int i;
44    celt_uint32_t mem;
45    
46    mem = nc[0];
47    nc[0] = nc0;
48    for (i=1;i<len;i++)
49    {
50       celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
51       mem = nc[i];
52       nc[i] = tmp;
53    }
54 }
55
56 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
57    compute ncwrs() for m-1, for all n. Could also be used when m and n are
58    swapped just by changing nc */
59 static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
60 {
61    int i;
62    celt_uint32_t mem;
63    
64    mem = nc[0];
65    nc[0] = nc0;
66    for (i=1;i<len;i++)
67    {
68       celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
69       mem = nc[i];
70       nc[i] = tmp;
71    }
72 }
73
74 static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
75 {
76    int i;
77    celt_uint64_t mem;
78    
79    mem = nc[0];
80    nc[0] = nc0;
81    for (i=1;i<len;i++)
82    {
83       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
84       mem = nc[i];
85       nc[i] = tmp;
86    }
87 }
88
89 static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
90 {
91    int i;
92    celt_uint64_t mem;
93    
94    mem = nc[0];
95    nc[0] = nc0;
96    for (i=1;i<len;i++)
97    {
98       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
99       mem = nc[i];
100       nc[i] = tmp;
101    }
102 }
103
104 /*Returns the numer of ways of choosing _m elements from a set of size _n with
105    replacement when a sign bit is needed for each unique element.*/
106 celt_uint32_t ncwrs(int _n,int _m)
107 {
108    int i;
109    celt_uint32_t nc[_n+1];
110    for (i=0;i<_n+1;i++)
111       nc[i] = 1;
112    for (i=0;i<_m;i++)
113       next_ncwrs32(nc, _n+1, 0);
114    return nc[_n];
115 }
116
117 /*Returns the numer of ways of choosing _m elements from a set of size _n with
118    replacement when a sign bit is needed for each unique element.*/
119 celt_uint64_t ncwrs64(int _n,int _m)
120 {
121    int i;
122    celt_uint64_t nc[_n+1];
123    for (i=0;i<_n+1;i++)
124       nc[i] = 1;
125    for (i=0;i<_m;i++)
126       next_ncwrs64(nc, _n+1, 0);
127    return nc[_n];
128 }
129
130
131 /*Returns the _i'th combination of _m elements chosen from a set of size _n
132    with associated sign bits.
133   _x:      Returns the combination with elements sorted in ascending order.
134   _s:      Returns the associated sign bits.*/
135 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
136   int j;
137   int k;
138   celt_uint32_t nc[_n+1];
139   for (j=0;j<_n+1;j++)
140     nc[j] = 1;
141   for (k=0;k<_m-1;k++)
142     next_ncwrs32(nc, _n+1, 0);
143   for(k=j=0;k<_m;k++){
144     celt_uint32_t pn, p, t;
145     /*p=ncwrs(_n-j,_m-k-1);
146     pn=ncwrs(_n-j-1,_m-k-1);*/
147     p=nc[_n-j];
148     pn=nc[_n-j-1];
149     p+=pn;
150     if(k>0){
151       t=p>>1;
152       if(t<=_i||_s[k-1])_i+=t;
153     }
154     while(p<=_i){
155       _i-=p;
156       j++;
157       p=pn;
158       /*pn=ncwrs(_n-j-1,_m-k-1);*/
159       pn=nc[_n-j-1];
160       p+=pn;
161     }
162     t=p>>1;
163     _s[k]=_i>=t;
164     _x[k]=j;
165     if(_s[k])_i-=t;
166     if (k<_m-2)
167       prev_ncwrs32(nc, _n+1, 0);
168     else
169       prev_ncwrs32(nc, _n+1, 1);
170   }
171 }
172
173 /*Returns the index of the given combination of _m elements chosen from a set
174    of size _n with associated sign bits.
175   _x:      The combination with elements sorted in ascending order.
176   _s:      The associated sign bits.*/
177 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
178   celt_uint32_t i;
179   int      j;
180   int      k;
181   celt_uint32_t nc[_n+1];
182   for (j=0;j<_n+1;j++)
183     nc[j] = 1;
184   for (k=0;k<_m;k++)
185     next_ncwrs32(nc, _n+1, 0);
186   if (bound)
187     *bound = nc[_n];
188   i=0;
189   for(k=j=0;k<_m;k++){
190     celt_uint32_t pn;
191     celt_uint32_t p;
192     if (k<_m-1)
193       prev_ncwrs32(nc, _n+1, 0);
194     else
195       prev_ncwrs32(nc, _n+1, 1);
196     /*p=ncwrs(_n-j,_m-k-1);
197     pn=ncwrs(_n-j-1,_m-k-1);*/
198     p=nc[_n-j];
199     pn=nc[_n-j-1];
200     p+=pn;
201     if(k>0)p>>=1;
202     while(j<_x[k]){
203       i+=p;
204       j++;
205       p=pn;
206       /*pn=ncwrs(_n-j-1,_m-k-1);*/
207       pn=nc[_n-j-1];
208       p+=pn;
209     }
210     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
211   }
212   return i;
213 }
214
215 /*Returns the _i'th combination of _m elements chosen from a set of size _n
216    with associated sign bits.
217   _x:      Returns the combination with elements sorted in ascending order.
218   _s:      Returns the associated sign bits.*/
219 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
220   int j;
221   int k;
222   celt_uint64_t nc[_n+1];
223   for (j=0;j<_n+1;j++)
224     nc[j] = 1;
225   for (k=0;k<_m-1;k++)
226     next_ncwrs64(nc, _n+1, 0);
227   for(k=j=0;k<_m;k++){
228     celt_uint64_t pn, p, t;
229     /*p=ncwrs64(_n-j,_m-k-1);
230     pn=ncwrs64(_n-j-1,_m-k-1);*/
231     p=nc[_n-j];
232     pn=nc[_n-j-1];
233     p+=pn;
234     if(k>0){
235       t=p>>1;
236       if(t<=_i||_s[k-1])_i+=t;
237     }
238     while(p<=_i){
239       _i-=p;
240       j++;
241       p=pn;
242       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
243       pn=nc[_n-j-1];
244       p+=pn;
245     }
246     t=p>>1;
247     _s[k]=_i>=t;
248     _x[k]=j;
249     if(_s[k])_i-=t;
250     if (k<_m-2)
251       prev_ncwrs64(nc, _n+1, 0);
252     else
253       prev_ncwrs64(nc, _n+1, 1);
254   }
255 }
256
257 /*Returns the index of the given combination of _m elements chosen from a set
258    of size _n with associated sign bits.
259   _x:      The combination with elements sorted in ascending order.
260   _s:      The associated sign bits.*/
261 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
262   celt_uint64_t i;
263   int           j;
264   int           k;
265   celt_uint64_t nc[_n+1];
266   for (j=0;j<_n+1;j++)
267     nc[j] = 1;
268   for (k=0;k<_m;k++)
269     next_ncwrs64(nc, _n+1, 0);
270   if (bound)
271      *bound = nc[_n];
272   i=0;
273   for(k=j=0;k<_m;k++){
274     celt_uint64_t pn;
275     celt_uint64_t p;
276     if (k<_m-1)
277       prev_ncwrs64(nc, _n+1, 0);
278     else
279       prev_ncwrs64(nc, _n+1, 1);
280     /*p=ncwrs64(_n-j,_m-k-1);
281     pn=ncwrs64(_n-j-1,_m-k-1);*/
282     p=nc[_n-j];
283     pn=nc[_n-j-1];
284     p+=pn;
285     if(k>0)p>>=1;
286     while(j<_x[k]){
287       i+=p;
288       j++;
289       p=pn;
290       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
291       pn=nc[_n-j-1];
292       p+=pn;
293     }
294     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
295   }
296   return i;
297 }
298
299 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
300    a pulse vector _y of length _n.
301   _y: Returns the vector of pulses.
302   _x: The combination with elements sorted in ascending order.
303   _s: The associated sign bits.*/
304 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
305   int j;
306   int k;
307   int n;
308   for(k=j=0;k<_m;k+=n){
309     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
310     while(j<_x[k])_y[j++]=0;
311     _y[j++]=_s[k]?-n:n;
312   }
313   while(j<_n)_y[j++]=0;
314 }
315
316 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
317    pulses with associated sign bits _s.
318   _x: Returns the combination with elements sorted in ascending order.
319   _s: Returns the associated sign bits.
320   _y: The vector of pulses, whose sum of absolute values must be _m.*/
321 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
322   int j;
323   int k;
324   for(k=j=0;j<_n;j++){
325     if(_y[j]){
326       int n;
327       int s;
328       n=abs(_y[j]);
329       s=_y[j]<0;
330       for(;n-->0;k++){
331         _x[k]=j;
332         _s[k]=s;
333       }
334     }
335   }
336 }
337
338 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
339 {
340    int comb[K];
341    int signs[K];
342    pulse2comb(N, K, comb, signs, _y);
343    /* Go with 32-bit path if we're sure we can */
344    if (N<=13 && K<=13)
345    {
346       celt_uint32_t bound, id;
347       id = icwrs(N, K, comb, signs, &bound);
348       ec_enc_uint(enc,id,bound);
349    } else {
350       celt_uint64_t bound, id;
351       id = icwrs64(N, K, comb, signs, &bound);
352       ec_enc_uint64(enc,id,bound);
353    }
354 }
355
356 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
357 {
358    int comb[K];
359    int signs[K];   
360    if (N<=13 && K<=13)
361    {
362       cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
363       comb2pulse(N, K, _y, comb, signs);
364    } else {
365       cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
366       comb2pulse(N, K, _y, comb, signs);
367    }
368 }
369