cwrs.c links to derf's article on pulse vector encoding.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors. For more details, see:
33    http://people.xiph.org/~tterribe/notes/cwrs.html
34 */
35 #include <stdlib.h>
36 #include "cwrs.h"
37
38 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
39    compute ncwrs() for m+1, for all n. Could also be used when m and n are
40    swapped just by changing nc */
41 static celt_uint32_t next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
42 {
43    int i;
44    celt_uint32_t mem;
45    
46    mem = nc[0];
47    nc[0] = nc0;
48    for (i=1;i<len;i++)
49    {
50       celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
51       mem = nc[i];
52       nc[i] = tmp;
53    }
54 }
55
56 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
57    compute ncwrs() for m-1, for all n. Could also be used when m and n are
58    swapped just by changing nc */
59 static celt_uint32_t prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
60 {
61    int i;
62    celt_uint32_t mem;
63    
64    mem = nc[0];
65    nc[0] = nc0;
66    for (i=1;i<len;i++)
67    {
68       celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
69       mem = nc[i];
70       nc[i] = tmp;
71    }
72 }
73
74 static celt_uint64_t next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
75 {
76    int i;
77    celt_uint64_t mem;
78    
79    mem = nc[0];
80    nc[0] = nc0;
81    for (i=1;i<len;i++)
82    {
83       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
84       mem = nc[i];
85       nc[i] = tmp;
86    }
87 }
88
89 static celt_uint64_t prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
90 {
91    int i;
92    celt_uint64_t mem;
93    
94    mem = nc[0];
95    nc[0] = nc0;
96    for (i=1;i<len;i++)
97    {
98       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
99       mem = nc[i];
100       nc[i] = tmp;
101    }
102 }
103
104 /*Returns the numer of ways of choosing _m elements from a set of size _n with
105    replacement when a sign bit is needed for each unique element.*/
106 celt_uint32_t ncwrs(int _n,int _m)
107 {
108    int i;
109    celt_uint32_t ret;
110    celt_uint32_t nc[_n+1];
111    for (i=0;i<_n+1;i++)
112       nc[i] = 1;
113    for (i=0;i<_m;i++)
114       next_ncwrs32(nc, _n+1, 0);
115    return nc[_n];
116 }
117
118 /*Returns the numer of ways of choosing _m elements from a set of size _n with
119    replacement when a sign bit is needed for each unique element.*/
120 celt_uint64_t ncwrs64(int _n,int _m)
121 {
122    int i;
123    celt_uint64_t ret;
124    celt_uint64_t nc[_n+1];
125    for (i=0;i<_n+1;i++)
126       nc[i] = 1;
127    for (i=0;i<_m;i++)
128       next_ncwrs64(nc, _n+1, 0);
129    return nc[_n];
130 }
131
132
133 /*Returns the _i'th combination of _m elements chosen from a set of size _n
134    with associated sign bits.
135   _x:      Returns the combination with elements sorted in ascending order.
136   _s:      Returns the associated sign bits.*/
137 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
138   int j;
139   int k;
140   celt_uint32_t nc[_n+1];
141   for (j=0;j<_n+1;j++)
142     nc[j] = 1;
143   for (k=0;k<_m-1;k++)
144     next_ncwrs32(nc, _n+1, 0);
145   for(k=j=0;k<_m;k++){
146     celt_uint32_t pn, p, t;
147     /*p=ncwrs(_n-j,_m-k-1);
148     pn=ncwrs(_n-j-1,_m-k-1);*/
149     p=nc[_n-j];
150     pn=nc[_n-j-1];
151     p+=pn;
152     if(k>0){
153       t=p>>1;
154       if(t<=_i||_s[k-1])_i+=t;
155     }
156     while(p<=_i){
157       _i-=p;
158       j++;
159       p=pn;
160       /*pn=ncwrs(_n-j-1,_m-k-1);*/
161       pn=nc[_n-j-1];
162       p+=pn;
163     }
164     t=p>>1;
165     _s[k]=_i>=t;
166     _x[k]=j;
167     if(_s[k])_i-=t;
168     if (k<_m-2)
169       prev_ncwrs32(nc, _n+1, 0);
170     else
171       prev_ncwrs32(nc, _n+1, 1);
172   }
173 }
174
175 /*Returns the index of the given combination of _m elements chosen from a set
176    of size _n with associated sign bits.
177   _x:      The combination with elements sorted in ascending order.
178   _s:      The associated sign bits.*/
179 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
180   celt_uint32_t i;
181   int      j;
182   int      k;
183   celt_uint32_t nc[_n+1];
184   for (j=0;j<_n+1;j++)
185     nc[j] = 1;
186   for (k=0;k<_m;k++)
187     next_ncwrs32(nc, _n+1, 0);
188   if (bound)
189     *bound = nc[_n];
190   i=0;
191   for(k=j=0;k<_m;k++){
192     celt_uint32_t pn;
193     celt_uint32_t p;
194     if (k<_m-1)
195       prev_ncwrs32(nc, _n+1, 0);
196     else
197       prev_ncwrs32(nc, _n+1, 1);
198     /*p=ncwrs(_n-j,_m-k-1);
199     pn=ncwrs(_n-j-1,_m-k-1);*/
200     p=nc[_n-j];
201     pn=nc[_n-j-1];
202     p+=pn;
203     if(k>0)p>>=1;
204     while(j<_x[k]){
205       i+=p;
206       j++;
207       p=pn;
208       /*pn=ncwrs(_n-j-1,_m-k-1);*/
209       pn=nc[_n-j-1];
210       p+=pn;
211     }
212     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
213   }
214   return i;
215 }
216
217 /*Returns the _i'th combination of _m elements chosen from a set of size _n
218    with associated sign bits.
219   _x:      Returns the combination with elements sorted in ascending order.
220   _s:      Returns the associated sign bits.*/
221 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
222   int j;
223   int k;
224   celt_uint64_t nc[_n+1];
225   for (j=0;j<_n+1;j++)
226     nc[j] = 1;
227   for (k=0;k<_m-1;k++)
228     next_ncwrs64(nc, _n+1, 0);
229   for(k=j=0;k<_m;k++){
230     celt_uint64_t pn, p, t;
231     /*p=ncwrs64(_n-j,_m-k-1);
232     pn=ncwrs64(_n-j-1,_m-k-1);*/
233     p=nc[_n-j];
234     pn=nc[_n-j-1];
235     p+=pn;
236     if(k>0){
237       t=p>>1;
238       if(t<=_i||_s[k-1])_i+=t;
239     }
240     while(p<=_i){
241       _i-=p;
242       j++;
243       p=pn;
244       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
245       pn=nc[_n-j-1];
246       p+=pn;
247     }
248     t=p>>1;
249     _s[k]=_i>=t;
250     _x[k]=j;
251     if(_s[k])_i-=t;
252     if (k<_m-2)
253       prev_ncwrs64(nc, _n+1, 0);
254     else
255       prev_ncwrs64(nc, _n+1, 1);
256   }
257 }
258
259 /*Returns the index of the given combination of _m elements chosen from a set
260    of size _n with associated sign bits.
261   _x:      The combination with elements sorted in ascending order.
262   _s:      The associated sign bits.*/
263 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
264   celt_uint64_t i;
265   int           j;
266   int           k;
267   celt_uint64_t nc[_n+1];
268   for (j=0;j<_n+1;j++)
269     nc[j] = 1;
270   for (k=0;k<_m;k++)
271     next_ncwrs64(nc, _n+1, 0);
272   if (bound)
273      *bound = nc[_n];
274   i=0;
275   for(k=j=0;k<_m;k++){
276     celt_uint64_t pn;
277     celt_uint64_t p;
278     if (k<_m-1)
279       prev_ncwrs64(nc, _n+1, 0);
280     else
281       prev_ncwrs64(nc, _n+1, 1);
282     /*p=ncwrs64(_n-j,_m-k-1);
283     pn=ncwrs64(_n-j-1,_m-k-1);*/
284     p=nc[_n-j];
285     pn=nc[_n-j-1];
286     p+=pn;
287     if(k>0)p>>=1;
288     while(j<_x[k]){
289       i+=p;
290       j++;
291       p=pn;
292       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
293       pn=nc[_n-j-1];
294       p+=pn;
295     }
296     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
297   }
298   return i;
299 }
300
301 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
302    a pulse vector _y of length _n.
303   _y: Returns the vector of pulses.
304   _x: The combination with elements sorted in ascending order.
305   _s: The associated sign bits.*/
306 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
307   int j;
308   int k;
309   int n;
310   for(k=j=0;k<_m;k+=n){
311     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
312     while(j<_x[k])_y[j++]=0;
313     _y[j++]=_s[k]?-n:n;
314   }
315   while(j<_n)_y[j++]=0;
316 }
317
318 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
319    pulses with associated sign bits _s.
320   _x: Returns the combination with elements sorted in ascending order.
321   _s: Returns the associated sign bits.
322   _y: The vector of pulses, whose sum of absolute values must be _m.*/
323 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
324   int j;
325   int k;
326   for(k=j=0;j<_n;j++){
327     if(_y[j]){
328       int n;
329       int s;
330       n=abs(_y[j]);
331       s=_y[j]<0;
332       for(;n-->0;k++){
333         _x[k]=j;
334         _s[k]=s;
335       }
336     }
337   }
338 }
339
340 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
341 {
342    int comb[K];
343    int signs[K];
344    pulse2comb(N, K, comb, signs, _y);
345    /* Go with 32-bit path if we're sure we can */
346    if (N<=13 && K<=13)
347    {
348       celt_uint32_t bound, id;
349       id = icwrs(N, K, comb, signs, &bound);
350       ec_enc_uint(enc,id,bound);
351    } else {
352       celt_uint64_t bound, id;
353       id = icwrs64(N, K, comb, signs, &bound);
354       ec_enc_uint64(enc,id,bound);
355    }
356 }
357
358 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
359 {
360    int comb[K];
361    int signs[K];   
362    if (N<=13 && K<=13)
363    {
364       cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
365       comb2pulse(N, K, _y, comb, signs);
366    } else {
367       cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
368       comb2pulse(N, K, _y, comb, signs);
369    }
370 }
371