fixed-point: unquant_energy_mono() has received the fixed-point code from
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /* Functions for encoding and decoding pulse vectors. For more details, see:
33    http://people.xiph.org/~tterribe/notes/cwrs.html
34 */
35
36 #ifdef HAVE_CONFIG_H
37 #include "config.h"
38 #endif
39
40 #include <stdlib.h>
41 #include "cwrs.h"
42
43 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
44    compute ncwrs() for m+1, for all n. Could also be used when m and n are
45    swapped just by changing nc */
46 static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
47 {
48    int i;
49    celt_uint32_t mem;
50    
51    mem = nc[0];
52    nc[0] = nc0;
53    for (i=1;i<len;i++)
54    {
55       celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
56       mem = nc[i];
57       nc[i] = tmp;
58    }
59 }
60
61 /* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
62    compute ncwrs() for m-1, for all n. Could also be used when m and n are
63    swapped just by changing nc */
64 static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
65 {
66    int i;
67    celt_uint32_t mem;
68    
69    mem = nc[0];
70    nc[0] = nc0;
71    for (i=1;i<len;i++)
72    {
73       celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
74       mem = nc[i];
75       nc[i] = tmp;
76    }
77 }
78
79 static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
80 {
81    int i;
82    celt_uint64_t mem;
83    
84    mem = nc[0];
85    nc[0] = nc0;
86    for (i=1;i<len;i++)
87    {
88       celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
89       mem = nc[i];
90       nc[i] = tmp;
91    }
92 }
93
94 static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
95 {
96    int i;
97    celt_uint64_t mem;
98    
99    mem = nc[0];
100    nc[0] = nc0;
101    for (i=1;i<len;i++)
102    {
103       celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
104       mem = nc[i];
105       nc[i] = tmp;
106    }
107 }
108
109 /*Returns the numer of ways of choosing _m elements from a set of size _n with
110    replacement when a sign bit is needed for each unique element.*/
111 celt_uint32_t ncwrs(int _n,int _m)
112 {
113    int i;
114    VARDECL(celt_uint32_t *nc);
115    ALLOC(nc,_n+1, celt_uint32_t);
116    for (i=0;i<_n+1;i++)
117       nc[i] = 1;
118    for (i=0;i<_m;i++)
119       next_ncwrs32(nc, _n+1, 0);
120    return nc[_n];
121 }
122
123 /*Returns the numer of ways of choosing _m elements from a set of size _n with
124    replacement when a sign bit is needed for each unique element.*/
125 celt_uint64_t ncwrs64(int _n,int _m)
126 {
127    int i;
128    VARDECL(celt_uint64_t *nc);
129    ALLOC(nc,_n+1, celt_uint64_t);
130    for (i=0;i<_n+1;i++)
131       nc[i] = 1;
132    for (i=0;i<_m;i++)
133       next_ncwrs64(nc, _n+1, 0);
134    return nc[_n];
135 }
136
137
138 /*Returns the _i'th combination of _m elements chosen from a set of size _n
139    with associated sign bits.
140   _x:      Returns the combination with elements sorted in ascending order.
141   _s:      Returns the associated sign bits.*/
142 void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
143   int j;
144   int k;
145   VARDECL(celt_uint32_t *nc);
146   ALLOC(nc,_n+1, celt_uint32_t);
147   for (j=0;j<_n+1;j++)
148     nc[j] = 1;
149   for (k=0;k<_m-1;k++)
150     next_ncwrs32(nc, _n+1, 0);
151   for(k=j=0;k<_m;k++){
152     celt_uint32_t pn, p, t;
153     /*p=ncwrs(_n-j,_m-k-1);
154     pn=ncwrs(_n-j-1,_m-k-1);*/
155     p=nc[_n-j];
156     pn=nc[_n-j-1];
157     p+=pn;
158     if(k>0){
159       t=p>>1;
160       if(t<=_i||_s[k-1])_i+=t;
161     }
162     while(p<=_i){
163       _i-=p;
164       j++;
165       p=pn;
166       /*pn=ncwrs(_n-j-1,_m-k-1);*/
167       pn=nc[_n-j-1];
168       p+=pn;
169     }
170     t=p>>1;
171     _s[k]=_i>=t;
172     _x[k]=j;
173     if(_s[k])_i-=t;
174     if (k<_m-2)
175       prev_ncwrs32(nc, _n+1, 0);
176     else
177       prev_ncwrs32(nc, _n+1, 1);
178   }
179 }
180
181 /*Returns the index of the given combination of _m elements chosen from a set
182    of size _n with associated sign bits.
183   _x:      The combination with elements sorted in ascending order.
184   _s:      The associated sign bits.*/
185 celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
186   celt_uint32_t i;
187   int      j;
188   int      k;
189   VARDECL(celt_uint32_t *nc);
190   ALLOC(nc,_n+1, celt_uint32_t);
191   for (j=0;j<_n+1;j++)
192     nc[j] = 1;
193   for (k=0;k<_m;k++)
194     next_ncwrs32(nc, _n+1, 0);
195   if (bound)
196     *bound = nc[_n];
197   i=0;
198   for(k=j=0;k<_m;k++){
199     celt_uint32_t pn;
200     celt_uint32_t p;
201     if (k<_m-1)
202       prev_ncwrs32(nc, _n+1, 0);
203     else
204       prev_ncwrs32(nc, _n+1, 1);
205     /*p=ncwrs(_n-j,_m-k-1);
206     pn=ncwrs(_n-j-1,_m-k-1);*/
207     p=nc[_n-j];
208     pn=nc[_n-j-1];
209     p+=pn;
210     if(k>0)p>>=1;
211     while(j<_x[k]){
212       i+=p;
213       j++;
214       p=pn;
215       /*pn=ncwrs(_n-j-1,_m-k-1);*/
216       pn=nc[_n-j-1];
217       p+=pn;
218     }
219     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
220   }
221   return i;
222 }
223
224 /*Returns the _i'th combination of _m elements chosen from a set of size _n
225    with associated sign bits.
226   _x:      Returns the combination with elements sorted in ascending order.
227   _s:      Returns the associated sign bits.*/
228 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
229   int j;
230   int k;
231   VARDECL(celt_uint64_t *nc);
232   ALLOC(nc,_n+1, celt_uint64_t);
233   for (j=0;j<_n+1;j++)
234     nc[j] = 1;
235   for (k=0;k<_m-1;k++)
236     next_ncwrs64(nc, _n+1, 0);
237   for(k=j=0;k<_m;k++){
238     celt_uint64_t pn, p, t;
239     /*p=ncwrs64(_n-j,_m-k-1);
240     pn=ncwrs64(_n-j-1,_m-k-1);*/
241     p=nc[_n-j];
242     pn=nc[_n-j-1];
243     p+=pn;
244     if(k>0){
245       t=p>>1;
246       if(t<=_i||_s[k-1])_i+=t;
247     }
248     while(p<=_i){
249       _i-=p;
250       j++;
251       p=pn;
252       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
253       pn=nc[_n-j-1];
254       p+=pn;
255     }
256     t=p>>1;
257     _s[k]=_i>=t;
258     _x[k]=j;
259     if(_s[k])_i-=t;
260     if (k<_m-2)
261       prev_ncwrs64(nc, _n+1, 0);
262     else
263       prev_ncwrs64(nc, _n+1, 1);
264   }
265 }
266
267 /*Returns the index of the given combination of _m elements chosen from a set
268    of size _n with associated sign bits.
269   _x:      The combination with elements sorted in ascending order.
270   _s:      The associated sign bits.*/
271 celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
272   celt_uint64_t i;
273   int           j;
274   int           k;
275   VARDECL(celt_uint64_t *nc);
276   ALLOC(nc,_n+1, celt_uint64_t);
277   for (j=0;j<_n+1;j++)
278     nc[j] = 1;
279   for (k=0;k<_m;k++)
280     next_ncwrs64(nc, _n+1, 0);
281   if (bound)
282      *bound = nc[_n];
283   i=0;
284   for(k=j=0;k<_m;k++){
285     celt_uint64_t pn;
286     celt_uint64_t p;
287     if (k<_m-1)
288       prev_ncwrs64(nc, _n+1, 0);
289     else
290       prev_ncwrs64(nc, _n+1, 1);
291     /*p=ncwrs64(_n-j,_m-k-1);
292     pn=ncwrs64(_n-j-1,_m-k-1);*/
293     p=nc[_n-j];
294     pn=nc[_n-j-1];
295     p+=pn;
296     if(k>0)p>>=1;
297     while(j<_x[k]){
298       i+=p;
299       j++;
300       p=pn;
301       /*pn=ncwrs64(_n-j-1,_m-k-1);*/
302       pn=nc[_n-j-1];
303       p+=pn;
304     }
305     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
306   }
307   return i;
308 }
309
310 /*Converts a combination _x of _m unit pulses with associated sign bits _s into
311    a pulse vector _y of length _n.
312   _y: Returns the vector of pulses.
313   _x: The combination with elements sorted in ascending order.
314   _s: The associated sign bits.*/
315 void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s){
316   int j;
317   int k;
318   int n;
319   for(k=j=0;k<_m;k+=n){
320     for(n=1;k+n<_m&&_x[k+n]==_x[k];n++);
321     while(j<_x[k])_y[j++]=0;
322     _y[j++]=_s[k]?-n:n;
323   }
324   while(j<_n)_y[j++]=0;
325 }
326
327 /*Converts a pulse vector vector _y of length _n into a combination of _m unit
328    pulses with associated sign bits _s.
329   _x: Returns the combination with elements sorted in ascending order.
330   _s: Returns the associated sign bits.
331   _y: The vector of pulses, whose sum of absolute values must be _m.*/
332 void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
333   int j;
334   int k;
335   for(k=j=0;j<_n;j++){
336     if(_y[j]){
337       int n;
338       int s;
339       n=abs(_y[j]);
340       s=_y[j]<0;
341       for(;n-->0;k++){
342         _x[k]=j;
343         _s[k]=s;
344       }
345     }
346   }
347 }
348
349 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
350 {
351    VARDECL(int *comb);
352    VARDECL(int *signs);
353    
354    ALLOC(comb, K, int);
355    ALLOC(signs, K, int);
356    
357    pulse2comb(N, K, comb, signs, _y);
358    /* Go with 32-bit path if we're sure we can */
359    if (N<=13 && K<=13)
360    {
361       celt_uint32_t bound, id;
362       id = icwrs(N, K, comb, signs, &bound);
363       ec_enc_uint(enc,id,bound);
364    } else {
365       celt_uint64_t bound, id;
366       id = icwrs64(N, K, comb, signs, &bound);
367       ec_enc_uint64(enc,id,bound);
368    }
369 }
370
371 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
372 {
373    VARDECL(int *comb);
374    VARDECL(int *signs);
375    
376    ALLOC(comb, K, int);
377    ALLOC(signs, K, int);
378    if (N<=13 && K<=13)
379    {
380       cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
381       comb2pulse(N, K, _y, comb, signs);
382    } else {
383       cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
384       comb2pulse(N, K, _y, comb, signs);
385    }
386 }
387