Speeded up cwrsi and icwrs by at least an order of magnitude. Now using
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry */
2 /*
3    Redistribution and use in source and binary forms, with or without
4    modification, are permitted provided that the following conditions
5    are met:
6
7    - Redistributions of source code must retain the above copyright
8    notice, this list of conditions and the following disclaimer.
9
10    - Redistributions in binary form must reproduce the above copyright
11    notice, this list of conditions and the following disclaimer in the
12    documentation and/or other materials provided with the distribution.
13
14    - Neither the name of the Xiph.org Foundation nor the names of its
15    contributors may be used to endorse or promote products derived from
16    this software without specific prior written permission.
17
18    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
22    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30 #include <stdlib.h>
31 #include "cwrs.h"
32
33 static celt_uint64_t update_ncwrs64(celt_uint64_t *nc, int len, int nc0)
34 {
35    int i;
36    celt_uint64_t mem;
37    
38    mem = nc[0];
39    nc[0] = nc0;
40    for (i=1;i<len;i++)
41    {
42       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
43       mem = nc[i];
44       nc[i] = tmp;
45    }
46 }
47
48 static celt_uint64_t reverse_ncwrs64(celt_uint64_t *nc, int len, int nc0)
49 {
50    int i;
51    celt_uint64_t mem;
52    
53    mem = nc[0];
54    nc[0] = nc0;
55    for (i=1;i<len;i++)
56    {
57       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
58       mem = nc[i];
59       nc[i] = tmp;
60    }
61 }
62
63 /* Optional implementation of ncwrs64 using update_ncwrs64(). It's slightly
64    slower than the standard ncwrs64(), but it could still be useful.
65 celt_uint64_t ncwrs64_opt(int _n,int _m)
66 {
67    int i;
68    celt_uint64_t ret;
69    celt_uint64_t nc[_n+1];
70    for (i=0;i<_n+1;i++)
71       nc[i] = 1;
72    for (i=0;i<_m;i++)
73       update_ncwrs64(nc, _n+1, 0);
74    return nc[_n];
75 }*/
76
77 /*Returns the numer of ways of choosing _m elements from a set of size _n with
78    replacement when a sign bit is needed for each unique element.*/
79 #if 0
80 static celt_uint32_t ncwrs(int _n,int _m){
81   static celt_uint32_t c[32][32];
82   if(_n<0||_m<0)return 0;
83   if(!c[_n][_m]){
84     if(_m<=0)c[_n][_m]=1;
85     else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
86   }
87   return c[_n][_m];
88 }
89 #else
90 celt_uint32_t ncwrs(int _n,int _m){
91   celt_uint32_t ret;
92   celt_uint32_t f;
93   celt_uint32_t d;
94   int      i;
95   if(_n<0||_m<0)return 0;
96   if(_m==0)return 1;
97   if(_n==0)return 0;
98   ret=0;
99   f=_n;
100   d=1;
101   for(i=1;i<=_m;i++){
102     ret+=f*d<<i;
103     f=(f*(_n-i))/(i+1);
104     d=(d*(_m-i))/i;
105   }
106   return ret;
107 }
108 #endif
109
110 #if 0
111 celt_uint64_t ncwrs64(int _n,int _m){
112   static celt_uint64_t c[101][101];
113   if(_n<0||_m<0)return 0;
114   if(!c[_n][_m]){
115     if(_m<=0)c[_n][_m]=1;
116     else if(_n>0)c[_n][_m]=ncwrs64(_n-1,_m)+ncwrs64(_n,_m-1)+ncwrs64(_n-1,_m-1);
117 }
118   return c[_n][_m];
119 }
120 #else
121 celt_uint64_t ncwrs64(int _n,int _m){
122   celt_uint64_t ret;
123   celt_uint64_t f;
124   celt_uint64_t d;
125   int           i;
126   if(_n<0||_m<0)return 0;
127   if(_m==0)return 1;
128   if(_n==0)return 0;
129   ret=0;
130   f=_n;
131   d=1;
132   for(i=1;i<=_m;i++){
133     ret+=f*d<<i;
134     f=(f*(_n-i))/(i+1);
135     d=(d*(_m-i))/i;
136   }
137   return ret;
138 }
139 #endif
140
141 /*Returns the _i'th combination of _m elements chosen from a set of size _n
142    with associated sign bits.
143   _x:      Returns the combination with elements sorted in ascending order.
144   _s:      Returns the associated sign bits.*/
145 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
146   int j;
147   int k;
148   for(k=j=0;k<_m;k++){
149     celt_uint32_t pn;
150     celt_uint32_t p;
151     celt_uint32_t t;
152     p=ncwrs(_n-j,_m-k-1);
153     pn=ncwrs(_n-j-1,_m-k-1);
154     p+=pn;
155     if(k>0){
156       t=p>>1;
157       if(t<=_i||_s[k-1])_i+=t;
158     }
159     while(p<=_i){
160       _i-=p;
161       j++;
162       p=pn;
163       pn=ncwrs(_n-j-1,_m-k-1);
164       p+=pn;
165     }
166     t=p>>1;
167     _s[k]=_i>=t;
168     _x[k]=j;
169     if(_s[k])_i-=t;
170   }
171 }
172
173 /*Returns the index of the given combination of _m elements chosen from a set
174    of size _n with associated sign bits.
175   _x:      The combination with elements sorted in ascending order.
176   _s:      The associated sign bits.*/
177 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s){
178   celt_uint32_t i;
179   int      j;
180   int      k;
181   i=0;
182   for(k=j=0;k<_m;k++){
183     celt_uint32_t pn;
184     celt_uint32_t p;
185     p=ncwrs(_n-j,_m-k-1);
186     pn=ncwrs(_n-j-1,_m-k-1);
187     p+=pn;
188     if(k>0)p>>=1;
189     while(j<_x[k]){
190       i+=p;
191       j++;
192       p=pn;
193       pn=ncwrs(_n-j-1,_m-k-1);
194       p+=pn;
195     }
196     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
197   }
198   return i;
199 }
200
201 /*Returns the _i'th combination of _m elements chosen from a set of size _n
202    with associated sign bits.
203   _x:      Returns the combination with elements sorted in ascending order.
204   _s:      Returns the associated sign bits.*/
205 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
206   int j;
207   int k;
208   celt_uint64_t nc[_n+1];
209   for (j=0;j<_n+1;j++)
210     nc[j] = 1;
211   for (k=0;k<_m-1;k++)
212     update_ncwrs64(nc, _n+1, 0);
213   for(k=j=0;k<_m;k++){
214     celt_uint64_t pn;
215     celt_uint64_t p;
216     celt_uint64_t t;
217     /*p=ncwrs64(_n-j,_m-k-1);
218     pn=ncwrs64(_n-j-1,_m-k-1);*/
219     p=nc[_n-j];
220     pn=nc[_n-j-1];
221     p+=pn;
222     if(k>0){
223       t=p>>1;
224       if(t<=_i||_s[k-1])_i+=t;
225     }
226     while(p<=_i){
227       _i-=p;
228       j++;
229       p=pn;
230       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
231       pn=nc[_n-j-1];
232       p+=pn;
233     }
234     t=p>>1;
235     _s[k]=_i>=t;
236     _x[k]=j;
237     if(_s[k])_i-=t;
238     if (k<_m-2)
239       reverse_ncwrs64(nc, _n+1, 0);
240     else
241       reverse_ncwrs64(nc, _n+1, 1);
242   }
243 }
244
245 /*Returns the index of the given combination of _m elements chosen from a set
246    of size _n with associated sign bits.
247   _x:      The combination with elements sorted in ascending order.
248   _s:      The associated sign bits.*/
249 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s){
250   celt_uint64_t i;
251   int           j;
252   int           k;
253   celt_uint64_t nc[_n+1];
254   for (j=0;j<_n+1;j++)
255     nc[j] = 1;
256   for (k=0;k<_m-1;k++)
257     update_ncwrs64(nc, _n+1, 0);
258   i=0;
259   for(k=j=0;k<_m;k++){
260     celt_uint64_t pn;
261     celt_uint64_t p;
262     /*p=ncwrs64(_n-j,_m-k-1);
263     pn=ncwrs64(_n-j-1,_m-k-1);*/
264     p=nc[_n-j];
265     pn=nc[_n-j-1];
266     p+=pn;
267     if(k>0)p>>=1;
268     while(j<_x[k]){
269       i+=p;
270       j++;
271       p=pn;
272       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
273       pn=nc[_n-j-1];
274       p+=pn;
275     }
276     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
277     if (k<_m-2)
278       reverse_ncwrs64(nc, _n+1, 0);
279     else
280       reverse_ncwrs64(nc, _n+1, 1);
281   }
282   return i;
283 }
284
285 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
286    a pulse vector _y of length _n.
287   _y: Returns the vector of pulses.
288   _x: The combination with elements sorted in ascending order.
289   _s: The associated sign bits.*/
290 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
291   int j;
292   int k;
293   int n;
294   for(k=j=0;k<_m;k+=n){
295     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
296     while(j<_x[k])_y[j++]=0;
297     _y[j++]=_s[k]?-n:n;
298   }
299   while(j<_n)_y[j++]=0;
300 }
301
302 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
303    pulses with associated sign bits _s.
304   _x: Returns the combination with elements sorted in ascending order.
305   _s: Returns the associated sign bits.
306   _y: The vector of pulses, whose sum of absolute values must be _m.*/
307 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
308   int j;
309   int k;
310   for(k=j=0;j<_n;j++){
311     if(_y[j]){
312       int n;
313       int s;
314       n=abs(_y[j]);
315       s=_y[j]<0;
316       for(;n-->0;k++){
317         _x[k]=j;
318         _s[k]=s;
319       }
320     }
321   }
322 }
323