Added stereo_merge(), which does the renormalisation too
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    if (stereo_mode==0)
221    {
222       /* Do mid-side when not doing intensity stereo */
223       a1 = QCONST16(.70711f,14);
224       a2 = dir*QCONST16(.70711f,14);
225    } else {
226       celt_word16 left, right;
227       celt_word16 norm;
228 #ifdef FIXED_POINT
229       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
230 #endif
231       left = VSHR32(bank[i],shift);
232       right = VSHR32(bank[i+m->nbEBands],shift);
233       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
234       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
235       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
236    }
237    for (j=0;j<N;j++)
238    {
239       celt_norm r, l;
240       l = X[j];
241       r = Y[j];
242       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
243       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
244    }
245 }
246
247 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
248 {
249    int j;
250    celt_word32 xp=0;
251    celt_word32 El, Er;
252 #ifdef FIXED_POINT
253    int kl, kr;
254 #endif
255    celt_word32 t, lgain, rgain;
256
257    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
258    for (j=0;j<N;j++)
259       xp = MAC16_16(xp, X[j], Y[j]);
260    /* mid and side are in Q15, not Q14 like X and Y */
261    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*SHL32(xp,2);
262    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*SHL32(xp,2);
263    if (Er < EPSILON)
264       Er = EPSILON;
265    if (El < EPSILON)
266       El = EPSILON;
267
268 #ifdef FIXED_POINT
269    kl = celt_ilog2(El)>>1;
270    kr = celt_ilog2(Er)>>1;
271 #endif
272    t = VSHR32(El, (kl-7)<<1);
273    lgain = celt_rsqrt_norm(t);
274    t = VSHR32(Er, (kr-7)<<1);
275    rgain = celt_rsqrt_norm(t);
276
277    for (j=0;j<N;j++)
278    {
279       celt_norm r, l;
280       l = X[j];
281       r = Y[j];
282       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, l-r), kl));
283       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, l+r), kr));
284    }
285 }
286
287 /* Decide whether we should spread the pulses in the current frame */
288 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
289 {
290    int i, c, N0;
291    int sum = 0, nbBands=0;
292    const int C = CHANNELS(_C);
293    const celt_int16 * restrict eBands = m->eBands;
294    int decision;
295    
296    N0 = M*m->shortMdctSize;
297
298    if (M*(eBands[end]-eBands[end-1]) <= 8)
299       return 0;
300    for (c=0;c<C;c++)
301    {
302       for (i=0;i<end;i++)
303       {
304          int j, N, tmp=0;
305          int tcount[3] = {0};
306          celt_norm * restrict x = X+M*eBands[i]+c*N0;
307          N = M*(eBands[i+1]-eBands[i]);
308          if (N<=8)
309             continue;
310          /* Compute rough CDF of |x[j]| */
311          for (j=0;j<N;j++)
312          {
313             celt_word32 x2N; /* Q13 */
314
315             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
316             if (x2N < QCONST16(0.25f,13))
317                tcount[0]++;
318             if (x2N < QCONST16(0.0625f,13))
319                tcount[1]++;
320             if (x2N < QCONST16(0.015625f,13))
321                tcount[2]++;
322          }
323
324          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
325          sum += tmp*256;
326          nbBands++;
327       }
328    }
329    sum /= nbBands;
330    /* Recursive averaging */
331    sum = (sum+*average)>>1;
332    *average = sum;
333    /* Hysteresis */
334    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
335    /* decision and last_decision do not use the same encoding */
336    if (sum < 128)
337    {
338       decision = 2;
339       *last_decision = 0;
340    } else if (sum < 256)
341    {
342       decision = 1;
343       *last_decision = 1;
344    } else if (sum < 384)
345    {
346       decision = 3;
347       *last_decision = 2;
348    } else {
349       decision = 0;
350       *last_decision = 3;
351    }
352    return decision;
353 }
354
355 #ifdef MEASURE_NORM_MSE
356
357 float MSE[30] = {0};
358 int nbMSEBands = 0;
359 int MSECount[30] = {0};
360
361 void dump_norm_mse(void)
362 {
363    int i;
364    for (i=0;i<nbMSEBands;i++)
365    {
366       printf ("%g ", MSE[i]/MSECount[i]);
367    }
368    printf ("\n");
369 }
370
371 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
372 {
373    static int init = 0;
374    int i;
375    if (!init)
376    {
377       atexit(dump_norm_mse);
378       init = 1;
379    }
380    for (i=0;i<m->nbEBands;i++)
381    {
382       int j;
383       int c;
384       float g;
385       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
386          continue;
387       for (c=0;c<C;c++)
388       {
389          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
390          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
391             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
392       }
393       MSECount[i]+=C;
394    }
395    nbMSEBands = m->nbEBands;
396 }
397
398 #endif
399
400 static void interleave_vector(celt_norm *X, int N0, int stride)
401 {
402    int i,j;
403    VARDECL(celt_norm, tmp);
404    int N;
405    SAVE_STACK;
406    N = N0*stride;
407    ALLOC(tmp, N, celt_norm);
408    for (i=0;i<stride;i++)
409       for (j=0;j<N0;j++)
410          tmp[j*stride+i] = X[i*N0+j];
411    for (j=0;j<N;j++)
412       X[j] = tmp[j];
413    RESTORE_STACK;
414 }
415
416 static void deinterleave_vector(celt_norm *X, int N0, int stride)
417 {
418    int i,j;
419    VARDECL(celt_norm, tmp);
420    int N;
421    SAVE_STACK;
422    N = N0*stride;
423    ALLOC(tmp, N, celt_norm);
424    for (i=0;i<stride;i++)
425       for (j=0;j<N0;j++)
426          tmp[i*N0+j] = X[j*stride+i];
427    for (j=0;j<N;j++)
428       X[j] = tmp[j];
429    RESTORE_STACK;
430 }
431
432 static void haar1(celt_norm *X, int N0, int stride)
433 {
434    int i, j;
435    N0 >>= 1;
436    for (i=0;i<stride;i++)
437       for (j=0;j<N0;j++)
438       {
439          celt_norm tmp = X[stride*2*j+i];
440          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
441          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
442       }
443 }
444
445 static int compute_qn(int N, int b, int offset, int stereo)
446 {
447    static const celt_int16 exp2_table8[8] =
448       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
449    int qn, qb;
450    int N2 = 2*N-1;
451    if (stereo && N==2)
452       N2--;
453    qb = (b+N2*offset)/N2;
454    if (qb > (b>>1)-(1<<BITRES))
455       qb = (b>>1)-(1<<BITRES);
456
457    if (qb<0)
458        qb = 0;
459    if (qb>14<<BITRES)
460      qb = 14<<BITRES;
461
462    if (qb<(1<<BITRES>>1)) {
463       qn = 1;
464    } else {
465       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
466       qn = (qn+1)>>1<<1;
467       if (qn>1024)
468          qn = 1024;
469    }
470    return qn;
471 }
472
473
474 /* This function is responsible for encoding and decoding a band for both
475    the mono and stereo case. Even in the mono case, it can split the band
476    in two and transmit the energy difference with the two half-bands. It
477    can be called recursively so bands can end up being split in 8 parts. */
478 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
479       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
480       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
481       celt_int32 *seed, celt_word16 gain)
482 {
483    int q;
484    int curr_bits;
485    int stereo, split;
486    int imid=0, iside=0;
487    int N0=N;
488    int N_B=N;
489    int N_B0;
490    int B0=B;
491    int time_divide=0;
492    int recombine=0;
493    celt_word16 mid=0, side=0;
494
495    N_B /= B;
496    N_B0 = N_B;
497
498    split = stereo = Y != NULL;
499
500    /* Special case for one sample */
501    if (N==1)
502    {
503       int c;
504       celt_norm *x = X;
505       for (c=0;c<1+stereo;c++)
506       {
507          int sign=0;
508          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
509          {
510             if (encode)
511             {
512                sign = x[0]<0;
513                ec_enc_bits((ec_enc*)ec, sign, 1);
514             } else {
515                sign = ec_dec_bits((ec_dec*)ec, 1);
516             }
517             *remaining_bits -= 1<<BITRES;
518             b-=1<<BITRES;
519          }
520          if (resynth)
521             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
522          x = Y;
523       }
524       if (lowband_out)
525          lowband_out[0] = X[0];
526       return;
527    }
528
529    /* Band recombining to increase frequency resolution */
530    if (!stereo && B > 1 && level == 0 && tf_change>0)
531    {
532       while (B>1 && tf_change>0)
533       {
534          B>>=1;
535          N_B<<=1;
536          if (encode)
537             haar1(X, N_B, B);
538          if (lowband)
539             haar1(lowband, N_B, B);
540          recombine++;
541          tf_change--;
542       }
543       B0=B;
544       N_B0 = N_B;
545    }
546
547    /* Increasing the time resolution */
548    if (!stereo && level==0)
549    {
550       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
551       {
552          if (encode)
553             haar1(X, N_B, B);
554          if (lowband)
555             haar1(lowband, N_B, B);
556          B <<= 1;
557          N_B >>= 1;
558          time_divide++;
559          tf_change++;
560       }
561       B0 = B;
562       N_B0 = N_B;
563    }
564
565    /* Reorganize the samples in time order instead of frequency order */
566    if (!stereo && B0>1 && level==0)
567    {
568       if (encode)
569          deinterleave_vector(X, N_B, B0);
570       if (lowband)
571          deinterleave_vector(lowband, N_B, B0);
572    }
573
574    /* If we need more than 32 bits, try splitting the band in two. */
575    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
576    {
577       if (LM>0 || (N&1)==0)
578       {
579          N >>= 1;
580          Y = X+N;
581          split = 1;
582          LM -= 1;
583          B = (B+1)>>1;
584       }
585    }
586
587    if (split)
588    {
589       int qn;
590       int itheta=0;
591       int mbits, sbits, delta;
592       int qalloc;
593       int offset;
594
595       /* Decide on the resolution to give to the split parameter theta */
596       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
597       qn = compute_qn(N, b, offset, stereo);
598
599       qalloc = 0;
600       if (qn!=1)
601       {
602          if (encode)
603          {
604             if (stereo)
605                stereo_band_mix(m, X, Y, bandE, 0, i, 1, N);
606
607             mid = vector_norm(X, N);
608             side = vector_norm(Y, N);
609             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
610                      Let's do that at higher complexity */
611             /*mid = renormalise_vector(X, Q15ONE, N, 1);
612             side = renormalise_vector(Y, Q15ONE, N, 1);*/
613
614             /* theta is the atan() of the ratio between the (normalized)
615                side and mid. With just that parameter, we can re-scale both
616                mid and side because we know that 1) they have unit norm and
617                2) they are orthogonal. */
618    #ifdef FIXED_POINT
619             /* 0.63662 = 2/pi */
620             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
621    #else
622             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
623    #endif
624
625             itheta = (itheta*qn+8192)>>14;
626          }
627
628          /* Entropy coding of the angle. We use a uniform pdf for the
629             first stereo split but a triangular one for the rest. */
630          if (stereo || qn>256 || B>1)
631          {
632             if (encode)
633                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
634             else
635                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
636             qalloc = log2_frac(qn+1,BITRES);
637          } else {
638             int fs=1, ft;
639             ft = ((qn>>1)+1)*((qn>>1)+1);
640             if (encode)
641             {
642                int fl;
643
644                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
645                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
646                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
647
648                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
649             } else {
650                int fl=0;
651                int fm;
652                fm = ec_decode((ec_dec*)ec, ft);
653
654                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
655                {
656                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
657                   fs = itheta + 1;
658                   fl = itheta*(itheta + 1)>>1;
659                }
660                else
661                {
662                   itheta = (2*(qn + 1)
663                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
664                   fs = qn + 1 - itheta;
665                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
666                }
667
668                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
669             }
670             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
671          }
672          itheta = (celt_int32)itheta*16384/qn;
673       } else {
674          if (stereo && encode)
675             stereo_band_mix(m, X, Y, bandE, 1, i, 1, N);
676       }
677
678       if (itheta == 0)
679       {
680          imid = 32767;
681          iside = 0;
682          delta = -10000;
683       } else if (itheta == 16384)
684       {
685          imid = 0;
686          iside = 32767;
687          delta = 10000;
688       } else {
689          imid = bitexact_cos(itheta);
690          iside = bitexact_cos(16384-itheta);
691          /* This is the mid vs side allocation that minimizes squared error
692             in that band. */
693          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
694       }
695
696       /* This is a special case for N=2 that only works for stereo and takes
697          advantage of the fact that mid and side are orthogonal to encode
698          the side with just one bit. */
699       if (N==2 && stereo)
700       {
701          int c;
702          int sign=1;
703          celt_norm *x2, *y2;
704          mbits = b-qalloc;
705          sbits = 0;
706          /* Only need one bit for the side */
707          if (itheta != 0 && itheta != 16384)
708             sbits = 1<<BITRES;
709          mbits -= sbits;
710          c = itheta > 8192;
711          *remaining_bits -= qalloc+sbits;
712
713          x2 = c ? Y : X;
714          y2 = c ? X : Y;
715          if (sbits)
716          {
717             if (encode)
718             {
719                /* Here we only need to encode a sign for the side */
720                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
721                ec_enc_bits((ec_enc*)ec, sign, 1);
722             } else {
723                sign = ec_dec_bits((ec_dec*)ec, 1);
724             }
725          }
726          sign = 2*sign - 1;
727          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed, gain);
728          y2[0] = -sign*x2[1];
729          y2[1] = sign*x2[0];
730       } else {
731          /* "Normal" split code */
732          celt_norm *next_lowband2=NULL;
733          celt_norm *next_lowband_out1=NULL;
734          int next_level=0;
735
736 #ifdef FIXED_POINT
737          mid = imid;
738          side = iside;
739 #else
740          mid = (1.f/32768)*imid;
741          side = (1.f/32768)*iside;
742 #endif
743
744          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
745          if (B>1 && !stereo)
746             delta >>= 1;
747
748          mbits = (b-qalloc-delta)/2;
749          if (mbits > b-qalloc)
750             mbits = b-qalloc;
751          if (mbits<0)
752             mbits=0;
753          sbits = b-qalloc-mbits;
754          *remaining_bits -= qalloc;
755
756          if (lowband && !stereo)
757             next_lowband2 = lowband+N; /* >32-bit split case */
758
759          /* Only stereo needs to pass on lowband_out. Otherwise, it's
760             handled at the end */
761          if (stereo)
762             next_lowband_out1 = lowband_out;
763          else
764             next_level = level+1;
765
766          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
767                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
768                NULL, next_level, seed, MULT16_16_P15(gain,mid));
769          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
770                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
771                NULL, next_level, seed, MULT16_16_P15(gain,side));
772       }
773
774    } else {
775       /* This is the basic no-split case */
776       q = bits2pulses(m, i, LM, b);
777       curr_bits = pulses2bits(m, i, LM, q);
778       *remaining_bits -= curr_bits;
779
780       /* Ensures we can never bust the budget */
781       while (*remaining_bits < 0 && q > 0)
782       {
783          *remaining_bits += curr_bits;
784          q--;
785          curr_bits = pulses2bits(m, i, LM, q);
786          *remaining_bits -= curr_bits;
787       }
788
789       /* Finally do the actual quantization */
790       if (encode)
791          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
792       else
793          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed, gain);
794    }
795
796    /* This code is used by the decoder and by the resynthesis-enabled encoder */
797    if (resynth)
798    {
799       int k;
800
801       /* Undo the sample reorganization going from time order to frequency order */
802       if (!stereo && B0>1 && level==0)
803       {
804          interleave_vector(X, N_B, B0);
805          if (lowband)
806             interleave_vector(lowband, N_B, B0);
807       }
808
809       /* Undo time-freq changes that we did earlier */
810       N_B = N_B0;
811       B = B0;
812       for (k=0;k<time_divide;k++)
813       {
814          B >>= 1;
815          N_B <<= 1;
816          haar1(X, N_B, B);
817          if (lowband)
818             haar1(lowband, N_B, B);
819       }
820
821       for (k=0;k<recombine;k++)
822       {
823          haar1(X, N_B, B);
824          if (lowband)
825             haar1(lowband, N_B, B);
826          N_B>>=1;
827          B <<= 1;
828       }
829
830       /* Scale output for later folding */
831       if (lowband_out && !stereo)
832       {
833          int j;
834          celt_word16 n;
835          n = celt_sqrt(SHL32(EXTEND32(N0),22));
836          for (j=0;j<N0;j++)
837             lowband_out[j] = MULT16_16_Q15(n,X[j]);
838       }
839
840       if (stereo)
841          stereo_merge(X, Y, mid, side, N);
842    }
843 }
844
845 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
846 {
847    int i, balance;
848    celt_int32 remaining_bits;
849    const celt_int16 * restrict eBands = m->eBands;
850    celt_norm * restrict norm;
851    VARDECL(celt_norm, _norm);
852    int B;
853    int M;
854    celt_int32 seed;
855    celt_norm *lowband;
856    int update_lowband = 1;
857    int C = _Y != NULL ? 2 : 1;
858    SAVE_STACK;
859
860    M = 1<<LM;
861    B = shortBlocks ? M : 1;
862    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
863    norm = _norm;
864
865    if (encode)
866       seed = ((ec_enc*)ec)->rng;
867    else
868       seed = ((ec_dec*)ec)->rng;
869    balance = 0;
870    lowband = NULL;
871    for (i=start;i<end;i++)
872    {
873       int tell;
874       int b;
875       int N;
876       int curr_balance;
877       celt_norm * restrict X, * restrict Y;
878       int tf_change=0;
879       celt_norm *effective_lowband;
880       
881       X = _X+M*eBands[i];
882       if (_Y!=NULL)
883          Y = _Y+M*eBands[i];
884       else
885          Y = NULL;
886       N = M*eBands[i+1]-M*eBands[i];
887       if (encode)
888          tell = ec_enc_tell((ec_enc*)ec, BITRES);
889       else
890          tell = ec_dec_tell((ec_dec*)ec, BITRES);
891
892       /* Compute how many bits we want to allocate to this band */
893       if (i != start)
894          balance -= tell;
895       remaining_bits = (total_bits<<BITRES)-tell-1;
896       curr_balance = (end-i);
897       if (curr_balance > 3)
898          curr_balance = 3;
899       curr_balance = balance / curr_balance;
900       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
901       if (b<0)
902          b = 0;
903       /* Prevents ridiculous bit depths */
904       if (b > C*16*N<<BITRES)
905          b = C*16*N<<BITRES;
906
907       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
908             lowband = norm+M*eBands[i]-N;
909
910       tf_change = tf_res[i];
911       if (i>=m->effEBands)
912       {
913          X=norm;
914          if (_Y!=NULL)
915             Y = norm;
916       }
917
918       if (tf_change==0 && !shortBlocks && fold)
919          effective_lowband = NULL;
920       else
921          effective_lowband = lowband;
922       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
923             effective_lowband, resynth, ec, &remaining_bits, LM,
924             norm+M*eBands[i], bandE, 0, &seed, Q15ONE);
925
926       balance += pulses[i] + tell;
927
928       /* Update the folding position only as long as we have 2 bit/sample depth */
929       update_lowband = (b>>BITRES)>2*N;
930    }
931    RESTORE_STACK;
932 }
933