Apply band caps to the band allocation table.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 celt_uint32 lcg_rand(celt_uint32 seed)
49 {
50    return 1664525 * seed + 1013904223;
51 }
52
53 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
54    with this approximation is important because it has an impact on the bit allocation */
55 static celt_int16 bitexact_cos(celt_int16 x)
56 {
57    celt_int32 tmp;
58    celt_int16 x2;
59    tmp = (4096+((celt_int32)(x)*(x)))>>13;
60    if (tmp > 32767)
61       tmp = 32767;
62    x2 = tmp;
63    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
64    if (x2 > 32766)
65       x2 = 32766;
66    return 1+x2;
67 }
68
69 static int bitexact_log2tan(int isin,int icos)
70 {
71    int lc;
72    int ls;
73    lc=EC_ILOG(icos);
74    ls=EC_ILOG(isin);
75    icos<<=15-lc;
76    isin<<=15-ls;
77    return (ls-lc<<11)
78          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
79          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
80 }
81
82 #ifdef FIXED_POINT
83 /* Compute the amplitude (sqrt energy) in each of the bands */
84 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
85 {
86    int i, c, N;
87    const celt_int16 *eBands = m->eBands;
88    const int C = CHANNELS(_C);
89    N = M*m->shortMdctSize;
90    c=0; do {
91       for (i=0;i<end;i++)
92       {
93          int j;
94          celt_word32 maxval=0;
95          celt_word32 sum = 0;
96          
97          j=M*eBands[i]; do {
98             maxval = MAX32(maxval, X[j+c*N]);
99             maxval = MAX32(maxval, -X[j+c*N]);
100          } while (++j<M*eBands[i+1]);
101          
102          if (maxval > 0)
103          {
104             int shift = celt_ilog2(maxval)-10;
105             j=M*eBands[i]; do {
106                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
107                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
108             } while (++j<M*eBands[i+1]);
109             /* We're adding one here to make damn sure we never end up with a pitch vector that's
110                larger than unity norm */
111             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
112          } else {
113             bank[i+c*m->nbEBands] = EPSILON;
114          }
115          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
116       }
117    } while (++c<C);
118    /*printf ("\n");*/
119 }
120
121 /* Normalise each band such that the energy is one. */
122 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
123 {
124    int i, c, N;
125    const celt_int16 *eBands = m->eBands;
126    const int C = CHANNELS(_C);
127    N = M*m->shortMdctSize;
128    c=0; do {
129       i=0; do {
130          celt_word16 g;
131          int j,shift;
132          celt_word16 E;
133          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
134          E = VSHR32(bank[i+c*m->nbEBands], shift);
135          g = EXTRACT16(celt_rcp(SHL32(E,3)));
136          j=M*eBands[i]; do {
137             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
138          } while (++j<M*eBands[i+1]);
139       } while (++i<end);
140    } while (++c<C);
141 }
142
143 #else /* FIXED_POINT */
144 /* Compute the amplitude (sqrt energy) in each of the bands */
145 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
146 {
147    int i, c, N;
148    const celt_int16 *eBands = m->eBands;
149    const int C = CHANNELS(_C);
150    N = M*m->shortMdctSize;
151    c=0; do {
152       for (i=0;i<end;i++)
153       {
154          int j;
155          celt_word32 sum = 1e-27f;
156          for (j=M*eBands[i];j<M*eBands[i+1];j++)
157             sum += X[j+c*N]*X[j+c*N];
158          bank[i+c*m->nbEBands] = celt_sqrt(sum);
159          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
160       }
161    } while (++c<C);
162    /*printf ("\n");*/
163 }
164
165 /* Normalise each band such that the energy is one. */
166 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
167 {
168    int i, c, N;
169    const celt_int16 *eBands = m->eBands;
170    const int C = CHANNELS(_C);
171    N = M*m->shortMdctSize;
172    c=0; do {
173       for (i=0;i<end;i++)
174       {
175          int j;
176          celt_word16 g = 1.f/(1e-27f+bank[i+c*m->nbEBands]);
177          for (j=M*eBands[i];j<M*eBands[i+1];j++)
178             X[j+c*N] = freq[j+c*N]*g;
179       }
180    } while (++c<C);
181 }
182
183 #endif /* FIXED_POINT */
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    c=0; do {
194       celt_sig * restrict f;
195       const celt_norm * restrict x;
196       f = freq+c*N;
197       x = X+c*N;
198       for (i=0;i<end;i++)
199       {
200          int j, band_end;
201          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
202          j=M*eBands[i];
203          band_end = M*eBands[i+1];
204          do {
205             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
206             x++;
207          } while (++j<band_end);
208       }
209       for (i=M*eBands[m->nbEBands];i<N;i++)
210          *f++ = 0;
211    } while (++c<C);
212 }
213
214 /* This prevents energy collapse for transients with multiple short MDCTs */
215 void anti_collapse(const CELTMode *m, celt_norm *_X, unsigned char *collapse_masks, int LM, int C, int size,
216       int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
217       celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
218 {
219    int c, i, j, k;
220    for (i=start;i<end;i++)
221    {
222       int N0;
223       celt_word16 thresh, sqrt_1;
224       int depth;
225 #ifdef FIXED_POINT
226       int shift;
227 #endif
228
229       N0 = m->eBands[i+1]-m->eBands[i];
230       /* depth in 1/8 bits */
231       depth = (1+pulses[i])/(m->eBands[i+1]-m->eBands[i]<<LM);
232
233 #ifdef FIXED_POINT
234       thresh = MULT16_32_Q15(QCONST16(0.5f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 10-BITRES)),1) ));
235       {
236          celt_word32 t;
237          t = N0<<LM;
238          shift = celt_ilog2(t)>>1;
239          t = SHL32(t, (7-shift)<<1);
240          sqrt_1 = celt_rsqrt_norm(t);
241       }
242 #else
243       thresh = .5f*celt_exp2(-.125f*depth);
244       sqrt_1 = celt_rsqrt(N0<<LM);
245 #endif
246
247       c=0; do
248       {
249          celt_norm *X;
250          celt_word16 Ediff;
251          celt_word16 r;
252          Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
253          Ediff = MAX16(0, Ediff);
254
255 #ifdef FIXED_POINT
256          if (Ediff < 16384)
257             r = 2*MIN16(16383,SHR32(celt_exp2(-Ediff),1));
258          else
259             r = 0;
260          if (LM==3)
261             r = MULT16_16_Q14(23170, MIN32(23169, r));
262          r = SHR16(MIN16(thresh, r),1);
263          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
264 #else
265          /* r needs to be multiplied by 2 or 2*sqrt(2) depending on LM because
266             short blocks don't have the same energy as long */
267          r = 2.f*celt_exp2(-Ediff);
268          if (LM==3)
269             r *= 1.41421356f;
270          r = MIN16(thresh, r);
271          r = r*sqrt_1;
272 #endif
273          X = _X+c*size+(m->eBands[i]<<LM);
274          for (k=0;k<1<<LM;k++)
275          {
276             /* Detect collapse */
277             if (!(collapse_masks[i*C+c]&1<<k))
278             {
279                /* Fill with noise */
280                for (j=0;j<N0;j++)
281                {
282                   seed = lcg_rand(seed);
283                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
284                }
285             }
286          }
287          /* We just added some energy, so we need to renormalise */
288          renormalise_vector(X, N0<<LM, Q15ONE);
289       } while (++c<C);
290    }
291
292 }
293
294
295 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
296 {
297    int i = bandID;
298    int j;
299    celt_word16 a1, a2;
300    celt_word16 left, right;
301    celt_word16 norm;
302 #ifdef FIXED_POINT
303    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
304 #endif
305    left = VSHR32(bank[i],shift);
306    right = VSHR32(bank[i+m->nbEBands],shift);
307    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
308    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
309    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
310    for (j=0;j<N;j++)
311    {
312       celt_norm r, l;
313       l = X[j];
314       r = Y[j];
315       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
316       /* Side is not encoded, no need to calculate */
317    }
318 }
319
320 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
321 {
322    int j;
323    for (j=0;j<N;j++)
324    {
325       celt_norm r, l;
326       l = MULT16_16_Q15(QCONST16(.70710678f,15), X[j]);
327       r = MULT16_16_Q15(QCONST16(.70710678f,15), Y[j]);
328       X[j] = l+r;
329       Y[j] = r-l;
330    }
331 }
332
333 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, int N)
334 {
335    int j;
336    celt_word32 xp=0, side=0;
337    celt_word32 El, Er;
338    celt_word16 mid2;
339 #ifdef FIXED_POINT
340    int kl, kr;
341 #endif
342    celt_word32 t, lgain, rgain;
343
344    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
345    for (j=0;j<N;j++)
346    {
347       xp = MAC16_16(xp, X[j], Y[j]);
348       side = MAC16_16(side, Y[j], Y[j]);
349    }
350    /* Compensating for the mid normalization */
351    xp = MULT16_32_Q15(mid, xp);
352    /* mid and side are in Q15, not Q14 like X and Y */
353    mid2 = SHR32(mid, 1);
354    El = MULT16_16(mid2, mid2) + side - 2*xp;
355    Er = MULT16_16(mid2, mid2) + side + 2*xp;
356    if (Er < QCONST32(6e-4f, 28) || El < QCONST32(6e-4f, 28))
357    {
358       for (j=0;j<N;j++)
359          Y[j] = X[j];
360       return;
361    }
362
363 #ifdef FIXED_POINT
364    kl = celt_ilog2(El)>>1;
365    kr = celt_ilog2(Er)>>1;
366 #endif
367    t = VSHR32(El, (kl-7)<<1);
368    lgain = celt_rsqrt_norm(t);
369    t = VSHR32(Er, (kr-7)<<1);
370    rgain = celt_rsqrt_norm(t);
371
372 #ifdef FIXED_POINT
373    if (kl < 7)
374       kl = 7;
375    if (kr < 7)
376       kr = 7;
377 #endif
378
379    for (j=0;j<N;j++)
380    {
381       celt_norm r, l;
382       /* Apply mid scaling (side is already scaled) */
383       l = MULT16_16_Q15(mid, X[j]);
384       r = Y[j];
385       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
386       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
387    }
388 }
389
390 /* Decide whether we should spread the pulses in the current frame */
391 int spreading_decision(const CELTMode *m, celt_norm *X, int *average,
392       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
393       int end, int _C, int M)
394 {
395    int i, c, N0;
396    int sum = 0, nbBands=0;
397    const int C = CHANNELS(_C);
398    const celt_int16 * restrict eBands = m->eBands;
399    int decision;
400    int hf_sum=0;
401    
402    N0 = M*m->shortMdctSize;
403
404    if (M*(eBands[end]-eBands[end-1]) <= 8)
405       return SPREAD_NONE;
406    c=0; do {
407       for (i=0;i<end;i++)
408       {
409          int j, N, tmp=0;
410          int tcount[3] = {0};
411          celt_norm * restrict x = X+M*eBands[i]+c*N0;
412          N = M*(eBands[i+1]-eBands[i]);
413          if (N<=8)
414             continue;
415          /* Compute rough CDF of |x[j]| */
416          for (j=0;j<N;j++)
417          {
418             celt_word32 x2N; /* Q13 */
419
420             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
421             if (x2N < QCONST16(0.25f,13))
422                tcount[0]++;
423             if (x2N < QCONST16(0.0625f,13))
424                tcount[1]++;
425             if (x2N < QCONST16(0.015625f,13))
426                tcount[2]++;
427          }
428
429          /* Only include four last bands (8 kHz and up) */
430          if (i>m->nbEBands-4)
431             hf_sum += 32*(tcount[1]+tcount[0])/N;
432          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
433          sum += tmp*256;
434          nbBands++;
435       }
436    } while (++c<C);
437
438    if (update_hf)
439    {
440       if (hf_sum)
441          hf_sum /= C*(4-m->nbEBands+end);
442       *hf_average = (*hf_average+hf_sum)>>1;
443       hf_sum = *hf_average;
444       if (*tapset_decision==2)
445          hf_sum += 4;
446       else if (*tapset_decision==0)
447          hf_sum -= 4;
448       if (hf_sum > 22)
449          *tapset_decision=2;
450       else if (hf_sum > 18)
451          *tapset_decision=1;
452       else
453          *tapset_decision=0;
454    }
455    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
456    sum /= nbBands;
457    /* Recursive averaging */
458    sum = (sum+*average)>>1;
459    *average = sum;
460    /* Hysteresis */
461    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
462    if (sum < 80)
463    {
464       decision = SPREAD_AGGRESSIVE;
465    } else if (sum < 256)
466    {
467       decision = SPREAD_NORMAL;
468    } else if (sum < 384)
469    {
470       decision = SPREAD_LIGHT;
471    } else {
472       decision = SPREAD_NONE;
473    }
474    return decision;
475 }
476
477 #ifdef MEASURE_NORM_MSE
478
479 float MSE[30] = {0};
480 int nbMSEBands = 0;
481 int MSECount[30] = {0};
482
483 void dump_norm_mse(void)
484 {
485    int i;
486    for (i=0;i<nbMSEBands;i++)
487    {
488       printf ("%g ", MSE[i]/MSECount[i]);
489    }
490    printf ("\n");
491 }
492
493 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
494 {
495    static int init = 0;
496    int i;
497    if (!init)
498    {
499       atexit(dump_norm_mse);
500       init = 1;
501    }
502    for (i=0;i<m->nbEBands;i++)
503    {
504       int j;
505       int c;
506       float g;
507       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
508          continue;
509       c=0; do {
510          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
511          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
512             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
513       } while (++c<C);
514       MSECount[i]+=C;
515    }
516    nbMSEBands = m->nbEBands;
517 }
518
519 #endif
520
521 /* Indexing table for converting from natural Hadamard to ordery Hadamard
522    This is essentially a bit-reversed Gray, on top of which we've added
523    an inversion of the order because we want the DC at the end rather than
524    the beginning. The lines are for N=2, 4, 8, 16 */
525 static const int ordery_table[] = {
526        1,  0,
527        3,  0,  2,  1,
528        7,  0,  4,  3,  6,  1,  5,  2,
529       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
530 };
531
532 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
533 {
534    int i,j;
535    VARDECL(celt_norm, tmp);
536    int N;
537    SAVE_STACK;
538    N = N0*stride;
539    ALLOC(tmp, N, celt_norm);
540    if (hadamard)
541    {
542       const int *ordery = ordery_table+stride-2;
543       for (i=0;i<stride;i++)
544       {
545          for (j=0;j<N0;j++)
546             tmp[ordery[i]*N0+j] = X[j*stride+i];
547       }
548    } else {
549       for (i=0;i<stride;i++)
550          for (j=0;j<N0;j++)
551             tmp[i*N0+j] = X[j*stride+i];
552    }
553    for (j=0;j<N;j++)
554       X[j] = tmp[j];
555    RESTORE_STACK;
556 }
557
558 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
559 {
560    int i,j;
561    VARDECL(celt_norm, tmp);
562    int N;
563    SAVE_STACK;
564    N = N0*stride;
565    ALLOC(tmp, N, celt_norm);
566    if (hadamard)
567    {
568       const int *ordery = ordery_table+stride-2;
569       for (i=0;i<stride;i++)
570          for (j=0;j<N0;j++)
571             tmp[j*stride+i] = X[ordery[i]*N0+j];
572    } else {
573       for (i=0;i<stride;i++)
574          for (j=0;j<N0;j++)
575             tmp[j*stride+i] = X[i*N0+j];
576    }
577    for (j=0;j<N;j++)
578       X[j] = tmp[j];
579    RESTORE_STACK;
580 }
581
582 void haar1(celt_norm *X, int N0, int stride)
583 {
584    int i, j;
585    N0 >>= 1;
586    for (i=0;i<stride;i++)
587       for (j=0;j<N0;j++)
588       {
589          celt_norm tmp1, tmp2;
590          tmp1 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*2*j+i]);
591          tmp2 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
592          X[stride*2*j+i] = tmp1 + tmp2;
593          X[stride*(2*j+1)+i] = tmp1 - tmp2;
594       }
595 }
596
597 static int compute_qn(int N, int b, int offset, int pulse_cap, int stereo)
598 {
599    static const celt_int16 exp2_table8[8] =
600       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
601    int qn, qb;
602    int N2 = 2*N-1;
603    if (stereo && N==2)
604       N2--;
605    /* The upper limit ensures that in a stereo split with itheta==16384, we'll
606        always have enough bits left over to code at least one pulse in the
607        side; otherwise it would collapse, since it doesn't get folded. */
608    qb = IMIN(b-pulse_cap-(4<<BITRES), (b+N2*offset)/N2);
609
610    qb = IMIN(8<<BITRES, qb);
611
612    if (qb<(1<<BITRES>>1)) {
613       qn = 1;
614    } else {
615       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
616       qn = (qn+1)>>1<<1;
617    }
618    celt_assert(qn <= 256);
619    return qn;
620 }
621
622 /* This function is responsible for encoding and decoding a band for both
623    the mono and stereo case. Even in the mono case, it can split the band
624    in two and transmit the energy difference with the two half-bands. It
625    can be called recursively so bands can end up being split in 8 parts. */
626 static unsigned quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
627       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
628       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
629       celt_uint32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
630 {
631    int q;
632    int curr_bits;
633    int stereo, split;
634    int imid=0, iside=0;
635    int N0=N;
636    int N_B=N;
637    int N_B0;
638    int B0=B;
639    int time_divide=0;
640    int recombine=0;
641    int inv = 0;
642    celt_word16 mid=0, side=0;
643    int longBlocks;
644    unsigned cm=0;
645
646    longBlocks = B0==1;
647
648    N_B /= B;
649    N_B0 = N_B;
650
651    split = stereo = Y != NULL;
652
653    /* Special case for one sample */
654    if (N==1)
655    {
656       int c;
657       celt_norm *x = X;
658       c=0; do {
659          int sign=0;
660          if (*remaining_bits>=1<<BITRES)
661          {
662             if (encode)
663             {
664                sign = x[0]<0;
665                ec_enc_bits((ec_enc*)ec, sign, 1);
666             } else {
667                sign = ec_dec_bits((ec_dec*)ec, 1);
668             }
669             *remaining_bits -= 1<<BITRES;
670             b-=1<<BITRES;
671          }
672          if (resynth)
673             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
674          x = Y;
675       } while (++c<1+stereo);
676       if (lowband_out)
677          lowband_out[0] = SHR16(X[0],4);
678       return 1;
679    }
680
681    if (!stereo && level == 0)
682    {
683       int k;
684       if (tf_change>0)
685          recombine = tf_change;
686       /* Band recombining to increase frequency resolution */
687
688       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
689       {
690          int j;
691          for (j=0;j<N;j++)
692             lowband_scratch[j] = lowband[j];
693          lowband = lowband_scratch;
694       }
695
696       for (k=0;k<recombine;k++)
697       {
698          if (encode)
699             haar1(X, N>>k, 1<<k);
700          if (lowband)
701             haar1(lowband, N>>k, 1<<k);
702          fill |= fill<<(1<<k);
703       }
704       B>>=recombine;
705       N_B<<=recombine;
706
707       /* Increasing the time resolution */
708       while ((N_B&1) == 0 && tf_change<0)
709       {
710          if (encode)
711             haar1(X, N_B, B);
712          if (lowband)
713             haar1(lowband, N_B, B);
714          fill |= fill<<B;
715          B <<= 1;
716          N_B >>= 1;
717          time_divide++;
718          tf_change++;
719       }
720       B0=B;
721       N_B0 = N_B;
722
723       /* Reorganize the samples in time order instead of frequency order */
724       if (B0>1)
725       {
726          if (encode)
727             deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
728          if (lowband)
729             deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
730       }
731    }
732
733    /* If we need more than 32 bits, try splitting the band in two. */
734    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
735    {
736       if (LM>0 || (N&1)==0)
737       {
738          N >>= 1;
739          Y = X+N;
740          split = 1;
741          LM -= 1;
742          if (B==1)
743             fill = fill&1|fill<<1;
744          B = (B+1)>>1;
745       }
746    }
747
748    if (split)
749    {
750       int qn;
751       int itheta=0;
752       int mbits, sbits, delta;
753       int qalloc;
754       int pulse_cap;
755       int offset;
756       int orig_fill;
757       celt_int32 tell;
758
759       /* Decide on the resolution to give to the split parameter theta */
760       pulse_cap = m->logN[i]+(LM<<BITRES);
761       offset = (pulse_cap>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
762       qn = compute_qn(N, b, offset, pulse_cap, stereo);
763       if (stereo && i>=intensity)
764          qn = 1;
765       if (encode)
766       {
767          /* theta is the atan() of the ratio between the (normalized)
768             side and mid. With just that parameter, we can re-scale both
769             mid and side because we know that 1) they have unit norm and
770             2) they are orthogonal. */
771          itheta = stereo_itheta(X, Y, stereo, N);
772       }
773       tell = encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES);
774       if (qn!=1)
775       {
776          if (encode)
777             itheta = (itheta*qn+8192)>>14;
778
779          /* Entropy coding of the angle. We use a uniform pdf for the
780             time split, a step for stereo, and a triangular one for the rest. */
781          if (stereo && N>2)
782          {
783             int p0 = 3;
784             int x = itheta;
785             int x0 = qn/2;
786             int ft = p0*(x0+1) + x0;
787             /* Use a probability of p0 up to itheta=8192 and then use 1 after */
788             if (encode)
789             {
790                ec_encode((ec_enc*)ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
791             } else {
792                int fs;
793                fs=ec_decode(ec,ft);
794                if (fs<(x0+1)*p0)
795                   x=fs/p0;
796                else
797                   x=x0+1+(fs-(x0+1)*p0);
798                ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
799                itheta = x;
800             }
801          } else if (B0>1 || stereo) {
802             /* Uniform pdf */
803             if (encode)
804                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
805             else
806                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
807          } else {
808             int fs=1, ft;
809             ft = ((qn>>1)+1)*((qn>>1)+1);
810             if (encode)
811             {
812                int fl;
813
814                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
815                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
816                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
817
818                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
819             } else {
820                /* Triangular pdf */
821                int fl=0;
822                int fm;
823                fm = ec_decode((ec_dec*)ec, ft);
824
825                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
826                {
827                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
828                   fs = itheta + 1;
829                   fl = itheta*(itheta + 1)>>1;
830                }
831                else
832                {
833                   itheta = (2*(qn + 1)
834                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
835                   fs = qn + 1 - itheta;
836                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
837                }
838
839                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
840             }
841          }
842          itheta = (celt_int32)itheta*16384/qn;
843          if (encode && stereo)
844          {
845             if (itheta==0)
846                intensity_stereo(m, X, Y, bandE, i, N);
847             else
848                stereo_split(X, Y, N);
849          }
850          /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
851                   Let's do that at higher complexity */
852       } else if (stereo) {
853          if (encode)
854          {
855             inv = itheta > 8192;
856             if (inv)
857             {
858                int j;
859                for (j=0;j<N;j++)
860                   Y[j] = -Y[j];
861             }
862             intensity_stereo(m, X, Y, bandE, i, N);
863          }
864          if (b>2<<BITRES && *remaining_bits > 2<<BITRES)
865          {
866             if (encode)
867                ec_enc_bit_logp(ec, inv, 2);
868             else
869                inv = ec_dec_bit_logp(ec, 2);
870          } else
871             inv = 0;
872          itheta = 0;
873       }
874       qalloc = (encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES))
875                - tell;
876       b -= qalloc;
877
878       orig_fill = fill;
879       if (itheta == 0)
880       {
881          imid = 32767;
882          iside = 0;
883          fill &= (1<<B)-1;
884          delta = -16384;
885       } else if (itheta == 16384)
886       {
887          imid = 0;
888          iside = 32767;
889          fill &= (1<<B)-1<<B;
890          delta = 16384;
891       } else {
892          imid = bitexact_cos(itheta);
893          iside = bitexact_cos(16384-itheta);
894          /* This is the mid vs side allocation that minimizes squared error
895             in that band. */
896          delta = FRAC_MUL16(N-1<<7,bitexact_log2tan(iside,imid));
897       }
898
899 #ifdef FIXED_POINT
900       mid = imid;
901       side = iside;
902 #else
903       mid = (1.f/32768)*imid;
904       side = (1.f/32768)*iside;
905 #endif
906
907       /* This is a special case for N=2 that only works for stereo and takes
908          advantage of the fact that mid and side are orthogonal to encode
909          the side with just one bit. */
910       if (N==2 && stereo)
911       {
912          int c;
913          int sign=0;
914          celt_norm *x2, *y2;
915          mbits = b;
916          sbits = 0;
917          /* Only need one bit for the side */
918          if (itheta != 0 && itheta != 16384)
919             sbits = 1<<BITRES;
920          mbits -= sbits;
921          c = itheta > 8192;
922          *remaining_bits -= qalloc+sbits;
923
924          x2 = c ? Y : X;
925          y2 = c ? X : Y;
926          if (sbits)
927          {
928             if (encode)
929             {
930                /* Here we only need to encode a sign for the side */
931                sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
932                ec_enc_bits((ec_enc*)ec, sign, 1);
933             } else {
934                sign = ec_dec_bits((ec_dec*)ec, 1);
935             }
936          }
937          sign = 1-2*sign;
938          /* We use orig_fill here because we want to fold the side, but if
939              itheta==16384, we'll have cleared the low bits of fill. */
940          cm = quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, orig_fill);
941          /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
942              and there's no need to worry about mixing with the other channel. */
943          y2[0] = -sign*x2[1];
944          y2[1] = sign*x2[0];
945          if (resynth)
946          {
947             celt_norm tmp;
948             X[0] = MULT16_16_Q15(mid, X[0]);
949             X[1] = MULT16_16_Q15(mid, X[1]);
950             Y[0] = MULT16_16_Q15(side, Y[0]);
951             Y[1] = MULT16_16_Q15(side, Y[1]);
952             tmp = X[0];
953             X[0] = SUB16(tmp,Y[0]);
954             Y[0] = ADD16(tmp,Y[0]);
955             tmp = X[1];
956             X[1] = SUB16(tmp,Y[1]);
957             Y[1] = ADD16(tmp,Y[1]);
958          }
959       } else {
960          /* "Normal" split code */
961          celt_norm *next_lowband2=NULL;
962          celt_norm *next_lowband_out1=NULL;
963          int next_level=0;
964          celt_int32 rebalance;
965
966          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
967          if (B0>1 && !stereo && (itheta&0x3fff))
968          {
969             if (itheta > 8192)
970                /* Rough approximation for pre-echo masking */
971                delta -= delta>>(4-LM);
972             else
973                /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
974                delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
975          }
976          mbits = IMAX(0, IMIN(b, (b-delta)/2));
977          sbits = b-mbits;
978          *remaining_bits -= qalloc;
979
980          if (lowband && !stereo)
981             next_lowband2 = lowband+N; /* >32-bit split case */
982
983          /* Only stereo needs to pass on lowband_out. Otherwise, it's
984             handled at the end */
985          if (stereo)
986             next_lowband_out1 = lowband_out;
987          else
988             next_level = level+1;
989
990          rebalance = *remaining_bits;
991          if (mbits >= sbits)
992          {
993             /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
994                mid for folding later */
995             cm = quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
996                   lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
997                   NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
998             rebalance = mbits - (rebalance-*remaining_bits);
999             if (rebalance > 3<<BITRES && itheta!=0)
1000                sbits += rebalance - (3<<BITRES);
1001
1002             /* For a stereo split, the high bits of fill are always zero, so no
1003                folding will be done to the side. */
1004             cm |= quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
1005                   next_lowband2, resynth, ec, remaining_bits, LM, NULL,
1006                   NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<(B0>>1&stereo-1);
1007          } else {
1008             /* For a stereo split, the high bits of fill are always zero, so no
1009                folding will be done to the side. */
1010             cm = quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
1011                   next_lowband2, resynth, ec, remaining_bits, LM, NULL,
1012                   NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<(B0>>1&stereo-1);
1013             rebalance = sbits - (rebalance-*remaining_bits);
1014             if (rebalance > 3<<BITRES && itheta!=16384)
1015                mbits += rebalance - (3<<BITRES);
1016             /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1017                mid for folding later */
1018             cm |= quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
1019                   lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
1020                   NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
1021          }
1022       }
1023
1024    } else {
1025       /* This is the basic no-split case */
1026       q = bits2pulses(m, i, LM, b);
1027       curr_bits = pulses2bits(m, i, LM, q);
1028       *remaining_bits -= curr_bits;
1029
1030       /* Ensures we can never bust the budget */
1031       while (*remaining_bits < 0 && q > 0)
1032       {
1033          *remaining_bits += curr_bits;
1034          q--;
1035          curr_bits = pulses2bits(m, i, LM, q);
1036          *remaining_bits -= curr_bits;
1037       }
1038
1039       if (q!=0)
1040       {
1041          int K = get_pulses(q);
1042
1043          /* Finally do the actual quantization */
1044          if (encode)
1045             cm = alg_quant(X, N, K, spread, B, resynth, (ec_enc*)ec, gain);
1046          else
1047             cm = alg_unquant(X, N, K, spread, B, (ec_dec*)ec, gain);
1048       } else {
1049          /* If there's no pulse, fill the band anyway */
1050          int j;
1051          if (resynth)
1052          {
1053             unsigned cm_mask;
1054             /*B can be as large as 16, so this shift might overflow an int on a
1055                16-bit platform; use a long to get defined behavior.*/
1056             cm_mask = (unsigned)(1UL<<B)-1;
1057             fill &= cm_mask;
1058             if (!fill)
1059             {
1060                for (j=0;j<N;j++)
1061                   X[j] = 0;
1062             } else {
1063                if (lowband == NULL)
1064                {
1065                   /* Noise */
1066                   for (j=0;j<N;j++)
1067                   {
1068                      *seed = lcg_rand(*seed);
1069                      X[j] = (celt_int32)(*seed)>>20;
1070                   }
1071                   cm = cm_mask;
1072                } else {
1073                   /* Folded spectrum */
1074                   for (j=0;j<N;j++)
1075                      X[j] = lowband[j];
1076                   cm = fill;
1077                }
1078                renormalise_vector(X, N, gain);
1079             }
1080          }
1081       }
1082    }
1083
1084    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1085    if (resynth)
1086    {
1087       if (stereo)
1088       {
1089          if (N!=2)
1090             stereo_merge(X, Y, mid, N);
1091          if (inv)
1092          {
1093             int j;
1094             for (j=0;j<N;j++)
1095                Y[j] = -Y[j];
1096          }
1097       } else if (level == 0)
1098       {
1099          int k;
1100
1101          /* Undo the sample reorganization going from time order to frequency order */
1102          if (B0>1)
1103             interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1104
1105          /* Undo time-freq changes that we did earlier */
1106          N_B = N_B0;
1107          B = B0;
1108          for (k=0;k<time_divide;k++)
1109          {
1110             B >>= 1;
1111             N_B <<= 1;
1112             cm |= cm>>B;
1113             haar1(X, N_B, B);
1114          }
1115
1116          for (k=0;k<recombine;k++)
1117          {
1118             cm |= cm<<(1<<k);
1119             haar1(X, N0>>k, 1<<k);
1120          }
1121          B<<=recombine;
1122          N_B>>=recombine;
1123
1124          /* Scale output for later folding */
1125          if (lowband_out)
1126          {
1127             int j;
1128             celt_word16 n;
1129             n = celt_sqrt(SHL32(EXTEND32(N0),22));
1130             for (j=0;j<N0;j++)
1131                lowband_out[j] = MULT16_16_Q15(n,X[j]);
1132          }
1133          cm &= (1<<B)-1;
1134       }
1135    }
1136    return cm;
1137 }
1138
1139 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1140       celt_norm *_X, celt_norm *_Y, unsigned char *collapse_masks, const celt_ener *bandE, int *pulses,
1141       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
1142       celt_int32 total_bits, void *ec, int LM, int codedBands, ec_uint32 *seed)
1143 {
1144    int i;
1145    celt_int32 balance;
1146    celt_int32 remaining_bits;
1147    const celt_int16 * restrict eBands = m->eBands;
1148    celt_norm * restrict norm, * restrict norm2;
1149    VARDECL(celt_norm, _norm);
1150    VARDECL(celt_norm, lowband_scratch);
1151    int B;
1152    int M;
1153    int lowband_offset;
1154    int update_lowband = 1;
1155    int C = _Y != NULL ? 2 : 1;
1156    SAVE_STACK;
1157
1158    M = 1<<LM;
1159    B = shortBlocks ? M : 1;
1160    ALLOC(_norm, C*M*eBands[m->nbEBands], celt_norm);
1161    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
1162    norm = _norm;
1163    norm2 = norm + M*eBands[m->nbEBands];
1164
1165    balance = 0;
1166    lowband_offset = 0;
1167    for (i=start;i<end;i++)
1168    {
1169       celt_int32 tell;
1170       int b;
1171       int N;
1172       celt_int32 curr_balance;
1173       int effective_lowband=-1;
1174       celt_norm * restrict X, * restrict Y;
1175       int tf_change=0;
1176       unsigned x_cm;
1177       unsigned y_cm;
1178
1179       X = _X+M*eBands[i];
1180       if (_Y!=NULL)
1181          Y = _Y+M*eBands[i];
1182       else
1183          Y = NULL;
1184       N = M*eBands[i+1]-M*eBands[i];
1185       if (encode)
1186          tell = ec_enc_tell((ec_enc*)ec, BITRES);
1187       else
1188          tell = ec_dec_tell((ec_dec*)ec, BITRES);
1189
1190       /* Compute how many bits we want to allocate to this band */
1191       if (i != start)
1192          balance -= tell;
1193       remaining_bits = total_bits-tell-1;
1194       if (i <= codedBands-1)
1195       {
1196          curr_balance = balance / IMIN(3, codedBands-i);
1197          b = IMAX(0, IMIN(16383, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1198       } else {
1199          b = 0;
1200       }
1201
1202       if (resynth && M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1203             lowband_offset = i;
1204
1205       tf_change = tf_res[i];
1206       if (i>=m->effEBands)
1207       {
1208          X=norm;
1209          if (_Y!=NULL)
1210             Y = norm;
1211       }
1212
1213       /* Get a conservative estimate of the collapse_mask's for the bands we're
1214           going to be folding from. */
1215       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1 || tf_change<0))
1216       {
1217          int fold_start;
1218          int fold_end;
1219          int fold_i;
1220          /* This ensures we never repeat spectral content within one band */
1221          effective_lowband = IMAX(M*eBands[start], M*eBands[lowband_offset]-N);
1222          fold_start = lowband_offset;
1223          while(M*eBands[--fold_start] > effective_lowband);
1224          fold_end = lowband_offset-1;
1225          while(M*eBands[++fold_end] < effective_lowband+N);
1226          x_cm = y_cm = 0;
1227          fold_i = fold_start; do {
1228            x_cm |= collapse_masks[fold_i*C+0];
1229            y_cm |= collapse_masks[fold_i*C+C-1];
1230          } while (++fold_i<fold_end);
1231       }
1232       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1233           always) be non-zero.*/
1234       else
1235          x_cm = y_cm = (1<<B)-1;
1236
1237       if (dual_stereo && i==intensity)
1238       {
1239          int j;
1240
1241          /* Switch off dual stereo to do intensity */
1242          dual_stereo = 0;
1243          for (j=M*eBands[start];j<M*eBands[i];j++)
1244             norm[j] = HALF32(norm[j]+norm2[j]);
1245       }
1246       if (dual_stereo)
1247       {
1248          x_cm = quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
1249                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1250                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm);
1251          y_cm = quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
1252                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1253                norm2+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, y_cm);
1254       } else {
1255          x_cm = quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
1256                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1257                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm|y_cm);
1258          y_cm = x_cm;
1259       }
1260       collapse_masks[i*C+0] = (unsigned char)x_cm;
1261       collapse_masks[i*C+C-1] = (unsigned char)y_cm;
1262       balance += pulses[i] + tell;
1263
1264       /* Update the folding position only as long as we have 1 bit/sample depth */
1265       update_lowband = b>(N<<BITRES);
1266    }
1267    RESTORE_STACK;
1268 }
1269