Initial support for a managed stack/scratchpad. Still needs some work.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors. For more details, see:
33    http://people.xiph.org/~tterribe/notes/cwrs.html
34 */
35
36 #ifdef HAVE_CONFIG_H
37 #include "config.h"
38 #endif
39
40 #include <stdlib.h>
41 #include "cwrs.h"
42
43 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
44    compute ncwrs() for m+1, for all n. Could also be used when m and n are
45    swapped just by changing nc */
46 static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
47 {
48    int i;
49    celt_uint32_t mem;
50    
51    mem = nc[0];
52    nc[0] = nc0;
53    for (i=1;i<len;i++)
54    {
55       celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
56       mem = nc[i];
57       nc[i] = tmp;
58    }
59 }
60
61 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
62    compute ncwrs() for m-1, for all n. Could also be used when m and n are
63    swapped just by changing nc */
64 static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
65 {
66    int i;
67    celt_uint32_t mem;
68    
69    mem = nc[0];
70    nc[0] = nc0;
71    for (i=1;i<len;i++)
72    {
73       celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
74       mem = nc[i];
75       nc[i] = tmp;
76    }
77 }
78
79 static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
80 {
81    int i;
82    celt_uint64_t mem;
83    
84    mem = nc[0];
85    nc[0] = nc0;
86    for (i=1;i<len;i++)
87    {
88       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
89       mem = nc[i];
90       nc[i] = tmp;
91    }
92 }
93
94 static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
95 {
96    int i;
97    celt_uint64_t mem;
98    
99    mem = nc[0];
100    nc[0] = nc0;
101    for (i=1;i<len;i++)
102    {
103       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
104       mem = nc[i];
105       nc[i] = tmp;
106    }
107 }
108
109 /*Returns the numer of ways of choosing _m elements from a set of size _n with
110    replacement when a sign bit is needed for each unique element.*/
111 celt_uint32_t ncwrs(int _n,int _m)
112 {
113    int i;
114    celt_uint32_t ret;
115    VARDECL(celt_uint32_t *nc);
116    SAVE_STACK;
117    ALLOC(nc,_n+1, celt_uint32_t);
118    for (i=0;i<_n+1;i++)
119       nc[i] = 1;
120    for (i=0;i<_m;i++)
121       next_ncwrs32(nc, _n+1, 0);
122    ret = nc[_n];
123    RESTORE_STACK;
124    return ret;
125 }
126
127 /*Returns the numer of ways of choosing _m elements from a set of size _n with
128    replacement when a sign bit is needed for each unique element.*/
129 celt_uint64_t ncwrs64(int _n,int _m)
130 {
131    int i;
132    celt_uint64_t ret;
133    VARDECL(celt_uint64_t *nc);
134    SAVE_STACK;
135    ALLOC(nc,_n+1, celt_uint64_t);
136    for (i=0;i<_n+1;i++)
137       nc[i] = 1;
138    for (i=0;i<_m;i++)
139       next_ncwrs64(nc, _n+1, 0);
140    ret = nc[_n];
141    RESTORE_STACK;
142    return ret;
143 }
144
145
146 /*Returns the _i'th combination of _m elements chosen from a set of size _n
147    with associated sign bits.
148   _x:      Returns the combination with elements sorted in ascending order.
149   _s:      Returns the associated sign bits.*/
150 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
151   int j;
152   int k;
153   VARDECL(celt_uint32_t *nc);
154   SAVE_STACK;
155   ALLOC(nc,_n+1, celt_uint32_t);
156   for (j=0;j<_n+1;j++)
157     nc[j] = 1;
158   for (k=0;k<_m-1;k++)
159     next_ncwrs32(nc, _n+1, 0);
160   for(k=j=0;k<_m;k++){
161     celt_uint32_t pn, p, t;
162     /*p=ncwrs(_n-j,_m-k-1);
163     pn=ncwrs(_n-j-1,_m-k-1);*/
164     p=nc[_n-j];
165     pn=nc[_n-j-1];
166     p+=pn;
167     if(k>0){
168       t=p>>1;
169       if(t<=_i||_s[k-1])_i+=t;
170     }
171     while(p<=_i){
172       _i-=p;
173       j++;
174       p=pn;
175       /*pn=ncwrs(_n-j-1,_m-k-1);*/
176       pn=nc[_n-j-1];
177       p+=pn;
178     }
179     t=p>>1;
180     _s[k]=_i>=t;
181     _x[k]=j;
182     if(_s[k])_i-=t;
183     if (k<_m-2)
184       prev_ncwrs32(nc, _n+1, 0);
185     else
186       prev_ncwrs32(nc, _n+1, 1);
187   }
188   RESTORE_STACK;
189 }
190
191 /*Returns the index of the given combination of _m elements chosen from a set
192    of size _n with associated sign bits.
193   _x:      The combination with elements sorted in ascending order.
194   _s:      The associated sign bits.*/
195 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
196   celt_uint32_t i;
197   int      j;
198   int      k;
199   VARDECL(celt_uint32_t *nc);
200   SAVE_STACK;
201   ALLOC(nc,_n+1, celt_uint32_t);
202   for (j=0;j<_n+1;j++)
203     nc[j] = 1;
204   for (k=0;k<_m;k++)
205     next_ncwrs32(nc, _n+1, 0);
206   if (bound)
207     *bound = nc[_n];
208   i=0;
209   for(k=j=0;k<_m;k++){
210     celt_uint32_t pn;
211     celt_uint32_t p;
212     if (k<_m-1)
213       prev_ncwrs32(nc, _n+1, 0);
214     else
215       prev_ncwrs32(nc, _n+1, 1);
216     /*p=ncwrs(_n-j,_m-k-1);
217     pn=ncwrs(_n-j-1,_m-k-1);*/
218     p=nc[_n-j];
219     pn=nc[_n-j-1];
220     p+=pn;
221     if(k>0)p>>=1;
222     while(j<_x[k]){
223       i+=p;
224       j++;
225       p=pn;
226       /*pn=ncwrs(_n-j-1,_m-k-1);*/
227       pn=nc[_n-j-1];
228       p+=pn;
229     }
230     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
231   }
232   RESTORE_STACK;
233   return i;
234 }
235
236 /*Returns the _i'th combination of _m elements chosen from a set of size _n
237    with associated sign bits.
238   _x:      Returns the combination with elements sorted in ascending order.
239   _s:      Returns the associated sign bits.*/
240 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
241   int j;
242   int k;
243   VARDECL(celt_uint64_t *nc);
244   SAVE_STACK;
245   ALLOC(nc,_n+1, celt_uint64_t);
246   for (j=0;j<_n+1;j++)
247     nc[j] = 1;
248   for (k=0;k<_m-1;k++)
249     next_ncwrs64(nc, _n+1, 0);
250   for(k=j=0;k<_m;k++){
251     celt_uint64_t pn, p, t;
252     /*p=ncwrs64(_n-j,_m-k-1);
253     pn=ncwrs64(_n-j-1,_m-k-1);*/
254     p=nc[_n-j];
255     pn=nc[_n-j-1];
256     p+=pn;
257     if(k>0){
258       t=p>>1;
259       if(t<=_i||_s[k-1])_i+=t;
260     }
261     while(p<=_i){
262       _i-=p;
263       j++;
264       p=pn;
265       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
266       pn=nc[_n-j-1];
267       p+=pn;
268     }
269     t=p>>1;
270     _s[k]=_i>=t;
271     _x[k]=j;
272     if(_s[k])_i-=t;
273     if (k<_m-2)
274       prev_ncwrs64(nc, _n+1, 0);
275     else
276       prev_ncwrs64(nc, _n+1, 1);
277   }
278   RESTORE_STACK;
279 }
280
281 /*Returns the index of the given combination of _m elements chosen from a set
282    of size _n with associated sign bits.
283   _x:      The combination with elements sorted in ascending order.
284   _s:      The associated sign bits.*/
285 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
286   celt_uint64_t i;
287   int           j;
288   int           k;
289   VARDECL(celt_uint64_t *nc);
290   SAVE_STACK;
291   ALLOC(nc,_n+1, celt_uint64_t);
292   for (j=0;j<_n+1;j++)
293     nc[j] = 1;
294   for (k=0;k<_m;k++)
295     next_ncwrs64(nc, _n+1, 0);
296   if (bound)
297      *bound = nc[_n];
298   i=0;
299   for(k=j=0;k<_m;k++){
300     celt_uint64_t pn;
301     celt_uint64_t p;
302     if (k<_m-1)
303       prev_ncwrs64(nc, _n+1, 0);
304     else
305       prev_ncwrs64(nc, _n+1, 1);
306     /*p=ncwrs64(_n-j,_m-k-1);
307     pn=ncwrs64(_n-j-1,_m-k-1);*/
308     p=nc[_n-j];
309     pn=nc[_n-j-1];
310     p+=pn;
311     if(k>0)p>>=1;
312     while(j<_x[k]){
313       i+=p;
314       j++;
315       p=pn;
316       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
317       pn=nc[_n-j-1];
318       p+=pn;
319     }
320     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
321   }
322   RESTORE_STACK;
323   return i;
324 }
325
326 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
327    a pulse vector _y of length _n.
328   _y: Returns the vector of pulses.
329   _x: The combination with elements sorted in ascending order.
330   _s: The associated sign bits.*/
331 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
332   int j;
333   int k;
334   int n;
335   for(k=j=0;k<_m;k+=n){
336     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
337     while(j<_x[k])_y[j++]=0;
338     _y[j++]=_s[k]?-n:n;
339   }
340   while(j<_n)_y[j++]=0;
341 }
342
343 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
344    pulses with associated sign bits _s.
345   _x: Returns the combination with elements sorted in ascending order.
346   _s: Returns the associated sign bits.
347   _y: The vector of pulses, whose sum of absolute values must be _m.*/
348 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
349   int j;
350   int k;
351   for(k=j=0;j<_n;j++){
352     if(_y[j]){
353       int n;
354       int s;
355       n=abs(_y[j]);
356       s=_y[j]<0;
357       for(;n-->0;k++){
358         _x[k]=j;
359         _s[k]=s;
360       }
361     }
362   }
363 }
364
365 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
366 {
367    VARDECL(int *comb);
368    VARDECL(int *signs);
369    SAVE_STACK;
370    
371    ALLOC(comb, K, int);
372    ALLOC(signs, K, int);
373    
374    pulse2comb(N, K, comb, signs, _y);
375    /* Go with 32-bit path if we're sure we can */
376    if (N<=13 && K<=13)
377    {
378       celt_uint32_t bound, id;
379       id = icwrs(N, K, comb, signs, &bound);
380       ec_enc_uint(enc,id,bound);
381    } else {
382       celt_uint64_t bound, id;
383       id = icwrs64(N, K, comb, signs, &bound);
384       ec_enc_uint64(enc,id,bound);
385    }
386    RESTORE_STACK;
387 }
388
389 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
390 {
391    VARDECL(int *comb);
392    VARDECL(int *signs);
393    SAVE_STACK;
394    
395    ALLOC(comb, K, int);
396    ALLOC(signs, K, int);
397    if (N<=13 && K<=13)
398    {
399       cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
400       comb2pulse(N, K, _y, comb, signs);
401    } else {
402       cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
403       comb2pulse(N, K, _y, comb, signs);
404    }
405    RESTORE_STACK;
406 }
407