Better computation of the VBR rate upper bound and reducing the coarse energy
[opus.git] / libcelt / cwrs.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2007-2009 Timothy B. Terriberry
4    Written by Timothy B. Terriberry and Jean-Marc Valin */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include "os_support.h"
39 #include <stdlib.h>
40 #include <string.h>
41 #include "cwrs.h"
42 #include "mathops.h"
43 #include "arch.h"
44
45 /*Guaranteed to return a conservatively large estimate of the binary logarithm
46    with frac bits of fractional precision.
47   Tested for all possible 32-bit inputs with frac=4, where the maximum
48    overestimation is 0.06254243 bits.*/
49 int log2_frac(ec_uint32 val, int frac)
50 {
51   int l;
52   l=EC_ILOG(val);
53   if(val&val-1){
54     /*This is (val>>l-16), but guaranteed to round up, even if adding a bias
55        before the shift would cause overflow (e.g., for 0xFFFFxxxx).*/
56     if(l>16)val=(val>>l-16)+((val&(1<<l-16)-1)+(1<<l-16)-1>>l-16);
57     else val<<=16-l;
58     l=l-1<<frac;
59     /*Note that we always need one iteration, since the rounding up above means
60        that we might need to adjust the integer part of the logarithm.*/
61     do{
62       int b;
63       b=(int)(val>>16);
64       l+=b<<frac;
65       val=val+b>>b;
66       val=val*val+0x7FFF>>15;
67     }
68     while(frac-->0);
69     /*If val is not exactly 0x8000, then we have to round up the remainder.*/
70     return l+(val>0x8000);
71   }
72   /*Exact powers of two require no rounding.*/
73   else return l-1<<frac;
74 }
75
76 #define MASK32 (0xFFFFFFFF)
77
78 /*INV_TABLE[i] holds the multiplicative inverse of (2*i+1) mod 2**32.*/
79 static const celt_uint32 INV_TABLE[128]={
80   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
81   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
82   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
83   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
84   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
85   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
86   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
87   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
88   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
89   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
90   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
91   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
92   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
93   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
94   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
95   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
96   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
97   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
98   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
99   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
100   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
101   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
102   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
103   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
104   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
105   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
106   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
107   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
108   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
109   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
110   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
111   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
112 };
113
114 /*Computes (_a*_b-_c)/(2*_d+1) when the quotient is known to be exact.
115   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
116    fits in 32 bits, but currently the table for multiplicative inverses is only
117    valid for _d<128.*/
118 static inline celt_uint32 imusdiv32odd(celt_uint32 _a,celt_uint32 _b,
119  celt_uint32 _c,int _d){
120   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
121 }
122
123 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
124   _d does not actually have to be even, but imusdiv32odd will be faster when
125    it's odd, so you should use that instead.
126   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
127    table for multiplicative inverses is only valid for _d<=256).
128   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
129    32 bits.*/
130 static inline celt_uint32 imusdiv32even(celt_uint32 _a,celt_uint32 _b,
131  celt_uint32 _c,int _d){
132   celt_uint32 inv;
133   int           mask;
134   int           shift;
135   int           one;
136   celt_assert(_d>0);
137   shift=EC_ILOG(_d^_d-1);
138   celt_assert(_d<=256);
139   inv=INV_TABLE[_d-1>>shift];
140   shift--;
141   one=1<<shift;
142   mask=one-1;
143   return (_a*(_b>>shift)-(_c>>shift)+
144    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
145 }
146
147 /*Compute floor(sqrt(_val)) with exact arithmetic.
148   This has been tested on all possible 32-bit inputs.*/
149 static unsigned isqrt32(celt_uint32 _val){
150   unsigned b;
151   unsigned g;
152   int      bshift;
153   /*Uses the second method from
154      http://www.azillionmonkeys.com/qed/sqroot.html
155     The main idea is to search for the largest binary digit b such that
156      (g+b)*(g+b) <= _val, and add it to the solution g.*/
157   g=0;
158   bshift=EC_ILOG(_val)-1>>1;
159   b=1U<<bshift;
160   do{
161     celt_uint32 t;
162     t=((celt_uint32)g<<1)+b<<bshift;
163     if(t<=_val){
164       g+=b;
165       _val-=t;
166     }
167     b>>=1;
168     bshift--;
169   }
170   while(bshift>=0);
171   return g;
172 }
173
174 /*Although derived separately, the pulse vector coding scheme is equivalent to
175    a Pyramid Vector Quantizer \cite{Fis86}.
176   Some additional notes about an early version appear at
177    http://people.xiph.org/~tterribe/notes/cwrs.html, but the codebook ordering
178    and the definitions of some terms have evolved since that was written.
179
180   The conversion from a pulse vector to an integer index (encoding) and back
181    (decoding) is governed by two related functions, V(N,K) and U(N,K).
182
183   V(N,K) = the number of combinations, with replacement, of N items, taken K
184    at a time, when a sign bit is added to each item taken at least once (i.e.,
185    the number of N-dimensional unit pulse vectors with K pulses).
186   One way to compute this is via
187     V(N,K) = K>0 ? sum(k=1...K,2**k*choose(N,k)*choose(K-1,k-1)) : 1,
188    where choose() is the binomial function.
189   A table of values for N<10 and K<10 looks like:
190   V[10][10] = {
191     {1,  0,   0,    0,    0,     0,     0,      0,      0,       0},
192     {1,  2,   2,    2,    2,     2,     2,      2,      2,       2},
193     {1,  4,   8,   12,   16,    20,    24,     28,     32,      36},
194     {1,  6,  18,   38,   66,   102,   146,    198,    258,     326},
195     {1,  8,  32,   88,  192,   360,   608,    952,   1408,    1992},
196     {1, 10,  50,  170,  450,  1002,  1970,   3530,   5890,    9290},
197     {1, 12,  72,  292,  912,  2364,  5336,  10836,  20256,   35436},
198     {1, 14,  98,  462, 1666,  4942, 12642,  28814,  59906,  115598},
199     {1, 16, 128,  688, 2816,  9424, 27008,  68464, 157184,  332688},
200     {1, 18, 162,  978, 4482, 16722, 53154, 148626, 374274,  864146}
201   };
202
203   U(N,K) = the number of such combinations wherein N-1 objects are taken at
204    most K-1 at a time.
205   This is given by
206     U(N,K) = sum(k=0...K-1,V(N-1,k))
207            = K>0 ? (V(N-1,K-1) + V(N,K-1))/2 : 0.
208   The latter expression also makes clear that U(N,K) is half the number of such
209    combinations wherein the first object is taken at least once.
210   Although it may not be clear from either of these definitions, U(N,K) is the
211    natural function to work with when enumerating the pulse vector codebooks,
212    not V(N,K).
213   U(N,K) is not well-defined for N=0, but with the extension
214     U(0,K) = K>0 ? 0 : 1,
215    the function becomes symmetric: U(N,K) = U(K,N), with a similar table:
216   U[10][10] = {
217     {1, 0,  0,   0,    0,    0,     0,     0,      0,      0},
218     {0, 1,  1,   1,    1,    1,     1,     1,      1,      1},
219     {0, 1,  3,   5,    7,    9,    11,    13,     15,     17},
220     {0, 1,  5,  13,   25,   41,    61,    85,    113,    145},
221     {0, 1,  7,  25,   63,  129,   231,   377,    575,    833},
222     {0, 1,  9,  41,  129,  321,   681,  1289,   2241,   3649},
223     {0, 1, 11,  61,  231,  681,  1683,  3653,   7183,  13073},
224     {0, 1, 13,  85,  377, 1289,  3653,  8989,  19825,  40081},
225     {0, 1, 15, 113,  575, 2241,  7183, 19825,  48639, 108545},
226     {0, 1, 17, 145,  833, 3649, 13073, 40081, 108545, 265729}
227   };
228
229   With this extension, V(N,K) may be written in terms of U(N,K):
230     V(N,K) = U(N,K) + U(N,K+1)
231    for all N>=0, K>=0.
232   Thus U(N,K+1) represents the number of combinations where the first element
233    is positive or zero, and U(N,K) represents the number of combinations where
234    it is negative.
235   With a large enough table of U(N,K) values, we could write O(N) encoding
236    and O(min(N*log(K),N+K)) decoding routines, but such a table would be
237    prohibitively large for small embedded devices (K may be as large as 32767
238    for small N, and N may be as large as 200).
239
240   Both functions obey the same recurrence relation:
241     V(N,K) = V(N-1,K) + V(N,K-1) + V(N-1,K-1),
242     U(N,K) = U(N-1,K) + U(N,K-1) + U(N-1,K-1),
243    for all N>0, K>0, with different initial conditions at N=0 or K=0.
244   This allows us to construct a row of one of the tables above given the
245    previous row or the next row.
246   Thus we can derive O(NK) encoding and decoding routines with O(K) memory
247    using only addition and subtraction.
248
249   When encoding, we build up from the U(2,K) row and work our way forwards.
250   When decoding, we need to start at the U(N,K) row and work our way backwards,
251    which requires a means of computing U(N,K).
252   U(N,K) may be computed from two previous values with the same N:
253     U(N,K) = ((2*N-1)*U(N,K-1) - U(N,K-2))/(K-1) + U(N,K-2)
254    for all N>1, and since U(N,K) is symmetric, a similar relation holds for two
255    previous values with the same K:
256     U(N,K>1) = ((2*K-1)*U(N-1,K) - U(N-2,K))/(N-1) + U(N-2,K)
257    for all K>1.
258   This allows us to construct an arbitrary row of the U(N,K) table by starting
259    with the first two values, which are constants.
260   This saves roughly 2/3 the work in our O(NK) decoding routine, but costs O(K)
261    multiplications.
262   Similar relations can be derived for V(N,K), but are not used here.
263
264   For N>0 and K>0, U(N,K) and V(N,K) take on the form of an (N-1)-degree
265    polynomial for fixed N.
266   The first few are
267     U(1,K) = 1,
268     U(2,K) = 2*K-1,
269     U(3,K) = (2*K-2)*K+1,
270     U(4,K) = (((4*K-6)*K+8)*K-3)/3,
271     U(5,K) = ((((2*K-4)*K+10)*K-8)*K+3)/3,
272    and
273     V(1,K) = 2,
274     V(2,K) = 4*K,
275     V(3,K) = 4*K*K+2,
276     V(4,K) = 8*(K*K+2)*K/3,
277     V(5,K) = ((4*K*K+20)*K*K+6)/3,
278    for all K>0.
279   This allows us to derive O(N) encoding and O(N*log(K)) decoding routines for
280    small N (and indeed decoding is also O(N) for N<3).
281
282   @ARTICLE{Fis86,
283     author="Thomas R. Fischer",
284     title="A Pyramid Vector Quantizer",
285     journal="IEEE Transactions on Information Theory",
286     volume="IT-32",
287     number=4,
288     pages="568--583",
289     month=Jul,
290     year=1986
291   }*/
292
293 /*Determines if V(N,K) fits in a 32-bit unsigned integer.
294   N and K are themselves limited to 15 bits.*/
295 int fits_in32(int _n, int _k)
296 {
297    static const celt_int16 maxN[15] = {
298       32767, 32767, 32767, 1476, 283, 109,  60,  40,
299        29,  24,  20,  18,  16,  14,  13};
300    static const celt_int16 maxK[15] = {
301       32767, 32767, 32767, 32767, 1172, 238,  95,  53,
302        36,  27,  22,  18,  16,  15,  13};
303    if (_n>=14)
304    {
305       if (_k>=14)
306          return 0;
307       else
308          return _n <= maxN[_k];
309    } else {
310       return _k <= maxK[_n];
311    }
312 }
313
314 /*Compute U(1,_k).*/
315 static inline unsigned ucwrs1(int _k){
316   return _k?1:0;
317 }
318
319 /*Compute V(1,_k).*/
320 static inline unsigned ncwrs1(int _k){
321   return _k?2:1;
322 }
323
324 /*Compute U(2,_k).
325   Note that this may be called with _k=32768 (maxK[2]+1).*/
326 static inline unsigned ucwrs2(unsigned _k){
327   return _k?_k+(_k-1):0;
328 }
329
330 /*Compute V(2,_k).*/
331 static inline celt_uint32 ncwrs2(int _k){
332   return _k?4*(celt_uint32)_k:1;
333 }
334
335 /*Compute U(3,_k).
336   Note that this may be called with _k=32768 (maxK[3]+1).*/
337 static inline celt_uint32 ucwrs3(unsigned _k){
338   return _k?(2*(celt_uint32)_k-2)*_k+1:0;
339 }
340
341 /*Compute V(3,_k).*/
342 static inline celt_uint32 ncwrs3(int _k){
343   return _k?2*(2*(unsigned)_k*(celt_uint32)_k+1):1;
344 }
345
346 /*Compute U(4,_k).*/
347 static inline celt_uint32 ucwrs4(int _k){
348   return _k?imusdiv32odd(2*_k,(2*_k-3)*(celt_uint32)_k+4,3,1):0;
349 }
350
351 /*Compute V(4,_k).*/
352 static inline celt_uint32 ncwrs4(int _k){
353   return _k?((_k*(celt_uint32)_k+2)*_k)/3<<3:1;
354 }
355
356 /*Compute U(5,_k).*/
357 static inline celt_uint32 ucwrs5(int _k){
358   return _k?(((((_k-2)*(unsigned)_k+5)*(celt_uint32)_k-4)*_k)/3<<1)+1:0;
359 }
360
361 /*Compute V(5,_k).*/
362 static inline celt_uint32 ncwrs5(int _k){
363   return _k?(((_k*(unsigned)_k+5)*(celt_uint32)_k*_k)/3<<2)+2:1;
364 }
365
366 /*Computes the next row/column of any recurrence that obeys the relation
367    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
368   _ui0 is the base case for the new row/column.*/
369 static inline void unext(celt_uint32 *_ui,unsigned _len,celt_uint32 _ui0){
370   celt_uint32 ui1;
371   unsigned      j;
372   /*This do-while will overrun the array if we don't have storage for at least
373      2 values.*/
374   j=1; do {
375     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
376     _ui[j-1]=_ui0;
377     _ui0=ui1;
378   } while (++j<_len);
379   _ui[j-1]=_ui0;
380 }
381
382 /*Computes the previous row/column of any recurrence that obeys the relation
383    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
384   _ui0 is the base case for the new row/column.*/
385 static inline void uprev(celt_uint32 *_ui,unsigned _n,celt_uint32 _ui0){
386   celt_uint32 ui1;
387   unsigned      j;
388   /*This do-while will overrun the array if we don't have storage for at least
389      2 values.*/
390   j=1; do {
391     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
392     _ui[j-1]=_ui0;
393     _ui0=ui1;
394   } while (++j<_n);
395   _ui[j-1]=_ui0;
396 }
397
398 /*Compute V(_n,_k), as well as U(_n,0..._k+1).
399   _u: On exit, _u[i] contains U(_n,i) for i in [0..._k+1].*/
400 static celt_uint32 ncwrs_urow(unsigned _n,unsigned _k,celt_uint32 *_u){
401   celt_uint32 um2;
402   unsigned      len;
403   unsigned      k;
404   len=_k+2;
405   /*We require storage at least 3 values (e.g., _k>0).*/
406   celt_assert(len>=3);
407   _u[0]=0;
408   _u[1]=um2=1;
409   if(_n<=6 || _k>255){
410     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
411     /*If _n==1, _u[i] should be 1 for i>1.*/
412     celt_assert(_n>=2);
413     /*If _k==0, the following do-while loop will overflow the buffer.*/
414     celt_assert(_k>0);
415     k=2;
416     do _u[k]=(k<<1)-1;
417     while(++k<len);
418     for(k=2;k<_n;k++)unext(_u+1,_k+1,1);
419   }
420   else{
421     celt_uint32 um1;
422     celt_uint32 n2m1;
423     _u[2]=n2m1=um1=(_n<<1)-1;
424     for(k=3;k<len;k++){
425       /*U(N,K) = ((2*N-1)*U(N,K-1)-U(N,K-2))/(K-1) + U(N,K-2)*/
426       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
427       if(++k>=len)break;
428       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
429     }
430   }
431   return _u[_k]+_u[_k+1];
432 }
433
434
435 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
436    set of size 1 with associated sign bits.
437   _y: Returns the vector of pulses.*/
438 static inline void cwrsi1(int _k,celt_uint32 _i,int *_y){
439   int s;
440   s=-(int)_i;
441   _y[0]=_k+s^s;
442 }
443
444 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
445    set of size 2 with associated sign bits.
446   _y: Returns the vector of pulses.*/
447 static inline void cwrsi2(int _k,celt_uint32 _i,int *_y){
448   celt_uint32 p;
449   int           s;
450   int           yj;
451   p=ucwrs2(_k+1U);
452   s=-(_i>=p);
453   _i-=p&s;
454   yj=_k;
455   _k=_i+1>>1;
456   p=ucwrs2(_k);
457   _i-=p;
458   yj-=_k;
459   _y[0]=yj+s^s;
460   cwrsi1(_k,_i,_y+1);
461 }
462
463 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
464    set of size 3 with associated sign bits.
465   _y: Returns the vector of pulses.*/
466 static void cwrsi3(int _k,celt_uint32 _i,int *_y){
467   celt_uint32 p;
468   int           s;
469   int           yj;
470   p=ucwrs3(_k+1U);
471   s=-(_i>=p);
472   _i-=p&s;
473   yj=_k;
474   /*Finds the maximum _k such that ucwrs3(_k)<=_i (tested for all
475      _i<2147418113=U(3,32768)).*/
476   _k=_i>0?isqrt32(2*_i-1)+1>>1:0;
477   p=ucwrs3(_k);
478   _i-=p;
479   yj-=_k;
480   _y[0]=yj+s^s;
481   cwrsi2(_k,_i,_y+1);
482 }
483
484 /*Returns the _i'th combination of _k elements (at most 1172) chosen from a set
485    of size 4 with associated sign bits.
486   _y: Returns the vector of pulses.*/
487 static void cwrsi4(int _k,celt_uint32 _i,int *_y){
488   celt_uint32 p;
489   int           s;
490   int           yj;
491   int           kl;
492   int           kr;
493   p=ucwrs4(_k+1);
494   s=-(_i>=p);
495   _i-=p&s;
496   yj=_k;
497   /*We could solve a cubic for k here, but the form of the direct solution does
498      not lend itself well to exact integer arithmetic.
499     Instead we do a binary search on U(4,K).*/
500   kl=0;
501   kr=_k;
502   for(;;){
503     _k=kl+kr>>1;
504     p=ucwrs4(_k);
505     if(p<_i){
506       if(_k>=kr)break;
507       kl=_k+1;
508     }
509     else if(p>_i)kr=_k-1;
510     else break;
511   }
512   _i-=p;
513   yj-=_k;
514   _y[0]=yj+s^s;
515   cwrsi3(_k,_i,_y+1);
516 }
517
518 /*Returns the _i'th combination of _k elements (at most 238) chosen from a set
519    of size 5 with associated sign bits.
520   _y: Returns the vector of pulses.*/
521 static void cwrsi5(int _k,celt_uint32 _i,int *_y){
522   celt_uint32 p;
523   int           s;
524   int           yj;
525   p=ucwrs5(_k+1);
526   s=-(_i>=p);
527   _i-=p&s;
528   yj=_k;
529   /* A binary search on U(5,K) avoids the need for 64-bit arithmetic */
530   {
531     int kl=0;
532     int kr=_k;
533     for(;;){
534       _k=kl+kr>>1;
535       p=ucwrs5(_k);
536       if(p<_i){
537         if(_k>=kr)break;
538         kl=_k+1;
539       }
540       else if(p>_i)kr=_k-1;
541       else break;
542     }  
543   }
544   _i-=p;
545   yj-=_k;
546   _y[0]=yj+s^s;
547   cwrsi4(_k,_i,_y+1);
548 }
549
550 /*Returns the _i'th combination of _k elements chosen from a set of size _n
551    with associated sign bits.
552   _y: Returns the vector of pulses.
553   _u: Must contain entries [0..._k+1] of row _n of U() on input.
554       Its contents will be destructively modified.*/
555 static void cwrsi(int _n,int _k,celt_uint32 _i,int *_y,celt_uint32 *_u){
556   int j;
557   celt_assert(_n>0);
558   j=0;
559   do{
560     celt_uint32 p;
561     int           s;
562     int           yj;
563     p=_u[_k+1];
564     s=-(_i>=p);
565     _i-=p&s;
566     yj=_k;
567     p=_u[_k];
568     while(p>_i)p=_u[--_k];
569     _i-=p;
570     yj-=_k;
571     _y[j]=yj+s^s;
572     uprev(_u,_k+2,0);
573   }
574   while(++j<_n);
575 }
576
577
578 /*Returns the index of the given combination of K elements chosen from a set
579    of size 1 with associated sign bits.
580   _y: The vector of pulses, whose sum of absolute values is K.
581   _k: Returns K.*/
582 static inline celt_uint32 icwrs1(const int *_y,int *_k){
583   *_k=abs(_y[0]);
584   return _y[0]<0;
585 }
586
587 /*Returns the index of the given combination of K elements chosen from a set
588    of size 2 with associated sign bits.
589   _y: The vector of pulses, whose sum of absolute values is K.
590   _k: Returns K.*/
591 static inline celt_uint32 icwrs2(const int *_y,int *_k){
592   celt_uint32 i;
593   int           k;
594   i=icwrs1(_y+1,&k);
595   i+=ucwrs2(k);
596   k+=abs(_y[0]);
597   if(_y[0]<0)i+=ucwrs2(k+1U);
598   *_k=k;
599   return i;
600 }
601
602 /*Returns the index of the given combination of K elements chosen from a set
603    of size 3 with associated sign bits.
604   _y: The vector of pulses, whose sum of absolute values is K.
605   _k: Returns K.*/
606 static inline celt_uint32 icwrs3(const int *_y,int *_k){
607   celt_uint32 i;
608   int           k;
609   i=icwrs2(_y+1,&k);
610   i+=ucwrs3(k);
611   k+=abs(_y[0]);
612   if(_y[0]<0)i+=ucwrs3(k+1U);
613   *_k=k;
614   return i;
615 }
616
617 /*Returns the index of the given combination of K elements chosen from a set
618    of size 4 with associated sign bits.
619   _y: The vector of pulses, whose sum of absolute values is K.
620   _k: Returns K.*/
621 static inline celt_uint32 icwrs4(const int *_y,int *_k){
622   celt_uint32 i;
623   int           k;
624   i=icwrs3(_y+1,&k);
625   i+=ucwrs4(k);
626   k+=abs(_y[0]);
627   if(_y[0]<0)i+=ucwrs4(k+1);
628   *_k=k;
629   return i;
630 }
631
632 /*Returns the index of the given combination of K elements chosen from a set
633    of size 5 with associated sign bits.
634   _y: The vector of pulses, whose sum of absolute values is K.
635   _k: Returns K.*/
636 static inline celt_uint32 icwrs5(const int *_y,int *_k){
637   celt_uint32 i;
638   int           k;
639   i=icwrs4(_y+1,&k);
640   i+=ucwrs5(k);
641   k+=abs(_y[0]);
642   if(_y[0]<0)i+=ucwrs5(k+1);
643   *_k=k;
644   return i;
645 }
646
647 /*Returns the index of the given combination of K elements chosen from a set
648    of size _n with associated sign bits.
649   _y:  The vector of pulses, whose sum of absolute values must be _k.
650   _nc: Returns V(_n,_k).*/
651 celt_uint32 icwrs(int _n,int _k,celt_uint32 *_nc,const int *_y,
652  celt_uint32 *_u){
653   celt_uint32 i;
654   int           j;
655   int           k;
656   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
657   celt_assert(_n>=2);
658   _u[0]=0;
659   for(k=1;k<=_k+1;k++)_u[k]=(k<<1)-1;
660   i=icwrs1(_y+_n-1,&k);
661   j=_n-2;
662   i+=_u[k];
663   k+=abs(_y[j]);
664   if(_y[j]<0)i+=_u[k+1];
665   while(j-->0){
666     unext(_u,_k+2,0);
667     i+=_u[k];
668     k+=abs(_y[j]);
669     if(_y[j]<0)i+=_u[k+1];
670   }
671   *_nc=_u[k]+_u[k+1];
672   return i;
673 }
674
675
676 /*Computes get_required_bits when splitting is required.
677   _left_bits and _right_bits must contain the required bits for the left and
678    right sides of the split, respectively (which themselves may require
679    splitting).*/
680 static void get_required_split_bits(celt_int16 *_bits,
681  const celt_int16 *_left_bits,const celt_int16 *_right_bits,
682  int _n,int _maxk,int _frac){
683   int k;
684   for(k=_maxk;k-->0;){
685     /*If we've reached a k where everything fits in 32 bits, evaluate the
686        remaining required bits directly.*/
687     if(fits_in32(_n,k)){
688       get_required_bits(_bits,_n,k+1,_frac);
689       break;
690     }
691     else{
692       int worst_bits;
693       int i;
694       /*Due to potentially recursive splitting, it's difficult to derive an
695          analytic expression for the location of the worst-case split index.
696         We simply check them all.*/
697       worst_bits=0;
698       for(i=0;i<=k;i++){
699         int split_bits;
700         split_bits=_left_bits[i]+_right_bits[k-i];
701         if(split_bits>worst_bits)worst_bits=split_bits;
702       }
703       _bits[k]=log2_frac(k+1,_frac)+worst_bits;
704     }
705   }
706 }
707
708 /*Computes get_required_bits for a pair of N values.
709   _n1 and _n2 must either be equal or two consecutive integers.
710   Returns the buffer used to store the required bits for _n2, which is either
711    _bits1 if _n1==_n2 or _bits2 if _n1+1==_n2.*/
712 static celt_int16 *get_required_bits_pair(celt_int16 *_bits1,
713  celt_int16 *_bits2,celt_int16 *_tmp,int _n1,int _n2,int _maxk,int _frac){
714   celt_int16 *tmp2;
715   /*If we only need a single set of required bits...*/
716   if(_n1==_n2){
717     /*Stop recursing if everything fits.*/
718     if(fits_in32(_n1,_maxk-1))get_required_bits(_bits1,_n1,_maxk,_frac);
719     else{
720       _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
721        _n1>>1,_n1+1>>1,_maxk,_frac);
722       get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
723     }
724     return _bits1;
725   }
726   /*Otherwise we need two distinct sets...*/
727   celt_assert(_n1+1==_n2);
728   /*Stop recursing if everything fits.*/
729   if(fits_in32(_n2,_maxk-1)){
730     get_required_bits(_bits1,_n1,_maxk,_frac);
731     get_required_bits(_bits2,_n2,_maxk,_frac);
732   }
733   /*Otherwise choose an evaluation order that doesn't require extra buffers.*/
734   else if(_n1&1){
735     /*This special case isn't really needed, but can save some work.*/
736     if(fits_in32(_n1,_maxk-1)){
737       tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
738        _n2>>1,_n2>>1,_maxk,_frac);
739       get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
740       get_required_bits(_bits1,_n1,_maxk,_frac);
741     }
742     else{
743       _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
744        _n1>>1,_n1+1>>1,_maxk,_frac);
745       get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
746       get_required_split_bits(_bits2,_tmp,_tmp,_n2,_maxk,_frac);
747     }
748   }
749   else{
750     /*There's no need to special case _n1 fitting by itself, since _n2 requires
751        us to recurse for both values anyway.*/
752     tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
753      _n2>>1,_n2+1>>1,_maxk,_frac);
754     get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
755     get_required_split_bits(_bits1,_tmp,_tmp,_n1,_maxk,_frac);
756   }
757   return _bits2;
758 }
759
760 void get_required_bits(celt_int16 *_bits,int _n,int _maxk,int _frac){
761   int k;
762   /*_maxk==0 => there's nothing to do.*/
763   celt_assert(_maxk>0);
764   if(fits_in32(_n,_maxk-1)){
765     _bits[0]=0;
766     if(_maxk>1){
767       VARDECL(celt_uint32,u);
768       SAVE_STACK;
769       ALLOC(u,_maxk+1U,celt_uint32);
770       ncwrs_urow(_n,_maxk-1,u);
771       for(k=1;k<_maxk;k++)_bits[k]=log2_frac(u[k]+u[k+1],_frac);
772       RESTORE_STACK;
773     }
774   }
775   else{
776     VARDECL(celt_int16,n1bits);
777     VARDECL(celt_int16,n2bits_buf);
778     celt_int16 *n2bits;
779     SAVE_STACK;
780     ALLOC(n1bits,_maxk,celt_int16);
781     ALLOC(n2bits_buf,_maxk,celt_int16);
782     n2bits=get_required_bits_pair(n1bits,n2bits_buf,_bits,
783      _n>>1,_n+1>>1,_maxk,_frac);
784     get_required_split_bits(_bits,n1bits,n2bits,_n,_maxk,_frac);
785     RESTORE_STACK;
786   }
787 }
788
789
790 static inline void encode_pulses32(int _n,int _k,const int *_y,ec_enc *_enc){
791   celt_uint32 i;
792   switch(_n){
793     case 1:{
794       i=icwrs1(_y,&_k);
795       celt_assert(ncwrs1(_k)==2);
796       ec_enc_bits(_enc,i,1);
797     }break;
798     case 2:{
799       i=icwrs2(_y,&_k);
800       ec_enc_uint(_enc,i,ncwrs2(_k));
801     }break;
802     case 3:{
803       i=icwrs3(_y,&_k);
804       ec_enc_uint(_enc,i,ncwrs3(_k));
805     }break;
806     case 4:{
807       i=icwrs4(_y,&_k);
808       ec_enc_uint(_enc,i,ncwrs4(_k));
809     }break;
810     case 5:{
811       i=icwrs5(_y,&_k);
812       ec_enc_uint(_enc,i,ncwrs5(_k));
813     }break;
814     default:{
815       VARDECL(celt_uint32,u);
816       celt_uint32 nc;
817       SAVE_STACK;
818       ALLOC(u,_k+2U,celt_uint32);
819       i=icwrs(_n,_k,&nc,_y,u);
820       ec_enc_uint(_enc,i,nc);
821       RESTORE_STACK;
822     }break;
823   }
824 }
825
826 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
827 {
828    if (K==0) {
829    } else if(fits_in32(N,K))
830    {
831       encode_pulses32(N, K, _y, enc);
832    } else {
833      int i;
834      int count=0;
835      int split;
836      split = (N+1)/2;
837      for (i=0;i<split;i++)
838         count += abs(_y[i]);
839      ec_enc_uint(enc,count,K+1);
840      encode_pulses(_y, split, count, enc);
841      encode_pulses(_y+split, N-split, K-count, enc);
842    }
843 }
844
845 static inline void decode_pulses32(int _n,int _k,int *_y,ec_dec *_dec){
846   switch(_n){
847     case 1:{
848       celt_assert(ncwrs1(_k)==2);
849       cwrsi1(_k,ec_dec_bits(_dec,1),_y);
850     }break;
851     case 2:cwrsi2(_k,ec_dec_uint(_dec,ncwrs2(_k)),_y);break;
852     case 3:cwrsi3(_k,ec_dec_uint(_dec,ncwrs3(_k)),_y);break;
853     case 4:cwrsi4(_k,ec_dec_uint(_dec,ncwrs4(_k)),_y);break;
854     case 5:cwrsi5(_k,ec_dec_uint(_dec,ncwrs5(_k)),_y);break;
855     default:{
856       VARDECL(celt_uint32,u);
857       SAVE_STACK;
858       ALLOC(u,_k+2U,celt_uint32);
859       cwrsi(_n,_k,ec_dec_uint(_dec,ncwrs_urow(_n,_k,u)),_y,u);
860       RESTORE_STACK;
861     }
862   }
863 }
864
865 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
866 {
867    if (K==0) {
868       int i;
869       for (i=0;i<N;i++)
870          _y[i] = 0;
871    } else if(fits_in32(N,K))
872    {
873       decode_pulses32(N, K, _y, dec);
874    } else {
875      int split;
876      int count = ec_dec_uint(dec,K+1);
877      split = (N+1)/2;
878      decode_pulses(_y, split, count, dec);
879      decode_pulses(_y+split, N-split, K-count, dec);
880    }
881 }