Separates stereo_band_mix() into the intensity and MS stereo cases
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    celt_word16 left, right;
221    celt_word16 norm;
222 #ifdef FIXED_POINT
223    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
224 #endif
225    left = VSHR32(bank[i],shift);
226    right = VSHR32(bank[i+m->nbEBands],shift);
227    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
228    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
229    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
230    for (j=0;j<N;j++)
231    {
232       celt_norm r, l;
233       l = X[j];
234       r = Y[j];
235       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
236       /* Side is not encoded, no need to calculate */
237    }
238 }
239
240 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
241 {
242    int j;
243    for (j=0;j<N;j++)
244    {
245       celt_norm r, l;
246       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
247       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
248       X[j] = l+r;
249       Y[j] = r-l;
250    }
251 }
252
253 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
254 {
255    int j;
256    celt_word32 xp=0;
257    celt_word32 El, Er;
258 #ifdef FIXED_POINT
259    int kl, kr;
260 #endif
261    celt_word32 t, lgain, rgain;
262
263    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
264    for (j=0;j<N;j++)
265       xp = MAC16_16(xp, X[j], Y[j]);
266    /* mid and side are in Q15, not Q14 like X and Y */
267    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*SHL32(xp,2);
268    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*SHL32(xp,2);
269    if (Er < EPSILON)
270       Er = EPSILON;
271    if (El < EPSILON)
272       El = EPSILON;
273
274 #ifdef FIXED_POINT
275    kl = celt_ilog2(El)>>1;
276    kr = celt_ilog2(Er)>>1;
277 #endif
278    t = VSHR32(El, (kl-7)<<1);
279    lgain = celt_rsqrt_norm(t);
280    t = VSHR32(Er, (kr-7)<<1);
281    rgain = celt_rsqrt_norm(t);
282
283    for (j=0;j<N;j++)
284    {
285       celt_norm r, l;
286       l = X[j];
287       r = Y[j];
288       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, l-r), kl));
289       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, l+r), kr));
290    }
291 }
292
293 /* Decide whether we should spread the pulses in the current frame */
294 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
295 {
296    int i, c, N0;
297    int sum = 0, nbBands=0;
298    const int C = CHANNELS(_C);
299    const celt_int16 * restrict eBands = m->eBands;
300    int decision;
301    
302    N0 = M*m->shortMdctSize;
303
304    if (M*(eBands[end]-eBands[end-1]) <= 8)
305       return 0;
306    for (c=0;c<C;c++)
307    {
308       for (i=0;i<end;i++)
309       {
310          int j, N, tmp=0;
311          int tcount[3] = {0};
312          celt_norm * restrict x = X+M*eBands[i]+c*N0;
313          N = M*(eBands[i+1]-eBands[i]);
314          if (N<=8)
315             continue;
316          /* Compute rough CDF of |x[j]| */
317          for (j=0;j<N;j++)
318          {
319             celt_word32 x2N; /* Q13 */
320
321             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
322             if (x2N < QCONST16(0.25f,13))
323                tcount[0]++;
324             if (x2N < QCONST16(0.0625f,13))
325                tcount[1]++;
326             if (x2N < QCONST16(0.015625f,13))
327                tcount[2]++;
328          }
329
330          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
331          sum += tmp*256;
332          nbBands++;
333       }
334    }
335    sum /= nbBands;
336    /* Recursive averaging */
337    sum = (sum+*average)>>1;
338    *average = sum;
339    /* Hysteresis */
340    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
341    /* decision and last_decision do not use the same encoding */
342    if (sum < 128)
343    {
344       decision = 2;
345       *last_decision = 0;
346    } else if (sum < 256)
347    {
348       decision = 1;
349       *last_decision = 1;
350    } else if (sum < 384)
351    {
352       decision = 3;
353       *last_decision = 2;
354    } else {
355       decision = 0;
356       *last_decision = 3;
357    }
358    return decision;
359 }
360
361 #ifdef MEASURE_NORM_MSE
362
363 float MSE[30] = {0};
364 int nbMSEBands = 0;
365 int MSECount[30] = {0};
366
367 void dump_norm_mse(void)
368 {
369    int i;
370    for (i=0;i<nbMSEBands;i++)
371    {
372       printf ("%g ", MSE[i]/MSECount[i]);
373    }
374    printf ("\n");
375 }
376
377 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
378 {
379    static int init = 0;
380    int i;
381    if (!init)
382    {
383       atexit(dump_norm_mse);
384       init = 1;
385    }
386    for (i=0;i<m->nbEBands;i++)
387    {
388       int j;
389       int c;
390       float g;
391       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
392          continue;
393       for (c=0;c<C;c++)
394       {
395          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
396          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
397             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
398       }
399       MSECount[i]+=C;
400    }
401    nbMSEBands = m->nbEBands;
402 }
403
404 #endif
405
406 static void interleave_vector(celt_norm *X, int N0, int stride)
407 {
408    int i,j;
409    VARDECL(celt_norm, tmp);
410    int N;
411    SAVE_STACK;
412    N = N0*stride;
413    ALLOC(tmp, N, celt_norm);
414    for (i=0;i<stride;i++)
415       for (j=0;j<N0;j++)
416          tmp[j*stride+i] = X[i*N0+j];
417    for (j=0;j<N;j++)
418       X[j] = tmp[j];
419    RESTORE_STACK;
420 }
421
422 static void deinterleave_vector(celt_norm *X, int N0, int stride)
423 {
424    int i,j;
425    VARDECL(celt_norm, tmp);
426    int N;
427    SAVE_STACK;
428    N = N0*stride;
429    ALLOC(tmp, N, celt_norm);
430    for (i=0;i<stride;i++)
431       for (j=0;j<N0;j++)
432          tmp[i*N0+j] = X[j*stride+i];
433    for (j=0;j<N;j++)
434       X[j] = tmp[j];
435    RESTORE_STACK;
436 }
437
438 static void haar1(celt_norm *X, int N0, int stride)
439 {
440    int i, j;
441    N0 >>= 1;
442    for (i=0;i<stride;i++)
443       for (j=0;j<N0;j++)
444       {
445          celt_norm tmp = X[stride*2*j+i];
446          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
447          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
448       }
449 }
450
451 static int compute_qn(int N, int b, int offset, int stereo)
452 {
453    static const celt_int16 exp2_table8[8] =
454       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
455    int qn, qb;
456    int N2 = 2*N-1;
457    if (stereo && N==2)
458       N2--;
459    qb = (b+N2*offset)/N2;
460    if (qb > (b>>1)-(1<<BITRES))
461       qb = (b>>1)-(1<<BITRES);
462
463    if (qb<0)
464        qb = 0;
465    if (qb>14<<BITRES)
466      qb = 14<<BITRES;
467
468    if (qb<(1<<BITRES>>1)) {
469       qn = 1;
470    } else {
471       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
472       qn = (qn+1)>>1<<1;
473       if (qn>1024)
474          qn = 1024;
475    }
476    return qn;
477 }
478
479
480 /* This function is responsible for encoding and decoding a band for both
481    the mono and stereo case. Even in the mono case, it can split the band
482    in two and transmit the energy difference with the two half-bands. It
483    can be called recursively so bands can end up being split in 8 parts. */
484 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
485       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
486       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
487       celt_int32 *seed, celt_word16 gain)
488 {
489    int q;
490    int curr_bits;
491    int stereo, split;
492    int imid=0, iside=0;
493    int N0=N;
494    int N_B=N;
495    int N_B0;
496    int B0=B;
497    int time_divide=0;
498    int recombine=0;
499    celt_word16 mid=0, side=0;
500
501    N_B /= B;
502    N_B0 = N_B;
503
504    split = stereo = Y != NULL;
505
506    /* Special case for one sample */
507    if (N==1)
508    {
509       int c;
510       celt_norm *x = X;
511       for (c=0;c<1+stereo;c++)
512       {
513          int sign=0;
514          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
515          {
516             if (encode)
517             {
518                sign = x[0]<0;
519                ec_enc_bits((ec_enc*)ec, sign, 1);
520             } else {
521                sign = ec_dec_bits((ec_dec*)ec, 1);
522             }
523             *remaining_bits -= 1<<BITRES;
524             b-=1<<BITRES;
525          }
526          if (resynth)
527             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
528          x = Y;
529       }
530       if (lowband_out)
531          lowband_out[0] = X[0];
532       return;
533    }
534
535    /* Band recombining to increase frequency resolution */
536    if (!stereo && B > 1 && level == 0 && tf_change>0)
537    {
538       while (B>1 && tf_change>0)
539       {
540          B>>=1;
541          N_B<<=1;
542          if (encode)
543             haar1(X, N_B, B);
544          if (lowband)
545             haar1(lowband, N_B, B);
546          recombine++;
547          tf_change--;
548       }
549       B0=B;
550       N_B0 = N_B;
551    }
552
553    /* Increasing the time resolution */
554    if (!stereo && level==0)
555    {
556       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
557       {
558          if (encode)
559             haar1(X, N_B, B);
560          if (lowband)
561             haar1(lowband, N_B, B);
562          B <<= 1;
563          N_B >>= 1;
564          time_divide++;
565          tf_change++;
566       }
567       B0 = B;
568       N_B0 = N_B;
569    }
570
571    /* Reorganize the samples in time order instead of frequency order */
572    if (!stereo && B0>1 && level==0)
573    {
574       if (encode)
575          deinterleave_vector(X, N_B, B0);
576       if (lowband)
577          deinterleave_vector(lowband, N_B, B0);
578    }
579
580    /* If we need more than 32 bits, try splitting the band in two. */
581    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
582    {
583       if (LM>0 || (N&1)==0)
584       {
585          N >>= 1;
586          Y = X+N;
587          split = 1;
588          LM -= 1;
589          B = (B+1)>>1;
590       }
591    }
592
593    if (split)
594    {
595       int qn;
596       int itheta=0;
597       int mbits, sbits, delta;
598       int qalloc;
599       int offset;
600
601       /* Decide on the resolution to give to the split parameter theta */
602       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
603       qn = compute_qn(N, b, offset, stereo);
604
605       qalloc = 0;
606       if (qn!=1)
607       {
608          if (encode)
609          {
610             if (stereo)
611                stereo_split(X, Y, N);
612
613             mid = vector_norm(X, N);
614             side = vector_norm(Y, N);
615             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
616                      Let's do that at higher complexity */
617             /*mid = renormalise_vector(X, Q15ONE, N, 1);
618             side = renormalise_vector(Y, Q15ONE, N, 1);*/
619
620             /* theta is the atan() of the ratio between the (normalized)
621                side and mid. With just that parameter, we can re-scale both
622                mid and side because we know that 1) they have unit norm and
623                2) they are orthogonal. */
624    #ifdef FIXED_POINT
625             /* 0.63662 = 2/pi */
626             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
627    #else
628             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
629    #endif
630
631             itheta = (itheta*qn+8192)>>14;
632          }
633
634          /* Entropy coding of the angle. We use a uniform pdf for the
635             first stereo split but a triangular one for the rest. */
636          if (stereo || qn>256 || B>1)
637          {
638             if (encode)
639                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
640             else
641                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
642             qalloc = log2_frac(qn+1,BITRES);
643          } else {
644             int fs=1, ft;
645             ft = ((qn>>1)+1)*((qn>>1)+1);
646             if (encode)
647             {
648                int fl;
649
650                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
651                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
652                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
653
654                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
655             } else {
656                int fl=0;
657                int fm;
658                fm = ec_decode((ec_dec*)ec, ft);
659
660                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
661                {
662                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
663                   fs = itheta + 1;
664                   fl = itheta*(itheta + 1)>>1;
665                }
666                else
667                {
668                   itheta = (2*(qn + 1)
669                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
670                   fs = qn + 1 - itheta;
671                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
672                }
673
674                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
675             }
676             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
677          }
678          itheta = (celt_int32)itheta*16384/qn;
679       } else {
680          if (stereo && encode)
681             intensity_stereo(m, X, Y, bandE, i, N);
682       }
683
684       if (itheta == 0)
685       {
686          imid = 32767;
687          iside = 0;
688          delta = -10000;
689       } else if (itheta == 16384)
690       {
691          imid = 0;
692          iside = 32767;
693          delta = 10000;
694       } else {
695          imid = bitexact_cos(itheta);
696          iside = bitexact_cos(16384-itheta);
697          /* This is the mid vs side allocation that minimizes squared error
698             in that band. */
699          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
700       }
701
702       /* This is a special case for N=2 that only works for stereo and takes
703          advantage of the fact that mid and side are orthogonal to encode
704          the side with just one bit. */
705       if (N==2 && stereo)
706       {
707          int c;
708          int sign=1;
709          celt_norm *x2, *y2;
710          mbits = b-qalloc;
711          sbits = 0;
712          /* Only need one bit for the side */
713          if (itheta != 0 && itheta != 16384)
714             sbits = 1<<BITRES;
715          mbits -= sbits;
716          c = itheta > 8192;
717          *remaining_bits -= qalloc+sbits;
718
719          x2 = c ? Y : X;
720          y2 = c ? X : Y;
721          if (sbits)
722          {
723             if (encode)
724             {
725                /* Here we only need to encode a sign for the side */
726                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
727                ec_enc_bits((ec_enc*)ec, sign, 1);
728             } else {
729                sign = ec_dec_bits((ec_dec*)ec, 1);
730             }
731          }
732          sign = 2*sign - 1;
733          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed, gain);
734          y2[0] = -sign*x2[1];
735          y2[1] = sign*x2[0];
736       } else {
737          /* "Normal" split code */
738          celt_norm *next_lowband2=NULL;
739          celt_norm *next_lowband_out1=NULL;
740          int next_level=0;
741
742 #ifdef FIXED_POINT
743          mid = imid;
744          side = iside;
745 #else
746          mid = (1.f/32768)*imid;
747          side = (1.f/32768)*iside;
748 #endif
749
750          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
751          if (B>1 && !stereo)
752             delta >>= 1;
753
754          mbits = (b-qalloc-delta)/2;
755          if (mbits > b-qalloc)
756             mbits = b-qalloc;
757          if (mbits<0)
758             mbits=0;
759          sbits = b-qalloc-mbits;
760          *remaining_bits -= qalloc;
761
762          if (lowband && !stereo)
763             next_lowband2 = lowband+N; /* >32-bit split case */
764
765          /* Only stereo needs to pass on lowband_out. Otherwise, it's
766             handled at the end */
767          if (stereo)
768             next_lowband_out1 = lowband_out;
769          else
770             next_level = level+1;
771
772          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
773                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
774                NULL, next_level, seed, MULT16_16_P15(gain,mid));
775          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
776                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
777                NULL, next_level, seed, MULT16_16_P15(gain,side));
778       }
779
780    } else {
781       /* This is the basic no-split case */
782       q = bits2pulses(m, i, LM, b);
783       curr_bits = pulses2bits(m, i, LM, q);
784       *remaining_bits -= curr_bits;
785
786       /* Ensures we can never bust the budget */
787       while (*remaining_bits < 0 && q > 0)
788       {
789          *remaining_bits += curr_bits;
790          q--;
791          curr_bits = pulses2bits(m, i, LM, q);
792          *remaining_bits -= curr_bits;
793       }
794
795       /* Finally do the actual quantization */
796       if (encode)
797          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
798       else
799          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed, gain);
800    }
801
802    /* This code is used by the decoder and by the resynthesis-enabled encoder */
803    if (resynth)
804    {
805       int k;
806
807       /* Undo the sample reorganization going from time order to frequency order */
808       if (!stereo && B0>1 && level==0)
809       {
810          interleave_vector(X, N_B, B0);
811          if (lowband)
812             interleave_vector(lowband, N_B, B0);
813       }
814
815       /* Undo time-freq changes that we did earlier */
816       N_B = N_B0;
817       B = B0;
818       for (k=0;k<time_divide;k++)
819       {
820          B >>= 1;
821          N_B <<= 1;
822          haar1(X, N_B, B);
823          if (lowband)
824             haar1(lowband, N_B, B);
825       }
826
827       for (k=0;k<recombine;k++)
828       {
829          haar1(X, N_B, B);
830          if (lowband)
831             haar1(lowband, N_B, B);
832          N_B>>=1;
833          B <<= 1;
834       }
835
836       /* Scale output for later folding */
837       if (lowband_out && !stereo)
838       {
839          int j;
840          celt_word16 n;
841          n = celt_sqrt(SHL32(EXTEND32(N0),22));
842          for (j=0;j<N0;j++)
843             lowband_out[j] = MULT16_16_Q15(n,X[j]);
844       }
845
846       if (stereo)
847          stereo_merge(X, Y, mid, side, N);
848    }
849 }
850
851 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
852 {
853    int i, balance;
854    celt_int32 remaining_bits;
855    const celt_int16 * restrict eBands = m->eBands;
856    celt_norm * restrict norm;
857    VARDECL(celt_norm, _norm);
858    int B;
859    int M;
860    celt_int32 seed;
861    celt_norm *lowband;
862    int update_lowband = 1;
863    int C = _Y != NULL ? 2 : 1;
864    SAVE_STACK;
865
866    M = 1<<LM;
867    B = shortBlocks ? M : 1;
868    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
869    norm = _norm;
870
871    if (encode)
872       seed = ((ec_enc*)ec)->rng;
873    else
874       seed = ((ec_dec*)ec)->rng;
875    balance = 0;
876    lowband = NULL;
877    for (i=start;i<end;i++)
878    {
879       int tell;
880       int b;
881       int N;
882       int curr_balance;
883       celt_norm * restrict X, * restrict Y;
884       int tf_change=0;
885       celt_norm *effective_lowband;
886       
887       X = _X+M*eBands[i];
888       if (_Y!=NULL)
889          Y = _Y+M*eBands[i];
890       else
891          Y = NULL;
892       N = M*eBands[i+1]-M*eBands[i];
893       if (encode)
894          tell = ec_enc_tell((ec_enc*)ec, BITRES);
895       else
896          tell = ec_dec_tell((ec_dec*)ec, BITRES);
897
898       /* Compute how many bits we want to allocate to this band */
899       if (i != start)
900          balance -= tell;
901       remaining_bits = (total_bits<<BITRES)-tell-1;
902       curr_balance = (end-i);
903       if (curr_balance > 3)
904          curr_balance = 3;
905       curr_balance = balance / curr_balance;
906       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
907       if (b<0)
908          b = 0;
909       /* Prevents ridiculous bit depths */
910       if (b > C*16*N<<BITRES)
911          b = C*16*N<<BITRES;
912
913       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
914             lowband = norm+M*eBands[i]-N;
915
916       tf_change = tf_res[i];
917       if (i>=m->effEBands)
918       {
919          X=norm;
920          if (_Y!=NULL)
921             Y = norm;
922       }
923
924       if (tf_change==0 && !shortBlocks && fold)
925          effective_lowband = NULL;
926       else
927          effective_lowband = lowband;
928       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
929             effective_lowband, resynth, ec, &remaining_bits, LM,
930             norm+M*eBands[i], bandE, 0, &seed, Q15ONE);
931
932       balance += pulses[i] + tell;
933
934       /* Update the folding position only as long as we have 2 bit/sample depth */
935       update_lowband = (b>>BITRES)>2*N;
936    }
937    RESTORE_STACK;
938 }
939