Switch the N=5 case of CWRS to also use a binary search.
[opus.git] / libcelt / cwrs.c
1 /* (C) 2007-2008 Timothy B. Terriberry
2    (C) 2008 Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "os_support.h"
37 #include <stdlib.h>
38 #include <string.h>
39 #include "cwrs.h"
40 #include "mathops.h"
41 #include "arch.h"
42
43 /*Guaranteed to return a conservatively large estimate of the binary logarithm
44    with frac bits of fractional precision.
45   Tested for all possible 32-bit inputs with frac=4, where the maximum
46    overestimation is 0.06254243 bits.*/
47 int log2_frac(ec_uint32 val, int frac)
48 {
49   int l;
50   l=EC_ILOG(val);
51   if(val&val-1){
52     /*This is (val>>l-16), but guaranteed to round up, even if adding a bias
53        before the shift would cause overflow (e.g., for 0xFFFFxxxx).*/
54     if(l>16)val=(val>>l-16)+((val&(1<<l-16)-1)+(1<<l-16)-1>>l-16);
55     else val<<=16-l;
56     l=l-1<<frac;
57     /*Note that we always need one iteration, since the rounding up above means
58        that we might need to adjust the integer part of the logarithm.*/
59     do{
60       int b;
61       b=(int)(val>>16);
62       l+=b<<frac;
63       val=val+b>>b;
64       val=val*val+0x7FFF>>15;
65     }
66     while(frac-->0);
67     /*If val is not exactly 0x8000, then we have to round up the remainder.*/
68     return l+(val>0x8000);
69   }
70   /*Exact powers of two require no rounding.*/
71   else return l-1<<frac;
72 }
73
74 #define MASK32 (0xFFFFFFFF)
75
76 /*INV_TABLE[i] holds the multiplicative inverse of (2*i+1) mod 2**32.*/
77 static const celt_uint32_t INV_TABLE[128]={
78   0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
79   0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
80   0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
81   0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
82   0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
83   0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
84   0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
85   0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
86   0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
87   0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
88   0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
89   0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
90   0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
91   0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
92   0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
93   0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
94   0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
95   0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
96   0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
97   0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
98   0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
99   0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
100   0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
101   0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
102   0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
103   0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
104   0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
105   0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
106   0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
107   0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
108   0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
109   0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
110 };
111
112 /*Computes (_a*_b-_c)/(2*_d+1) when the quotient is known to be exact.
113   _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
114    fits in 32 bits, but currently the table for multiplicative inverses is only
115    valid for _d<128.*/
116 static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
117  celt_uint32_t _c,int _d){
118   return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
119 }
120
121 /*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
122   _d does not actually have to be even, but imusdiv32odd will be faster when
123    it's odd, so you should use that instead.
124   _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
125    table for multiplicative inverses is only valid for _d<=256).
126   _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
127    32 bits.*/
128 static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
129  celt_uint32_t _c,int _d){
130   celt_uint32_t inv;
131   int           mask;
132   int           shift;
133   int           one;
134   celt_assert(_d>0);
135   shift=EC_ILOG(_d^_d-1);
136   celt_assert(_d<=256);
137   inv=INV_TABLE[_d-1>>shift];
138   shift--;
139   one=1<<shift;
140   mask=one-1;
141   return (_a*(_b>>shift)-(_c>>shift)+
142    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
143 }
144
145 /*Compute floor(sqrt(_val)) with exact arithmetic.
146   This has been tested on all possible 32-bit inputs.*/
147 static unsigned isqrt32(celt_uint32_t _val){
148   unsigned b;
149   unsigned g;
150   int      bshift;
151   /*Uses the second method from
152      http://www.azillionmonkeys.com/qed/sqroot.html
153     The main idea is to search for the largest binary digit b such that
154      (g+b)*(g+b) <= _val, and add it to the solution g.*/
155   g=0;
156   bshift=EC_ILOG(_val)-1>>1;
157   b=1U<<bshift;
158   for(;bshift>=0;bshift--){
159     celt_uint32_t t;
160     t=((celt_uint32_t)g<<1)+b<<bshift;
161     if(t<=_val){
162       g+=b;
163       _val-=t;
164     }
165     b>>=1;
166   }
167   return g;
168 }
169
170 #if 0
171 /*Compute floor(sqrt(_val)) with exact arithmetic.
172   This has been tested on all possible 36-bit inputs.*/
173 static celt_uint32_t isqrt36(celt_uint64_t _val){
174   celt_uint32_t val32;
175   celt_uint32_t b;
176   celt_uint32_t g;
177   int           bshift;
178   g=0;
179   b=0x20000;
180   for(bshift=18;bshift-->13;){
181     celt_uint64_t t;
182     t=((celt_uint64_t)g<<1)+b<<bshift;
183     if(t<=_val){
184       g+=b;
185       _val-=t;
186     }
187     b>>=1;
188   }
189   val32=(celt_uint32_t)_val;
190   for(;bshift>=0;bshift--){
191     celt_uint32_t t;
192     t=(g<<1)+b<<bshift;
193     if(t<=val32){
194       g+=b;
195       val32-=t;
196     }
197     b>>=1;
198   }
199   return g;
200 }
201 #endif
202
203 /*Although derived separately, the pulse vector coding scheme is equivalent to
204    a Pyramid Vector Quantizer \cite{Fis86}.
205   Some additional notes about an early version appear at
206    http://people.xiph.org/~tterribe/notes/cwrs.html, but the codebook ordering
207    and the definitions of some terms have evolved since that was written.
208
209   The conversion from a pulse vector to an integer index (encoding) and back
210    (decoding) is governed by two related functions, V(N,K) and U(N,K).
211
212   V(N,K) = the number of combinations, with replacement, of N items, taken K
213    at a time, when a sign bit is added to each item taken at least once (i.e.,
214    the number of N-dimensional unit pulse vectors with K pulses).
215   One way to compute this is via
216     V(N,K) = K>0 ? sum(k=1...K,2**k*choose(N,k)*choose(K-1,k-1)) : 1,
217    where choose() is the binomial function.
218   A table of values for N<10 and K<10 looks like:
219   V[10][10] = {
220     {1,  0,   0,    0,    0,     0,     0,      0,      0,       0},
221     {1,  2,   2,    2,    2,     2,     2,      2,      2,       2},
222     {1,  4,   8,   12,   16,    20,    24,     28,     32,      36},
223     {1,  6,  18,   38,   66,   102,   146,    198,    258,     326},
224     {1,  8,  32,   88,  192,   360,   608,    952,   1408,    1992},
225     {1, 10,  50,  170,  450,  1002,  1970,   3530,   5890,    9290},
226     {1, 12,  72,  292,  912,  2364,  5336,  10836,  20256,   35436},
227     {1, 14,  98,  462, 1666,  4942, 12642,  28814,  59906,  115598},
228     {1, 16, 128,  688, 2816,  9424, 27008,  68464, 157184,  332688},
229     {1, 18, 162,  978, 4482, 16722, 53154, 148626, 374274,  864146}
230   };
231
232   U(N,K) = the number of such combinations wherein N-1 objects are taken at
233    most K-1 at a time.
234   This is given by
235     U(N,K) = sum(k=0...K-1,V(N-1,k))
236            = K>0 ? (V(N-1,K-1) + V(N,K-1))/2 : 0.
237   The latter expression also makes clear that U(N,K) is half the number of such
238    combinations wherein the first object is taken at least once.
239   Although it may not be clear from either of these definitions, U(N,K) is the
240    natural function to work with when enumerating the pulse vector codebooks,
241    not V(N,K).
242   U(N,K) is not well-defined for N=0, but with the extension
243     U(0,K) = K>0 ? 0 : 1,
244    the function becomes symmetric: U(N,K) = U(K,N), with a similar table:
245   U[10][10] = {
246     {1, 0,  0,   0,    0,    0,     0,     0,      0,      0},
247     {0, 1,  1,   1,    1,    1,     1,     1,      1,      1},
248     {0, 1,  3,   5,    7,    9,    11,    13,     15,     17},
249     {0, 1,  5,  13,   25,   41,    61,    85,    113,    145},
250     {0, 1,  7,  25,   63,  129,   231,   377,    575,    833},
251     {0, 1,  9,  41,  129,  321,   681,  1289,   2241,   3649},
252     {0, 1, 11,  61,  231,  681,  1683,  3653,   7183,  13073},
253     {0, 1, 13,  85,  377, 1289,  3653,  8989,  19825,  40081},
254     {0, 1, 15, 113,  575, 2241,  7183, 19825,  48639, 108545},
255     {0, 1, 17, 145,  833, 3649, 13073, 40081, 108545, 265729}
256   };
257
258   With this extension, V(N,K) may be written in terms of U(N,K):
259     V(N,K) = U(N,K) + U(N,K+1)
260    for all N>=0, K>=0.
261   Thus U(N,K+1) represents the number of combinations where the first element
262    is positive or zero, and U(N,K) represents the number of combinations where
263    it is negative.
264   With a large enough table of U(N,K) values, we could write O(N) encoding
265    and O(min(N*log(K),N+K)) decoding routines, but such a table would be
266    prohibitively large for small embedded devices (K may be as large as 32767
267    for small N, and N may be as large as 200).
268
269   Both functions obey the same recurrence relation:
270     V(N,K) = V(N-1,K) + V(N,K-1) + V(N-1,K-1),
271     U(N,K) = U(N-1,K) + U(N,K-1) + U(N-1,K-1),
272    for all N>0, K>0, with different initial conditions at N=0 or K=0.
273   This allows us to construct a row of one of the tables above given the
274    previous row or the next row.
275   Thus we can derive O(NK) encoding and decoding routines with O(K) memory
276    using only addition and subtraction.
277
278   When encoding, we build up from the U(2,K) row and work our way forwards.
279   When decoding, we need to start at the U(N,K) row and work our way backwards,
280    which requires a means of computing U(N,K).
281   U(N,K) may be computed from two previous values with the same N:
282     U(N,K) = ((2*N-1)*U(N,K-1) - U(N,K-2))/(K-1) + U(N,K-2)
283    for all N>1, and since U(N,K) is symmetric, a similar relation holds for two
284    previous values with the same K:
285     U(N,K>1) = ((2*K-1)*U(N-1,K) - U(N-2,K))/(N-1) + U(N-2,K)
286    for all K>1.
287   This allows us to construct an arbitrary row of the U(N,K) table by starting
288    with the first two values, which are constants.
289   This saves roughly 2/3 the work in our O(NK) decoding routine, but costs O(K)
290    multiplications.
291   Similar relations can be derived for V(N,K), but are not used here.
292
293   For N>0 and K>0, U(N,K) and V(N,K) take on the form of an (N-1)-degree
294    polynomial for fixed N.
295   The first few are
296     U(1,K) = 1,
297     U(2,K) = 2*K-1,
298     U(3,K) = (2*K-2)*K+1,
299     U(4,K) = (((4*K-6)*K+8)*K-3)/3,
300     U(5,K) = ((((2*K-4)*K+10)*K-8)*K+3)/3,
301    and
302     V(1,K) = 2,
303     V(2,K) = 4*K,
304     V(3,K) = 4*K*K+2,
305     V(4,K) = 8*(K*K+2)*K/3,
306     V(5,K) = ((4*K*K+20)*K*K+6)/3,
307    for all K>0.
308   This allows us to derive O(N) encoding and O(N*log(K)) decoding routines for
309    small N (and indeed decoding is also O(N) for N<3).
310
311   @ARTICLE{Fis86,
312     author="Thomas R. Fischer",
313     title="A Pyramid Vector Quantizer",
314     journal="IEEE Transactions on Information Theory",
315     volume="IT-32",
316     number=4,
317     pages="568--583",
318     month=Jul,
319     year=1986
320   }*/
321
322 /*Determines if V(N,K) fits in a 32-bit unsigned integer.
323   N and K are themselves limited to 15 bits.*/
324 int fits_in32(int _n, int _k)
325 {
326    static const celt_int16_t maxN[15] = {
327       32767, 32767, 32767, 1476, 283, 109,  60,  40,
328        29,  24,  20,  18,  16,  14,  13};
329    static const celt_int16_t maxK[15] = {
330       32767, 32767, 32767, 32767, 1172, 238,  95,  53,
331        36,  27,  22,  18,  16,  15,  13};
332    if (_n>=14)
333    {
334       if (_k>=14)
335          return 0;
336       else
337          return _n <= maxN[_k];
338    } else {
339       return _k <= maxK[_n];
340    }
341 }
342
343 /*Compute U(1,_k).*/
344 static inline unsigned ucwrs1(int _k){
345   return _k?1:0;
346 }
347
348 /*Compute V(1,_k).*/
349 static inline unsigned ncwrs1(int _k){
350   return _k?2:1;
351 }
352
353 /*Compute U(2,_k).
354   Note that this may be called with _k=32768 (maxK[2]+1).*/
355 static inline unsigned ucwrs2(unsigned _k){
356   return _k?_k+(_k-1):0;
357 }
358
359 /*Compute V(2,_k).*/
360 static inline celt_uint32_t ncwrs2(int _k){
361   return _k?4*(celt_uint32_t)_k:1;
362 }
363
364 /*Compute U(3,_k).
365   Note that this may be called with _k=32768 (maxK[3]+1).*/
366 static inline celt_uint32_t ucwrs3(unsigned _k){
367   return _k?(2*(celt_uint32_t)_k-2)*_k+1:0;
368 }
369
370 /*Compute V(3,_k).*/
371 static inline celt_uint32_t ncwrs3(int _k){
372   return _k?2*(2*(unsigned)_k*(celt_uint32_t)_k+1):1;
373 }
374
375 /*Compute U(4,_k).*/
376 static inline celt_uint32_t ucwrs4(int _k){
377   return _k?imusdiv32odd(2*_k,(2*_k-3)*(celt_uint32_t)_k+4,3,1):0;
378 }
379
380 /*Compute V(4,_k).*/
381 static inline celt_uint32_t ncwrs4(int _k){
382   return _k?((_k*(celt_uint32_t)_k+2)*_k)/3<<3:1;
383 }
384
385 /*Compute U(5,_k).*/
386 static inline celt_uint32_t ucwrs5(int _k){
387   return _k?(((((_k-2)*(unsigned)_k+5)*(celt_uint32_t)_k-4)*_k)/3<<1)+1:0;
388 }
389
390 /*Compute V(5,_k).*/
391 static inline celt_uint32_t ncwrs5(int _k){
392   return _k?(((_k*(unsigned)_k+5)*(celt_uint32_t)_k*_k)/3<<2)+2:1;
393 }
394
395 /*Computes the next row/column of any recurrence that obeys the relation
396    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
397   _ui0 is the base case for the new row/column.*/
398 static inline void unext(celt_uint32_t *_ui,unsigned _len,celt_uint32_t _ui0){
399   celt_uint32_t ui1;
400   unsigned      j;
401   /*This do-while will overrun the array if we don't have storage for at least
402      2 values.*/
403   j=1; do {
404     ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0);
405     _ui[j-1]=_ui0;
406     _ui0=ui1;
407   } while (++j<_len);
408   _ui[j-1]=_ui0;
409 }
410
411 /*Computes the previous row/column of any recurrence that obeys the relation
412    u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
413   _ui0 is the base case for the new row/column.*/
414 static inline void uprev(celt_uint32_t *_ui,unsigned _n,celt_uint32_t _ui0){
415   celt_uint32_t ui1;
416   unsigned      j;
417   /*This do-while will overrun the array if we don't have storage for at least
418      2 values.*/
419   j=1; do {
420     ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0);
421     _ui[j-1]=_ui0;
422     _ui0=ui1;
423   } while (++j<_n);
424   _ui[j-1]=_ui0;
425 }
426
427 /*Compute V(_n,_k), as well as U(_n,0..._k+1).
428   _u: On exit, _u[i] contains U(_n,i) for i in [0..._k+1].*/
429 static celt_uint32_t ncwrs_urow(unsigned _n,unsigned _k,celt_uint32_t *_u){
430   celt_uint32_t um2;
431   unsigned      len;
432   unsigned      k;
433   len=_k+2;
434   /*We require storage at least 3 values (e.g., _k>0).*/
435   celt_assert(len>=3);
436   _u[0]=0;
437   _u[1]=um2=1;
438   if(_n<=6 || _k>255){
439     /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
440     /*If _n==1, _u[i] should be 1 for i>1.*/
441     celt_assert(_n>=2);
442     /*If _k==0, the following do-while loop will overflow the buffer.*/
443     celt_assert(_k>0);
444     k=2;
445     do _u[k]=(k<<1)-1;
446     while(++k<len);
447     for(k=2;k<_n;k++)unext(_u+1,_k+1,1);
448   }
449   else{
450     celt_uint32_t um1;
451     celt_uint32_t n2m1;
452     _u[2]=n2m1=um1=(_n<<1)-1;
453     for(k=3;k<len;k++){
454       /*U(N,K) = ((2*N-1)*U(N,K-1)-U(N,K-2))/(K-1) + U(N,K-2)*/
455       _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
456       if(++k>=len)break;
457       _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
458     }
459   }
460   return _u[_k]+_u[_k+1];
461 }
462
463
464 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
465    set of size 1 with associated sign bits.
466   _y: Returns the vector of pulses.*/
467 static inline void cwrsi1(int _k,celt_uint32_t _i,int *_y){
468   int s;
469   s=-(int)_i;
470   _y[0]=_k+s^s;
471 }
472
473 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
474    set of size 2 with associated sign bits.
475   _y: Returns the vector of pulses.*/
476 static inline void cwrsi2(int _k,celt_uint32_t _i,int *_y){
477   celt_uint32_t p;
478   int           s;
479   int           yj;
480   p=ucwrs2(_k+1U);
481   s=-(_i>=p);
482   _i-=p&s;
483   yj=_k;
484   _k=_i+1>>1;
485   p=ucwrs2(_k);
486   _i-=p;
487   yj-=_k;
488   _y[0]=yj+s^s;
489   cwrsi1(_k,_i,_y+1);
490 }
491
492 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
493    set of size 3 with associated sign bits.
494   _y: Returns the vector of pulses.*/
495 static void cwrsi3(int _k,celt_uint32_t _i,int *_y){
496   celt_uint32_t p;
497   int           s;
498   int           yj;
499   p=ucwrs3(_k+1U);
500   s=-(_i>=p);
501   _i-=p&s;
502   yj=_k;
503   /*Finds the maximum _k such that ucwrs3(_k)<=_i (tested for all
504      _i<2147418113=U(3,32768)).*/
505   _k=_i>0?isqrt32(2*_i-1)+1>>1:0;
506   p=ucwrs3(_k);
507   _i-=p;
508   yj-=_k;
509   _y[0]=yj+s^s;
510   cwrsi2(_k,_i,_y+1);
511 }
512
513 /*Returns the _i'th combination of _k elements (at most 1172) chosen from a set
514    of size 4 with associated sign bits.
515   _y: Returns the vector of pulses.*/
516 static void cwrsi4(int _k,celt_uint32_t _i,int *_y){
517   celt_uint32_t p;
518   int           s;
519   int           yj;
520   int           kl;
521   int           kr;
522   p=ucwrs4(_k+1);
523   s=-(_i>=p);
524   _i-=p&s;
525   yj=_k;
526   /*We could solve a cubic for k here, but the form of the direct solution does
527      not lend itself well to exact integer arithmetic.
528     Instead we do a binary search on U(4,K).*/
529   kl=0;
530   kr=_k;
531   for(;;){
532     _k=kl+kr>>1;
533     p=ucwrs4(_k);
534     if(p<_i){
535       if(_k>=kr)break;
536       kl=_k+1;
537     }
538     else if(p>_i)kr=_k-1;
539     else break;
540   }
541   _i-=p;
542   yj-=_k;
543   _y[0]=yj+s^s;
544   cwrsi3(_k,_i,_y+1);
545 }
546
547 /*Returns the _i'th combination of _k elements (at most 238) chosen from a set
548    of size 5 with associated sign bits.
549   _y: Returns the vector of pulses.*/
550 static void cwrsi5(int _k,celt_uint32_t _i,int *_y){
551   celt_uint32_t p;
552   int           s;
553   int           yj;
554   p=ucwrs5(_k+1);
555   s=-(_i>=p);
556   _i-=p&s;
557   yj=_k;
558 #if 0
559   /*Finds the maximum _k such that ucwrs5(_k)<=_i (tested for all
560      _i<2157192969=U(5,239)).*/
561   if(_i>=0x2AAAAAA9UL)_k=isqrt32(2*isqrt36(10+6*(celt_uint64_t)_i)-7)+1>>1;
562   else _k=_i>0?isqrt32(2*(celt_uint32_t)isqrt32(10+6*_i)-7)+1>>1:0;
563   p=ucwrs5(_k);
564 #else 
565   /* A binary search on U(5,K) avoids the need for 64-bit arithmetic */
566   {
567     int kl=0;
568     int kr=_k;
569     for(;;){
570       _k=kl+kr>>1;
571       p=ucwrs5(_k);
572       if(p<_i){
573         if(_k>=kr)break;
574         kl=_k+1;
575       }
576       else if(p>_i)kr=_k-1;
577       else break;
578     }  
579   }
580 #endif
581   _i-=p;
582   yj-=_k;
583   _y[0]=yj+s^s;
584   cwrsi4(_k,_i,_y+1);
585 }
586
587 /*Returns the _i'th combination of _k elements chosen from a set of size _n
588    with associated sign bits.
589   _y: Returns the vector of pulses.
590   _u: Must contain entries [0..._k+1] of row _n of U() on input.
591       Its contents will be destructively modified.*/
592 static void cwrsi(int _n,int _k,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
593   int j;
594   celt_assert(_n>0);
595   j=0;
596   do{
597     celt_uint32_t p;
598     int           s;
599     int           yj;
600     p=_u[_k+1];
601     s=-(_i>=p);
602     _i-=p&s;
603     yj=_k;
604     p=_u[_k];
605     while(p>_i)p=_u[--_k];
606     _i-=p;
607     yj-=_k;
608     _y[j]=yj+s^s;
609     uprev(_u,_k+2,0);
610   }
611   while(++j<_n);
612 }
613
614
615 /*Returns the index of the given combination of K elements chosen from a set
616    of size 1 with associated sign bits.
617   _y: The vector of pulses, whose sum of absolute values is K.
618   _k: Returns K.*/
619 static inline celt_uint32_t icwrs1(const int *_y,int *_k){
620   *_k=abs(_y[0]);
621   return _y[0]<0;
622 }
623
624 /*Returns the index of the given combination of K elements chosen from a set
625    of size 2 with associated sign bits.
626   _y: The vector of pulses, whose sum of absolute values is K.
627   _k: Returns K.*/
628 static inline celt_uint32_t icwrs2(const int *_y,int *_k){
629   celt_uint32_t i;
630   int           k;
631   i=icwrs1(_y+1,&k);
632   i+=ucwrs2(k);
633   k+=abs(_y[0]);
634   if(_y[0]<0)i+=ucwrs2(k+1U);
635   *_k=k;
636   return i;
637 }
638
639 /*Returns the index of the given combination of K elements chosen from a set
640    of size 3 with associated sign bits.
641   _y: The vector of pulses, whose sum of absolute values is K.
642   _k: Returns K.*/
643 static inline celt_uint32_t icwrs3(const int *_y,int *_k){
644   celt_uint32_t i;
645   int           k;
646   i=icwrs2(_y+1,&k);
647   i+=ucwrs3(k);
648   k+=abs(_y[0]);
649   if(_y[0]<0)i+=ucwrs3(k+1U);
650   *_k=k;
651   return i;
652 }
653
654 /*Returns the index of the given combination of K elements chosen from a set
655    of size 4 with associated sign bits.
656   _y: The vector of pulses, whose sum of absolute values is K.
657   _k: Returns K.*/
658 static inline celt_uint32_t icwrs4(const int *_y,int *_k){
659   celt_uint32_t i;
660   int           k;
661   i=icwrs3(_y+1,&k);
662   i+=ucwrs4(k);
663   k+=abs(_y[0]);
664   if(_y[0]<0)i+=ucwrs4(k+1);
665   *_k=k;
666   return i;
667 }
668
669 /*Returns the index of the given combination of K elements chosen from a set
670    of size 5 with associated sign bits.
671   _y: The vector of pulses, whose sum of absolute values is K.
672   _k: Returns K.*/
673 static inline celt_uint32_t icwrs5(const int *_y,int *_k){
674   celt_uint32_t i;
675   int           k;
676   i=icwrs4(_y+1,&k);
677   i+=ucwrs5(k);
678   k+=abs(_y[0]);
679   if(_y[0]<0)i+=ucwrs5(k+1);
680   *_k=k;
681   return i;
682 }
683
684 /*Returns the index of the given combination of K elements chosen from a set
685    of size _n with associated sign bits.
686   _y:  The vector of pulses, whose sum of absolute values must be _k.
687   _nc: Returns V(_n,_k).*/
688 celt_uint32_t icwrs(int _n,int _k,celt_uint32_t *_nc,const int *_y,
689  celt_uint32_t *_u){
690   celt_uint32_t i;
691   int           j;
692   int           k;
693   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
694   celt_assert(_n>=2);
695   _u[0]=0;
696   for(k=1;k<=_k+1;k++)_u[k]=(k<<1)-1;
697   i=icwrs1(_y+_n-1,&k);
698   j=_n-2;
699   i+=_u[k];
700   k+=abs(_y[j]);
701   if(_y[j]<0)i+=_u[k+1];
702   while(j-->0){
703     unext(_u,_k+2,0);
704     i+=_u[k];
705     k+=abs(_y[j]);
706     if(_y[j]<0)i+=_u[k+1];
707   }
708   *_nc=_u[k]+_u[k+1];
709   return i;
710 }
711
712
713 /*Computes get_required_bits when splitting is required.
714   _left_bits and _right_bits must contain the required bits for the left and
715    right sides of the split, respectively (which themselves may require
716    splitting).*/
717 static void get_required_split_bits(celt_int16_t *_bits,
718  const celt_int16_t *_left_bits,const celt_int16_t *_right_bits,
719  int _n,int _maxk,int _frac){
720   int k;
721   for(k=_maxk;k-->0;){
722     /*If we've reached a k where everything fits in 32 bits, evaluate the
723        remaining required bits directly.*/
724     if(fits_in32(_n,k)){
725       get_required_bits(_bits,_n,k+1,_frac);
726       break;
727     }
728     else{
729       int worst_bits;
730       int i;
731       /*Due to potentially recursive splitting, it's difficult to derive an
732          analytic expression for the location of the worst-case split index.
733         We simply check them all.*/
734       worst_bits=0;
735       for(i=0;i<=k;i++){
736         int split_bits;
737         split_bits=_left_bits[i]+_right_bits[k-i];
738         if(split_bits>worst_bits)worst_bits=split_bits;
739       }
740       _bits[k]=log2_frac(k+1,_frac)+worst_bits;
741     }
742   }
743 }
744
745 /*Computes get_required_bits for a pair of N values.
746   _n1 and _n2 must either be equal or two consecutive integers.
747   Returns the buffer used to store the required bits for _n2, which is either
748    _bits1 if _n1==_n2 or _bits2 if _n1+1==_n2.*/
749 static celt_int16_t *get_required_bits_pair(celt_int16_t *_bits1,
750  celt_int16_t *_bits2,celt_int16_t *_tmp,int _n1,int _n2,int _maxk,int _frac){
751   celt_int16_t *tmp2;
752   /*If we only need a single set of required bits...*/
753   if(_n1==_n2){
754     /*Stop recursing if everything fits.*/
755     if(fits_in32(_n1,_maxk-1))get_required_bits(_bits1,_n1,_maxk,_frac);
756     else{
757       _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
758        _n1>>1,_n1+1>>1,_maxk,_frac);
759       get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
760     }
761     return _bits1;
762   }
763   /*Otherwise we need two distinct sets...*/
764   celt_assert(_n1+1==_n2);
765   /*Stop recursing if everything fits.*/
766   if(fits_in32(_n2,_maxk-1)){
767     get_required_bits(_bits1,_n1,_maxk,_frac);
768     get_required_bits(_bits2,_n2,_maxk,_frac);
769   }
770   /*Otherwise choose an evaluation order that doesn't require extra buffers.*/
771   else if(_n1&1){
772     /*This special case isn't really needed, but can save some work.*/
773     if(fits_in32(_n1,_maxk-1)){
774       tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
775        _n2>>1,_n2>>1,_maxk,_frac);
776       get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
777       get_required_bits(_bits1,_n1,_maxk,_frac);
778     }
779     else{
780       _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
781        _n1>>1,_n1+1>>1,_maxk,_frac);
782       get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
783       get_required_split_bits(_bits2,_tmp,_tmp,_n2,_maxk,_frac);
784     }
785   }
786   else{
787     /*There's no need to special case _n1 fitting by itself, since _n2 requires
788        us to recurse for both values anyway.*/
789     tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
790      _n2>>1,_n2+1>>1,_maxk,_frac);
791     get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
792     get_required_split_bits(_bits1,_tmp,_tmp,_n1,_maxk,_frac);
793   }
794   return _bits2;
795 }
796
797 void get_required_bits(celt_int16_t *_bits,int _n,int _maxk,int _frac){
798   int k;
799   /*_maxk==0 => there's nothing to do.*/
800   celt_assert(_maxk>0);
801   if(fits_in32(_n,_maxk-1)){
802     _bits[0]=0;
803     if(_maxk>1){
804       VARDECL(celt_uint32_t,u);
805       SAVE_STACK;
806       ALLOC(u,_maxk+1U,celt_uint32_t);
807       ncwrs_urow(_n,_maxk-1,u);
808       for(k=1;k<_maxk;k++)_bits[k]=log2_frac(u[k]+u[k+1],_frac);
809       RESTORE_STACK;
810     }
811   }
812   else{
813     VARDECL(celt_int16_t,n1bits);
814     VARDECL(celt_int16_t,n2bits_buf);
815     celt_int16_t *n2bits;
816     SAVE_STACK;
817     ALLOC(n1bits,_maxk,celt_int16_t);
818     ALLOC(n2bits_buf,_maxk,celt_int16_t);
819     n2bits=get_required_bits_pair(n1bits,n2bits_buf,_bits,
820      _n>>1,_n+1>>1,_maxk,_frac);
821     get_required_split_bits(_bits,n1bits,n2bits,_n,_maxk,_frac);
822     RESTORE_STACK;
823   }
824 }
825
826
827 static inline void encode_pulses32(int _n,int _k,const int *_y,ec_enc *_enc){
828   celt_uint32_t i;
829   switch(_n){
830     case 1:{
831       i=icwrs1(_y,&_k);
832       celt_assert(ncwrs1(_k)==2);
833       ec_enc_bits(_enc,i,1);
834     }break;
835     case 2:{
836       i=icwrs2(_y,&_k);
837       ec_enc_uint(_enc,i,ncwrs2(_k));
838     }break;
839     case 3:{
840       i=icwrs3(_y,&_k);
841       ec_enc_uint(_enc,i,ncwrs3(_k));
842     }break;
843     case 4:{
844       i=icwrs4(_y,&_k);
845       ec_enc_uint(_enc,i,ncwrs4(_k));
846     }break;
847     case 5:{
848       i=icwrs5(_y,&_k);
849       ec_enc_uint(_enc,i,ncwrs5(_k));
850     }break;
851     default:{
852       VARDECL(celt_uint32_t,u);
853       celt_uint32_t nc;
854       SAVE_STACK;
855       ALLOC(u,_k+2U,celt_uint32_t);
856       i=icwrs(_n,_k,&nc,_y,u);
857       ec_enc_uint(_enc,i,nc);
858       RESTORE_STACK;
859     }break;
860   }
861 }
862
863 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
864 {
865    if (K==0) {
866    } else if(fits_in32(N,K))
867    {
868       encode_pulses32(N, K, _y, enc);
869    } else {
870      int i;
871      int count=0;
872      int split;
873      split = (N+1)/2;
874      for (i=0;i<split;i++)
875         count += abs(_y[i]);
876      ec_enc_uint(enc,count,K+1);
877      encode_pulses(_y, split, count, enc);
878      encode_pulses(_y+split, N-split, K-count, enc);
879    }
880 }
881
882 static inline void decode_pulses32(int _n,int _k,int *_y,ec_dec *_dec){
883   switch(_n){
884     case 1:{
885       celt_assert(ncwrs1(_k)==2);
886       cwrsi1(_k,ec_dec_bits(_dec,1),_y);
887     }break;
888     case 2:cwrsi2(_k,ec_dec_uint(_dec,ncwrs2(_k)),_y);break;
889     case 3:cwrsi3(_k,ec_dec_uint(_dec,ncwrs3(_k)),_y);break;
890     case 4:cwrsi4(_k,ec_dec_uint(_dec,ncwrs4(_k)),_y);break;
891     case 5:cwrsi5(_k,ec_dec_uint(_dec,ncwrs5(_k)),_y);break;
892     default:{
893       VARDECL(celt_uint32_t,u);
894       SAVE_STACK;
895       ALLOC(u,_k+2U,celt_uint32_t);
896       cwrsi(_n,_k,ec_dec_uint(_dec,ncwrs_urow(_n,_k,u)),_y,u);
897       RESTORE_STACK;
898     }
899   }
900 }
901
902 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
903 {
904    if (K==0) {
905       int i;
906       for (i=0;i<N;i++)
907          _y[i] = 0;
908    } else if(fits_in32(N,K))
909    {
910       decode_pulses32(N, K, _y, dec);
911    } else {
912      int split;
913      int count = ec_dec_uint(dec,K+1);
914      split = (N+1)/2;
915      decode_pulses(_y, split, count, dec);
916      decode_pulses(_y+split, N-split, K-count, dec);
917    }
918 }