Reduced useless calls to ncwrs64() by half.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry */
2 /*
3    Redistribution and use in source and binary forms, with or without
4    modification, are permitted provided that the following conditions
5    are met:
6
7    - Redistributions of source code must retain the above copyright
8    notice, this list of conditions and the following disclaimer.
9
10    - Redistributions in binary form must reproduce the above copyright
11    notice, this list of conditions and the following disclaimer in the
12    documentation and/or other materials provided with the distribution.
13
14    - Neither the name of the Xiph.org Foundation nor the names of its
15    contributors may be used to endorse or promote products derived from
16    this software without specific prior written permission.
17
18    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
22    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30 #include <stdlib.h>
31 #include "cwrs.h"
32
33 static celt_uint64_t next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
34 {
35    int i;
36    celt_uint64_t mem;
37    
38    mem = nc[0];
39    nc[0] = nc0;
40    for (i=1;i<len;i++)
41    {
42       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
43       mem = nc[i];
44       nc[i] = tmp;
45    }
46 }
47
48 static celt_uint64_t prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
49 {
50    int i;
51    celt_uint64_t mem;
52    
53    mem = nc[0];
54    nc[0] = nc0;
55    for (i=1;i<len;i++)
56    {
57       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
58       mem = nc[i];
59       nc[i] = tmp;
60    }
61 }
62
63 /* Optional implementation of ncwrs64 using update_ncwrs64(). It's slightly
64    slower than the standard ncwrs64(), but it could still be useful.
65 celt_uint64_t ncwrs64_opt(int _n,int _m)
66 {
67    int i;
68    celt_uint64_t ret;
69    celt_uint64_t nc[_n+1];
70    for (i=0;i<_n+1;i++)
71       nc[i] = 1;
72    for (i=0;i<_m;i++)
73       update_ncwrs64(nc, _n+1, 0);
74    return nc[_n];
75 }*/
76
77 /*Returns the numer of ways of choosing _m elements from a set of size _n with
78    replacement when a sign bit is needed for each unique element.*/
79 #if 0
80 static celt_uint32_t ncwrs(int _n,int _m){
81   static celt_uint32_t c[32][32];
82   if(_n<0||_m<0)return 0;
83   if(!c[_n][_m]){
84     if(_m<=0)c[_n][_m]=1;
85     else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
86   }
87   return c[_n][_m];
88 }
89 #else
90 celt_uint32_t ncwrs(int _n,int _m){
91   celt_uint32_t ret;
92   celt_uint32_t f;
93   celt_uint32_t d;
94   int      i;
95   if(_n<0||_m<0)return 0;
96   if(_m==0)return 1;
97   if(_n==0)return 0;
98   ret=0;
99   f=_n;
100   d=1;
101   for(i=1;i<=_m;i++){
102     ret+=f*d<<i;
103     f=(f*(_n-i))/(i+1);
104     d=(d*(_m-i))/i;
105   }
106   return ret;
107 }
108 #endif
109
110 #if 0
111 celt_uint64_t ncwrs64(int _n,int _m){
112   static celt_uint64_t c[101][101];
113   if(_n<0||_m<0)return 0;
114   if(!c[_n][_m]){
115     if(_m<=0)c[_n][_m]=1;
116     else if(_n>0)c[_n][_m]=ncwrs64(_n-1,_m)+ncwrs64(_n,_m-1)+ncwrs64(_n-1,_m-1);
117 }
118   return c[_n][_m];
119 }
120 #else
121 celt_uint64_t ncwrs64(int _n,int _m){
122   celt_uint64_t ret;
123   celt_uint64_t f;
124   celt_uint64_t d;
125   int           i;
126   if(_n<0||_m<0)return 0;
127   if(_m==0)return 1;
128   if(_n==0)return 0;
129   ret=0;
130   f=_n;
131   d=1;
132   for(i=1;i<=_m;i++){
133     ret+=f*d<<i;
134     f=(f*(_n-i))/(i+1);
135     d=(d*(_m-i))/i;
136   }
137   return ret;
138 }
139 #endif
140
141 /*Returns the _i'th combination of _m elements chosen from a set of size _n
142    with associated sign bits.
143   _x:      Returns the combination with elements sorted in ascending order.
144   _s:      Returns the associated sign bits.*/
145 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
146   int j;
147   int k;
148   for(k=j=0;k<_m;k++){
149     celt_uint32_t pn;
150     celt_uint32_t p;
151     celt_uint32_t t;
152     p=ncwrs(_n-j,_m-k-1);
153     pn=ncwrs(_n-j-1,_m-k-1);
154     p+=pn;
155     if(k>0){
156       t=p>>1;
157       if(t<=_i||_s[k-1])_i+=t;
158     }
159     while(p<=_i){
160       _i-=p;
161       j++;
162       p=pn;
163       pn=ncwrs(_n-j-1,_m-k-1);
164       p+=pn;
165     }
166     t=p>>1;
167     _s[k]=_i>=t;
168     _x[k]=j;
169     if(_s[k])_i-=t;
170   }
171 }
172
173 /*Returns the index of the given combination of _m elements chosen from a set
174    of size _n with associated sign bits.
175   _x:      The combination with elements sorted in ascending order.
176   _s:      The associated sign bits.*/
177 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s){
178   celt_uint32_t i;
179   int      j;
180   int      k;
181   i=0;
182   for(k=j=0;k<_m;k++){
183     celt_uint32_t pn;
184     celt_uint32_t p;
185     p=ncwrs(_n-j,_m-k-1);
186     pn=ncwrs(_n-j-1,_m-k-1);
187     p+=pn;
188     if(k>0)p>>=1;
189     while(j<_x[k]){
190       i+=p;
191       j++;
192       p=pn;
193       pn=ncwrs(_n-j-1,_m-k-1);
194       p+=pn;
195     }
196     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
197   }
198   return i;
199 }
200
201 /*Returns the _i'th combination of _m elements chosen from a set of size _n
202    with associated sign bits.
203   _x:      Returns the combination with elements sorted in ascending order.
204   _s:      Returns the associated sign bits.*/
205 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
206   int j;
207   int k;
208   celt_uint64_t nc[_n+1];
209   for (j=0;j<_n+1;j++)
210     nc[j] = 1;
211   for (k=0;k<_m-1;k++)
212     next_ncwrs64(nc, _n+1, 0);
213   for(k=j=0;k<_m;k++){
214     celt_uint64_t pn;
215     celt_uint64_t p;
216     celt_uint64_t t;
217     /*p=ncwrs64(_n-j,_m-k-1);
218     pn=ncwrs64(_n-j-1,_m-k-1);*/
219     p=nc[_n-j];
220     pn=nc[_n-j-1];
221     p+=pn;
222     if(k>0){
223       t=p>>1;
224       if(t<=_i||_s[k-1])_i+=t;
225     }
226     while(p<=_i){
227       _i-=p;
228       j++;
229       p=pn;
230       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
231       pn=nc[_n-j-1];
232       p+=pn;
233     }
234     t=p>>1;
235     _s[k]=_i>=t;
236     _x[k]=j;
237     if(_s[k])_i-=t;
238     if (k<_m-2)
239       prev_ncwrs64(nc, _n+1, 0);
240     else
241       prev_ncwrs64(nc, _n+1, 1);
242   }
243 }
244
245 /*Returns the index of the given combination of _m elements chosen from a set
246    of size _n with associated sign bits.
247   _x:      The combination with elements sorted in ascending order.
248   _s:      The associated sign bits.*/
249 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
250   celt_uint64_t i;
251   int           j;
252   int           k;
253   celt_uint64_t nc[_n+1];
254   for (j=0;j<_n+1;j++)
255     nc[j] = 1;
256   for (k=0;k<_m;k++)
257     next_ncwrs64(nc, _n+1, 0);
258   if (bound)
259      *bound = nc[_n];
260   i=0;
261   for(k=j=0;k<_m;k++){
262     celt_uint64_t pn;
263     celt_uint64_t p;
264     if (k<_m-1)
265       prev_ncwrs64(nc, _n+1, 0);
266     else
267       prev_ncwrs64(nc, _n+1, 1);
268     /*p=ncwrs64(_n-j,_m-k-1);
269     pn=ncwrs64(_n-j-1,_m-k-1);*/
270     p=nc[_n-j];
271     pn=nc[_n-j-1];
272     p+=pn;
273     if(k>0)p>>=1;
274     while(j<_x[k]){
275       i+=p;
276       j++;
277       p=pn;
278       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
279       pn=nc[_n-j-1];
280       p+=pn;
281     }
282     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
283   }
284   return i;
285 }
286
287 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
288    a pulse vector _y of length _n.
289   _y: Returns the vector of pulses.
290   _x: The combination with elements sorted in ascending order.
291   _s: The associated sign bits.*/
292 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
293   int j;
294   int k;
295   int n;
296   for(k=j=0;k<_m;k+=n){
297     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
298     while(j<_x[k])_y[j++]=0;
299     _y[j++]=_s[k]?-n:n;
300   }
301   while(j<_n)_y[j++]=0;
302 }
303
304 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
305    pulses with associated sign bits _s.
306   _x: Returns the combination with elements sorted in ascending order.
307   _s: Returns the associated sign bits.
308   _y: The vector of pulses, whose sum of absolute values must be _m.*/
309 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
310   int j;
311   int k;
312   for(k=j=0;j<_n;j++){
313     if(_y[j]){
314       int n;
315       int s;
316       n=abs(_y[j]);
317       s=_y[j]<0;
318       for(;n-->0;k++){
319         _x[k]=j;
320         _s[k]=s;
321       }
322     }
323   }
324 }
325
326 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
327 {
328    int comb[K];
329    int signs[K];
330    pulse2comb(N, K, comb, signs, _y);
331    celt_uint64_t bound, id;
332    id = icwrs64(N, K, comb, signs, &bound);
333    ec_enc_uint64(enc,id,bound);
334 }
335
336 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
337 {
338    int comb[K];
339    int signs[K];   
340    cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
341    comb2pulse(N, K, _y, comb, signs);
342 }
343