Added support for codebooks up to 64 bits.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy Terriberry
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32
33 /*#include <stdio.h>*/
34 #include <stdlib.h>
35
36 #include "cwrs.h"
37
38 /*Returns the numer of ways of choosing _m elements from a set of size _n with
39    replacement when a sign bit is needed for each unique element.*/
40 #if 0
41 static unsigned ncwrs(int _n,int _m){
42   static unsigned c[32][32];
43   if(_n<0||_m<0)return 0;
44   if(!c[_n][_m]){
45     if(_m<=0)c[_n][_m]=1;
46     else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
47   }
48   return c[_n][_m];
49 }
50
51 #else
52
53 /*Returns the greatest common divisor of _a and _b.*/
54 static unsigned gcd(unsigned _a,unsigned _b){
55   unsigned r;
56   while(_b){
57     r=_a%_b;
58     _a=_b;
59     _b=r;
60   }
61   return _a;
62 }
63
64 /*Returns _a*b/_d, under the assumption that the result is an integer, avoiding
65    overflow.
66   It is assumed, but not required, that _b is smaller than _a.*/
67 static unsigned umuldiv(unsigned _a,unsigned _b,unsigned _d){
68   unsigned d;
69   d=gcd(_b,_d);
70   return (_a/(_d/d))*(_b/d);
71 }
72
73 unsigned ncwrs(int _n,int _m){
74   unsigned ret;
75   unsigned f;
76   unsigned d;
77   int      i;
78   if(_n<0||_m<0)return 0;
79   if(_m==0)return 1;
80   if(_n==0)return 0;
81   ret=0;
82   f=_n;
83   d=1;
84   for(i=1;i<=_m;i++){
85     ret+=f*d<<i;
86 #if 0
87     f=umuldiv(f,_n-i,i+1);
88     d=umuldiv(d,_m-i,i);
89 #else
90     f=(f*(_n-i))/(i+1);
91     d=(d*(_m-i))/i;
92 #endif
93   }
94   return ret;
95 }
96 #endif
97
98 celt_uint64_t ncwrs64(int _n,int _m){
99    celt_uint64_t ret;
100    celt_uint64_t f;
101    celt_uint64_t d;
102    int      i;
103    if(_n<0||_m<0)return 0;
104    if(_m==0)return 1;
105    if(_n==0)return 0;
106    ret=0;
107    f=_n;
108    d=1;
109    for(i=1;i<=_m;i++){
110       ret+=f*d<<i;
111 #if 0
112       f=umuldiv(f,_n-i,i+1);
113       d=umuldiv(d,_m-i,i);
114 #else
115       f=(f*(_n-i))/(i+1);
116       d=(d*(_m-i))/i;
117 #endif
118    }
119    return ret;
120 }
121
122 /*Returns the _i'th combination of _m elements chosen from a set of size _n
123    with associated sign bits.
124   _x:      Returns the combination with elements sorted in ascending order.
125   _s:      Returns the associated sign bits.*/
126 void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
127   unsigned pn;
128   int      j;
129   int      k;
130   pn=ncwrs(_n-1,_m);
131   for(k=j=0;k<_m;k++){
132     unsigned pp;
133     unsigned p;
134     unsigned t;
135     pp=0;
136     p=ncwrs(_n-j,_m-k)-pn;
137     if(k>0){
138       t=p>>1;
139       if(t<=_i||_s[k-1])_i+=t;
140     }
141     pn=ncwrs(_n-j-1,_m-k-1);
142     while(p<=_i){
143       pp=p;
144       j++;
145       p+=pn;
146       pn=ncwrs(_n-j-1,_m-k-1);
147       p+=pn;
148     }
149     t=p-pp>>1;
150     _s[k]=_i-pp>=t;
151     _x[k]=j;
152     _i-=pp;
153     if(_s[k])_i-=t;
154   }
155 }
156
157 /*Returns the index of the given combination of _m elements chosen from a set
158    of size _n with associated sign bits.
159   _x:      The combination with elements sorted in ascending order.
160   _s:      The associated sign bits.*/
161 unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
162   unsigned pn;
163   unsigned i;
164   int      j;
165   int      k;
166   i=0;
167   pn=ncwrs(_n-1,_m);
168   for(k=j=0;k<_m;k++){
169     unsigned pp;
170     unsigned p;
171     pp=0;
172     p=ncwrs(_n-j,_m-k)-pn;
173     if(k>0)p>>=1;
174     pn=ncwrs(_n-j-1,_m-k-1);
175     while(j<_x[k]){
176       pp=p;
177       j++;
178       p+=pn;
179       pn=ncwrs(_n-j-1,_m-k-1);
180       p+=pn;
181     }
182     i+=pp;
183     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p-pp>>1;
184   }
185   return i;
186 }
187
188
189 /*Returns the _i'th combination of _m elements chosen from a set of size _n
190    with associated sign bits.
191   _x:      Returns the combination with elements sorted in ascending order.
192   _s:      Returns the associated sign bits.*/
193 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
194    celt_uint64_t pn;
195    int      j;
196    int      k;
197    pn=ncwrs64(_n-1,_m);
198    for(k=j=0;k<_m;k++){
199       celt_uint64_t pp;
200       celt_uint64_t p;
201       celt_uint64_t t;
202       pp=0;
203       p=ncwrs64(_n-j,_m-k)-pn;
204       if(k>0){
205          t=p>>1;
206          if(t<=_i||_s[k-1])_i+=t;
207       }
208       pn=ncwrs64(_n-j-1,_m-k-1);
209       while(p<=_i){
210          pp=p;
211          j++;
212          p+=pn;
213          pn=ncwrs64(_n-j-1,_m-k-1);
214          p+=pn;
215       }
216       t=p-pp>>1;
217       _s[k]=_i-pp>=t;
218       _x[k]=j;
219       _i-=pp;
220       if(_s[k])_i-=t;
221    }
222 }
223
224 /*Returns the index of the given combination of _m elements chosen from a set
225    of size _n with associated sign bits.
226   _x:      The combination with elements sorted in ascending order.
227   _s:      The associated sign bits.*/
228 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s){
229    celt_uint64_t pn;
230    celt_uint64_t i;
231    int      j;
232    int      k;
233    i=0;
234    pn=ncwrs64(_n-1,_m);
235    for(k=j=0;k<_m;k++){
236       celt_uint64_t pp;
237       celt_uint64_t p;
238       pp=0;
239       p=ncwrs64(_n-j,_m-k)-pn;
240       if(k>0)p>>=1;
241       pn=ncwrs64(_n-j-1,_m-k-1);
242       while(j<_x[k]){
243          pp=p;
244          j++;
245          p+=pn;
246          pn=ncwrs64(_n-j-1,_m-k-1);
247          p+=pn;
248       }
249       i+=pp;
250       if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p-pp>>1;
251    }
252    return i;
253 }
254
255
256 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
257    a pulse vector _y of length _n.
258   _y: Returns the vector of pulses.
259   _x: The combination with elements sorted in ascending order.
260   _s: The associated sign bits.*/
261 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
262   int j;
263   int k;
264   int n;
265   for(k=j=0;k<_m;k+=n){
266     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
267     while(j<_x[k])_y[j++]=0;
268     _y[j++]=_s[k]?-n:n;
269   }
270   while(j<_n)_y[j++]=0;
271 }
272
273 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
274    pulses with associated sign bits _s.
275   _x: Returns the combination with elements sorted in ascending order.
276   _s: Returns the associated sign bits.
277   _y: The vector of pulses, whose sum of absolute values must be _m.*/
278 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
279   int j;
280   int k;
281   for(k=j=0;j<_n;j++){
282     if(_y[j]){
283       int n;
284       int s;
285       n=abs(_y[j]);
286       s=_y[j]<0;
287       for(;n-->0;k++){
288         _x[k]=j;
289         _s[k]=s;
290       }
291     }
292   }
293 }
294
295 /*
296 #define NMAX (10)
297 #define MMAX (9)
298
299 int main(int _argc,char **_argv){
300   int n;
301   for(n=0;n<=NMAX;n++){
302     int m;
303     for(m=0;m<=MMAX;m++){
304       unsigned nc;
305       unsigned i;
306       nc=ncwrs(n,m);
307       for(i=0;i<nc;i++){
308         int x[MMAX];
309         int s[MMAX];
310         int x2[MMAX];
311         int s2[MMAX];
312         int y[NMAX];
313         int j;
314         int k;
315         cwrsi(n,m,i,x,s);
316         printf("%6u of %u:",i,nc);
317         for(k=0;k<m;k++){
318           printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
319         }
320         printf(" ->");
321         if(icwrs(n,m,x,s)!=i){
322           fprintf(stderr,"Combination-index mismatch.\n");
323         }
324         comb2pulse(n,m,y,x,s);
325         for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
326         printf("\n");
327         pulse2comb(n,m,x2,s2,y);
328         for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
329           fprintf(stderr,"Pulse-combination mismatch.\n");
330           break;
331         }
332       }
333       printf("\n");
334     }
335   }
336   return 0;
337 }
338 */
339
340 /*
341 #include <stdio.h>
342 #define NMAX (32)
343 #define MMAX (16)
344
345 int main(int _argc,char **_argv){
346    int n;
347    for(n=0;n<=NMAX;n+=3){
348       int m;
349       for(m=0;m<=MMAX;m++){
350          celt_uint64_t nc;
351          celt_uint64_t i;
352          nc=ncwrs64(n,m);
353          printf("%d/%d: %llu",n,m, nc);
354          for(i=0;i<nc;i+=100000){
355             int x[MMAX];
356             int s[MMAX];
357             int x2[MMAX];
358             int s2[MMAX];
359             int y[NMAX];
360             int j;
361             int k;
362             cwrsi64(n,m,i,x,s);
363             //printf("%llu of %llu:",i,nc);
364             for(k=0;k<m;k++){
365                //printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
366             }
367             //printf(" ->");
368             if(icwrs64(n,m,x,s)!=i){
369                fprintf(stderr,"Combination-index mismatch.\n");
370             }
371             comb2pulse(n,m,y,x,s);
372             //for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
373             //printf("\n");
374             pulse2comb(n,m,x2,s2,y);
375             for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
376                fprintf(stderr,"Pulse-combination mismatch.\n");
377                break;
378             }
379          }
380          printf("\n");
381       }
382    }
383    return 0;
384 }
385 */