Now no divisions required in the cwrs code
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry */
2 /*
3    Redistribution and use in source and binary forms, with or without
4    modification, are permitted provided that the following conditions
5    are met:
6
7    - Redistributions of source code must retain the above copyright
8    notice, this list of conditions and the following disclaimer.
9
10    - Redistributions in binary form must reproduce the above copyright
11    notice, this list of conditions and the following disclaimer in the
12    documentation and/or other materials provided with the distribution.
13
14    - Neither the name of the Xiph.org Foundation nor the names of its
15    contributors may be used to endorse or promote products derived from
16    this software without specific prior written permission.
17
18    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
22    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30 #include <stdlib.h>
31 #include "cwrs.h"
32
33 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
34    compute ncwrs() for m+1, for all n. Could also be used when m and n are
35    swapped just by changing nc */
36 static celt_uint32_t next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
37 {
38    int i;
39    celt_uint32_t mem;
40    
41    mem = nc[0];
42    nc[0] = nc0;
43    for (i=1;i<len;i++)
44    {
45       celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
46       mem = nc[i];
47       nc[i] = tmp;
48    }
49 }
50
51 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
52    compute ncwrs() for m-1, for all n. Could also be used when m and n are
53    swapped just by changing nc */
54 static celt_uint32_t prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
55 {
56    int i;
57    celt_uint32_t mem;
58    
59    mem = nc[0];
60    nc[0] = nc0;
61    for (i=1;i<len;i++)
62    {
63       celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
64       mem = nc[i];
65       nc[i] = tmp;
66    }
67 }
68
69 static celt_uint64_t next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
70 {
71    int i;
72    celt_uint64_t mem;
73    
74    mem = nc[0];
75    nc[0] = nc0;
76    for (i=1;i<len;i++)
77    {
78       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
79       mem = nc[i];
80       nc[i] = tmp;
81    }
82 }
83
84 static celt_uint64_t prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
85 {
86    int i;
87    celt_uint64_t mem;
88    
89    mem = nc[0];
90    nc[0] = nc0;
91    for (i=1;i<len;i++)
92    {
93       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
94       mem = nc[i];
95       nc[i] = tmp;
96    }
97 }
98
99 /*Returns the numer of ways of choosing _m elements from a set of size _n with
100    replacement when a sign bit is needed for each unique element.*/
101 celt_uint32_t ncwrs(int _n,int _m)
102 {
103    int i;
104    celt_uint32_t ret;
105    celt_uint32_t nc[_n+1];
106    for (i=0;i<_n+1;i++)
107       nc[i] = 1;
108    for (i=0;i<_m;i++)
109       next_ncwrs32(nc, _n+1, 0);
110    return nc[_n];
111 }
112
113 /*Returns the numer of ways of choosing _m elements from a set of size _n with
114    replacement when a sign bit is needed for each unique element.*/
115 celt_uint64_t ncwrs64(int _n,int _m)
116 {
117    int i;
118    celt_uint64_t ret;
119    celt_uint64_t nc[_n+1];
120    for (i=0;i<_n+1;i++)
121       nc[i] = 1;
122    for (i=0;i<_m;i++)
123       next_ncwrs64(nc, _n+1, 0);
124    return nc[_n];
125 }
126
127
128 /*Returns the _i'th combination of _m elements chosen from a set of size _n
129    with associated sign bits.
130   _x:      Returns the combination with elements sorted in ascending order.
131   _s:      Returns the associated sign bits.*/
132 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
133   int j;
134   int k;
135   celt_uint32_t nc[_n+1];
136   for (j=0;j<_n+1;j++)
137     nc[j] = 1;
138   for (k=0;k<_m-1;k++)
139     next_ncwrs32(nc, _n+1, 0);
140   for(k=j=0;k<_m;k++){
141     celt_uint32_t pn, p, t;
142     /*p=ncwrs(_n-j,_m-k-1);
143     pn=ncwrs(_n-j-1,_m-k-1);*/
144     p=nc[_n-j];
145     pn=nc[_n-j-1];
146     p+=pn;
147     if(k>0){
148       t=p>>1;
149       if(t<=_i||_s[k-1])_i+=t;
150     }
151     while(p<=_i){
152       _i-=p;
153       j++;
154       p=pn;
155       /*pn=ncwrs(_n-j-1,_m-k-1);*/
156       pn=nc[_n-j-1];
157       p+=pn;
158     }
159     t=p>>1;
160     _s[k]=_i>=t;
161     _x[k]=j;
162     if(_s[k])_i-=t;
163     if (k<_m-2)
164       prev_ncwrs32(nc, _n+1, 0);
165     else
166       prev_ncwrs32(nc, _n+1, 1);
167   }
168 }
169
170 /*Returns the index of the given combination of _m elements chosen from a set
171    of size _n with associated sign bits.
172   _x:      The combination with elements sorted in ascending order.
173   _s:      The associated sign bits.*/
174 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
175   celt_uint32_t i;
176   int      j;
177   int      k;
178   celt_uint32_t nc[_n+1];
179   for (j=0;j<_n+1;j++)
180     nc[j] = 1;
181   for (k=0;k<_m;k++)
182     next_ncwrs32(nc, _n+1, 0);
183   if (bound)
184     *bound = nc[_n];
185   i=0;
186   for(k=j=0;k<_m;k++){
187     celt_uint32_t pn;
188     celt_uint32_t p;
189     if (k<_m-1)
190       prev_ncwrs32(nc, _n+1, 0);
191     else
192       prev_ncwrs32(nc, _n+1, 1);
193     /*p=ncwrs(_n-j,_m-k-1);
194     pn=ncwrs(_n-j-1,_m-k-1);*/
195     p=nc[_n-j];
196     pn=nc[_n-j-1];
197     p+=pn;
198     if(k>0)p>>=1;
199     while(j<_x[k]){
200       i+=p;
201       j++;
202       p=pn;
203       /*pn=ncwrs(_n-j-1,_m-k-1);*/
204       pn=nc[_n-j-1];
205       p+=pn;
206     }
207     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
208   }
209   return i;
210 }
211
212 /*Returns the _i'th combination of _m elements chosen from a set of size _n
213    with associated sign bits.
214   _x:      Returns the combination with elements sorted in ascending order.
215   _s:      Returns the associated sign bits.*/
216 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
217   int j;
218   int k;
219   celt_uint64_t nc[_n+1];
220   for (j=0;j<_n+1;j++)
221     nc[j] = 1;
222   for (k=0;k<_m-1;k++)
223     next_ncwrs64(nc, _n+1, 0);
224   for(k=j=0;k<_m;k++){
225     celt_uint64_t pn, p, t;
226     /*p=ncwrs64(_n-j,_m-k-1);
227     pn=ncwrs64(_n-j-1,_m-k-1);*/
228     p=nc[_n-j];
229     pn=nc[_n-j-1];
230     p+=pn;
231     if(k>0){
232       t=p>>1;
233       if(t<=_i||_s[k-1])_i+=t;
234     }
235     while(p<=_i){
236       _i-=p;
237       j++;
238       p=pn;
239       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
240       pn=nc[_n-j-1];
241       p+=pn;
242     }
243     t=p>>1;
244     _s[k]=_i>=t;
245     _x[k]=j;
246     if(_s[k])_i-=t;
247     if (k<_m-2)
248       prev_ncwrs64(nc, _n+1, 0);
249     else
250       prev_ncwrs64(nc, _n+1, 1);
251   }
252 }
253
254 /*Returns the index of the given combination of _m elements chosen from a set
255    of size _n with associated sign bits.
256   _x:      The combination with elements sorted in ascending order.
257   _s:      The associated sign bits.*/
258 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
259   celt_uint64_t i;
260   int           j;
261   int           k;
262   celt_uint64_t nc[_n+1];
263   for (j=0;j<_n+1;j++)
264     nc[j] = 1;
265   for (k=0;k<_m;k++)
266     next_ncwrs64(nc, _n+1, 0);
267   if (bound)
268      *bound = nc[_n];
269   i=0;
270   for(k=j=0;k<_m;k++){
271     celt_uint64_t pn;
272     celt_uint64_t p;
273     if (k<_m-1)
274       prev_ncwrs64(nc, _n+1, 0);
275     else
276       prev_ncwrs64(nc, _n+1, 1);
277     /*p=ncwrs64(_n-j,_m-k-1);
278     pn=ncwrs64(_n-j-1,_m-k-1);*/
279     p=nc[_n-j];
280     pn=nc[_n-j-1];
281     p+=pn;
282     if(k>0)p>>=1;
283     while(j<_x[k]){
284       i+=p;
285       j++;
286       p=pn;
287       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
288       pn=nc[_n-j-1];
289       p+=pn;
290     }
291     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
292   }
293   return i;
294 }
295
296 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
297    a pulse vector _y of length _n.
298   _y: Returns the vector of pulses.
299   _x: The combination with elements sorted in ascending order.
300   _s: The associated sign bits.*/
301 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
302   int j;
303   int k;
304   int n;
305   for(k=j=0;k<_m;k+=n){
306     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
307     while(j<_x[k])_y[j++]=0;
308     _y[j++]=_s[k]?-n:n;
309   }
310   while(j<_n)_y[j++]=0;
311 }
312
313 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
314    pulses with associated sign bits _s.
315   _x: Returns the combination with elements sorted in ascending order.
316   _s: Returns the associated sign bits.
317   _y: The vector of pulses, whose sum of absolute values must be _m.*/
318 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
319   int j;
320   int k;
321   for(k=j=0;j<_n;j++){
322     if(_y[j]){
323       int n;
324       int s;
325       n=abs(_y[j]);
326       s=_y[j]<0;
327       for(;n-->0;k++){
328         _x[k]=j;
329         _s[k]=s;
330       }
331     }
332   }
333 }
334
335 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
336 {
337    int comb[K];
338    int signs[K];
339    pulse2comb(N, K, comb, signs, _y);
340    /* Go with 32-bit path if we're sure we can */
341    if (N<=13 && K<=13)
342    {
343       celt_uint32_t bound, id;
344       id = icwrs(N, K, comb, signs, &bound);
345       ec_enc_uint(enc,id,bound);
346    } else {
347       celt_uint64_t bound, id;
348       id = icwrs64(N, K, comb, signs, &bound);
349       ec_enc_uint64(enc,id,bound);
350    }
351 }
352
353 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
354 {
355    int comb[K];
356    int signs[K];   
357    if (N<=13 && K<=13)
358    {
359       cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
360       comb2pulse(N, K, _y, comb, signs);
361    } else {
362       cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
363       comb2pulse(N, K, _y, comb, signs);
364    }
365 }
366