Fixes a bug that could turn off folding at low rate when specifying
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48
49 #ifdef FIXED_POINT
50 /* Compute the amplitude (sqrt energy) in each of the bands */
51 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
52 {
53    int i, c, N;
54    const celt_int16 *eBands = m->eBands;
55    const int C = CHANNELS(_C);
56    N = M*m->shortMdctSize;
57    for (c=0;c<C;c++)
58    {
59       for (i=0;i<end;i++)
60       {
61          int j;
62          celt_word32 maxval=0;
63          celt_word32 sum = 0;
64          
65          j=M*eBands[i]; do {
66             maxval = MAX32(maxval, X[j+c*N]);
67             maxval = MAX32(maxval, -X[j+c*N]);
68          } while (++j<M*eBands[i+1]);
69          
70          if (maxval > 0)
71          {
72             int shift = celt_ilog2(maxval)-10;
73             j=M*eBands[i]; do {
74                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
75                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
76             } while (++j<M*eBands[i+1]);
77             /* We're adding one here to make damn sure we never end up with a pitch vector that's
78                larger than unity norm */
79             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
80          } else {
81             bank[i+c*m->nbEBands] = EPSILON;
82          }
83          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
84       }
85    }
86    /*printf ("\n");*/
87 }
88
89 /* Normalise each band such that the energy is one. */
90 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
91 {
92    int i, c, N;
93    const celt_int16 *eBands = m->eBands;
94    const int C = CHANNELS(_C);
95    N = M*m->shortMdctSize;
96    for (c=0;c<C;c++)
97    {
98       i=0; do {
99          celt_word16 g;
100          int j,shift;
101          celt_word16 E;
102          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
103          E = VSHR32(bank[i+c*m->nbEBands], shift);
104          g = EXTRACT16(celt_rcp(SHL32(E,3)));
105          j=M*eBands[i]; do {
106             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
107          } while (++j<M*eBands[i+1]);
108       } while (++i<end);
109    }
110 }
111
112 #else /* FIXED_POINT */
113 /* Compute the amplitude (sqrt energy) in each of the bands */
114 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
115 {
116    int i, c, N;
117    const celt_int16 *eBands = m->eBands;
118    const int C = CHANNELS(_C);
119    N = M*m->shortMdctSize;
120    for (c=0;c<C;c++)
121    {
122       for (i=0;i<end;i++)
123       {
124          int j;
125          celt_word32 sum = 1e-10;
126          for (j=M*eBands[i];j<M*eBands[i+1];j++)
127             sum += X[j+c*N]*X[j+c*N];
128          bank[i+c*m->nbEBands] = sqrt(sum);
129          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
130       }
131    }
132    /*printf ("\n");*/
133 }
134
135 /* Normalise each band such that the energy is one. */
136 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
137 {
138    int i, c, N;
139    const celt_int16 *eBands = m->eBands;
140    const int C = CHANNELS(_C);
141    N = M*m->shortMdctSize;
142    for (c=0;c<C;c++)
143    {
144       for (i=0;i<end;i++)
145       {
146          int j;
147          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
148          for (j=M*eBands[i];j<M*eBands[i+1];j++)
149             X[j+c*N] = freq[j+c*N]*g;
150       }
151    }
152 }
153
154 #endif /* FIXED_POINT */
155
156 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
157 {
158    int i, c;
159    const celt_int16 *eBands = m->eBands;
160    const int C = CHANNELS(_C);
161    for (c=0;c<C;c++)
162    {
163       i=0; do {
164          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
165       } while (++i<end);
166    }
167 }
168
169 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
170 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
171 {
172    int i, c, N;
173    const celt_int16 *eBands = m->eBands;
174    const int C = CHANNELS(_C);
175    N = M*m->shortMdctSize;
176    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
177    for (c=0;c<C;c++)
178    {
179       celt_sig * restrict f;
180       const celt_norm * restrict x;
181       f = freq+c*N;
182       x = X+c*N;
183       for (i=0;i<end;i++)
184       {
185          int j, band_end;
186          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
187          j=M*eBands[i];
188          band_end = M*eBands[i+1];
189          do {
190             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
191             x++;
192          } while (++j<band_end);
193       }
194       for (i=M*eBands[m->nbEBands];i<N;i++)
195          *f++ = 0;
196    }
197 }
198
199 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
200 {
201    int j, c;
202    celt_word16 g;
203    celt_word16 delta;
204    const int C = CHANNELS(_C);
205    celt_word32 Sxy=0, Sxx=0, Syy=0;
206    int len = M*m->pitchEnd;
207    int N = M*m->shortMdctSize;
208 #ifdef FIXED_POINT
209    int shift = 0;
210    celt_word32 maxabs=0;
211
212    for (c=0;c<C;c++)
213    {
214       for (j=0;j<len;j++)
215       {
216          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
217          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
218       }
219    }
220    shift = celt_ilog2(maxabs)-12;
221    if (shift<0)
222       shift = 0;
223 #endif
224    delta = PDIV32_16(Q15ONE, len);
225    for (c=0;c<C;c++)
226    {
227       celt_word16 gg = Q15ONE;
228       for (j=0;j<len;j++)
229       {
230          celt_word16 Xj, Pj;
231          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
232          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
233          Sxy = MAC16_16(Sxy, Xj, Pj);
234          Sxx = MAC16_16(Sxx, Pj, Pj);
235          Syy = MAC16_16(Syy, Xj, Xj);
236          gg = SUB16(gg, delta);
237       }
238    }
239 #ifdef FIXED_POINT
240    {
241       celt_word32 num, den;
242       celt_word16 fact;
243       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
244       if (fact < QCONST16(1.f, 14))
245          fact = QCONST16(1.f, 14);
246       num = Sxy;
247       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
248       shift = celt_zlog2(Sxy)-16;
249       if (shift < 0)
250          shift = 0;
251       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
252          g = 0;
253       else
254          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
255
256       /* This MUST round down so that we don't over-estimate the gain */
257       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
258    }
259 #else
260    {
261       float fact = .04f*norm_rate;
262       if (fact < 1)
263          fact = 1;
264       g = Sxy/(.1f+Sxx+.03f*Syy);
265       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
266          g = 0;
267       /* This MUST round down so that we don't over-estimate the gain */
268       *gain_id = floor(20*(g-.5f));
269    }
270 #endif
271    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
272       maximum error amplification factor to 2.0 */
273    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
274    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
275    if (*gain_prod>QCONST32(2.f, 13))
276    {
277       *gain_id=9;
278       *gain_prod = QCONST32(2.f, 13);
279    }
280
281    if (*gain_id < 0)
282    {
283       *gain_id = 0;
284       return 0;
285    } else {
286       if (*gain_id > 15)
287          *gain_id = 15;
288       return 1;
289    }
290 }
291
292 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
293 {
294    int j, c, N;
295    celt_word16 gain;
296    celt_word16 delta;
297    const int C = CHANNELS(_C);
298    int len = M*m->pitchEnd;
299
300    N = M*m->shortMdctSize;
301    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
302    delta = PDIV32_16(gain, len);
303    if (pred)
304       gain = -gain;
305    else
306       delta = -delta;
307    for (c=0;c<C;c++)
308    {
309       celt_word16 gg = gain;
310       for (j=0;j<len;j++)
311       {
312          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
313          gg = ADD16(gg, delta);
314       }
315    }
316 }
317
318 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
319 {
320    int i = bandID;
321    int j;
322    celt_word16 a1, a2;
323    if (stereo_mode==0)
324    {
325       /* Do mid-side when not doing intensity stereo */
326       a1 = QCONST16(.70711f,14);
327       a2 = dir*QCONST16(.70711f,14);
328    } else {
329       celt_word16 left, right;
330       celt_word16 norm;
331 #ifdef FIXED_POINT
332       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
333 #endif
334       left = VSHR32(bank[i],shift);
335       right = VSHR32(bank[i+m->nbEBands],shift);
336       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
337       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
338       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
339    }
340    for (j=0;j<N;j++)
341    {
342       celt_norm r, l;
343       l = X[j];
344       r = Y[j];
345       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
346       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
347    }
348 }
349
350
351 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
352 {
353    int i, c, N0;
354    int NR=0;
355    celt_word32 ratio = EPSILON;
356    const int C = CHANNELS(_C);
357    const celt_int16 * restrict eBands = m->eBands;
358    
359    N0 = M*m->shortMdctSize;
360
361    for (c=0;c<C;c++)
362    {
363    for (i=0;i<end;i++)
364    {
365       int j, N;
366       int max_i=0;
367       celt_word16 max_val=EPSILON;
368       celt_word32 floor_ener=EPSILON;
369       celt_norm * restrict x = X+M*eBands[i]+c*N0;
370       N = M*eBands[i+1]-M*eBands[i];
371       for (j=0;j<N;j++)
372       {
373          if (ABS16(x[j])>max_val)
374          {
375             max_val = ABS16(x[j]);
376             max_i = j;
377          }
378       }
379 #if 0
380       for (j=0;j<N;j++)
381       {
382          if (abs(j-max_i)>2)
383             floor_ener += x[j]*x[j];
384       }
385 #else
386       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
387       if (max_i < N-1)
388          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
389       if (max_i < N-2)
390          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
391       if (max_i > 0)
392          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
393       if (max_i > 1)
394          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
395       floor_ener = MAX32(floor_ener, EPSILON);
396 #endif
397       if (N>7)
398       {
399          celt_word16 r;
400          celt_word16 den = celt_sqrt(floor_ener);
401          den = MAX32(QCONST16(.02f, 15), den);
402          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
403          ratio = ADD32(ratio, EXTEND32(r));
404          NR++;
405       }
406    }
407    }
408    if (NR>0)
409       ratio = DIV32_16(ratio, NR);
410    ratio = ADD32(HALF32(ratio), HALF32(*average));
411    if (!*last_decision)
412    {
413       *last_decision = (ratio < QCONST16(1.8f,8));
414    } else {
415       *last_decision = (ratio < QCONST16(3.f,8));
416    }
417    *average = EXTRACT16(ratio);
418    return *last_decision;
419 }
420
421 static void interleave_vector(celt_norm *X, int N0, int stride)
422 {
423    int i,j;
424    VARDECL(celt_norm, tmp);
425    int N;
426    SAVE_STACK;
427    N = N0*stride;
428    ALLOC(tmp, N, celt_norm);
429    for (i=0;i<stride;i++)
430       for (j=0;j<N0;j++)
431          tmp[j*stride+i] = X[i*N0+j];
432    for (j=0;j<N;j++)
433       X[j] = tmp[j];
434    RESTORE_STACK;
435 }
436
437 static void deinterleave_vector(celt_norm *X, int N0, int stride)
438 {
439    int i,j;
440    VARDECL(celt_norm, tmp);
441    int N;
442    SAVE_STACK;
443    N = N0*stride;
444    ALLOC(tmp, N, celt_norm);
445    for (i=0;i<stride;i++)
446       for (j=0;j<N0;j++)
447          tmp[i*N0+j] = X[j*stride+i];
448    for (j=0;j<N;j++)
449       X[j] = tmp[j];
450    RESTORE_STACK;
451 }
452
453 static void haar1(celt_norm *X, int N0, int stride)
454 {
455    int i, j;
456    N0 >>= 1;
457    for (i=0;i<stride;i++)
458       for (j=0;j<N0;j++)
459       {
460          celt_norm tmp = X[stride*2*j+i];
461          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
462          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
463       }
464 }
465
466 /* This function is responsible for encoding and decoding a band for both
467    the mono and stereo case. Even in the mono case, it can split the band
468    in two and transmit the energy difference with the two half-bands. It
469    can be called recursively so bands can end up being split in 8 parts. */
470 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
471       int N, int b, int spread, int tf_change, celt_norm *lowband, int resynth, void *ec,
472       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level)
473 {
474    int q;
475    int curr_bits;
476    int stereo, split;
477    int imid=0, iside=0;
478    int N0=N;
479    int N_B=N;
480    int N_B0;
481    int spread0=spread;
482    int time_divide=0;
483    int recombine=0;
484
485    if (spread)
486       N_B /= spread;
487    N_B0 = N_B;
488
489    split = stereo = Y != NULL;
490
491    /* Special case for one sample */
492    if (N==1)
493    {
494       int c;
495       celt_norm *x = X;
496       for (c=0;c<1+stereo;c++)
497       {
498          int sign=0;
499          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
500          {
501             if (encode)
502             {
503                sign = x[0]<0;
504                ec_enc_bits((ec_enc*)ec, sign, 1);
505             } else {
506                sign = ec_dec_bits((ec_dec*)ec, 1);
507             }
508             *remaining_bits -= 1<<BITRES;
509             b-=1<<BITRES;
510          }
511          if (resynth)
512             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
513          x = Y;
514       }
515       if (lowband_out)
516          lowband_out[0] = X[0];
517       return;
518    }
519
520    /* Band recombining to increase frequency resolution */
521    if (!stereo && spread > 1 && level == 0 && tf_change>0)
522    {
523       while (spread>1 && tf_change>0)
524       {
525          spread>>=1;
526          N_B<<=1;
527          if (encode)
528             haar1(X, N_B, spread);
529          if (lowband)
530             haar1(lowband, N_B, spread);
531          recombine++;
532          tf_change--;
533       }
534       spread0=spread;
535       N_B0 = N_B;
536    }
537
538    /* Increasing the time resolution */
539    if (!stereo && level==0)
540    {
541       while ((N_B&1) == 0 && tf_change<0 && spread <= (1<<LM))
542       {
543          if (encode)
544             haar1(X, N_B, spread);
545          if (lowband)
546             haar1(lowband, N_B, spread);
547          spread <<= 1;
548          N_B >>= 1;
549          time_divide++;
550          tf_change++;
551       }
552       spread0 = spread;
553       N_B0 = N_B;
554    }
555
556    /* Reorganize the samples in time order instead of frequency order */
557    if (!stereo && spread0>1 && level==0)
558    {
559       if (encode)
560          deinterleave_vector(X, N_B, spread0);
561       if (lowband)
562          deinterleave_vector(lowband, N_B, spread0);
563    }
564
565    /* If we need more than 32 bits, try splitting the band in two. */
566    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
567    {
568       if (LM>0 || (N&1)==0)
569       {
570          N >>= 1;
571          Y = X+N;
572          split = 1;
573          LM -= 1;
574          spread = (spread+1)>>1;
575       }
576    }
577
578    if (split)
579    {
580       int qb;
581       int itheta=0;
582       int mbits, sbits, delta;
583       int qalloc;
584       celt_word16 mid, side;
585       int offset, N2;
586       offset = m->logN[i]+(LM<<BITRES)-QTHETA_OFFSET;
587
588       /* Decide on the resolution to give to the split parameter theta */
589       N2 = 2*N-1;
590       if (stereo && N>2)
591          N2--;
592       qb = (b+N2*offset)/(N2<<BITRES);
593       if (qb > (b>>(BITRES+1))-1)
594          qb = (b>>(BITRES+1))-1;
595
596       if (qb<0)
597          qb = 0;
598       if (qb>14)
599          qb = 14;
600
601       qalloc = 0;
602       if (qb!=0)
603       {
604          int shift;
605          shift = 14-qb;
606
607          if (encode)
608          {
609             if (stereo)
610                stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, N);
611
612             mid = renormalise_vector(X, Q15ONE, N, 1);
613             side = renormalise_vector(Y, Q15ONE, N, 1);
614
615             /* theta is the atan() of the ration between the (normalized)
616                side and mid. With just that parameter, we can re-scale both
617                mid and side because we know that 1) they have unit norm and
618                2) they are orthogonal. */
619    #ifdef FIXED_POINT
620             /* 0.63662 = 2/pi */
621             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
622    #else
623             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
624    #endif
625
626             itheta = (itheta+(1<<shift>>1))>>shift;
627          }
628
629          /* Entropy coding of the angle. We use a uniform pdf for the
630             first stereo split but a triangular one for the rest. */
631          if (stereo || qb>9 || spread>1)
632          {
633             if (encode)
634                ec_enc_uint((ec_enc*)ec, itheta, (1<<qb)+1);
635             else
636                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
637             qalloc = log2_frac((1<<qb)+1,BITRES);
638          } else {
639             int fs=1, ft;
640             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
641             if (encode)
642             {
643                int j;
644                int fl=0;
645                j=0;
646                while(1)
647                {
648                   if (j==itheta)
649                      break;
650                   fl+=fs;
651                   if (j<(1<<qb>>1))
652                      fs++;
653                   else
654                      fs--;
655                   j++;
656                }
657                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
658             } else {
659                int fl=0;
660                int j, fm;
661                fm = ec_decode((ec_dec*)ec, ft);
662                j=0;
663                while (1)
664                {
665                   if (fm < fl+fs)
666                      break;
667                   fl+=fs;
668                   if (j<(1<<qb>>1))
669                      fs++;
670                   else
671                      fs--;
672                   j++;
673                }
674                itheta = j;
675                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
676             }
677             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
678          }
679          itheta <<= shift;
680       }
681
682       if (itheta == 0)
683       {
684          imid = 32767;
685          iside = 0;
686          delta = -10000;
687       } else if (itheta == 16384)
688       {
689          imid = 0;
690          iside = 32767;
691          delta = 10000;
692       } else {
693          imid = bitexact_cos(itheta);
694          iside = bitexact_cos(16384-itheta);
695          /* This is the mid vs side allocation that minimizes squared error
696             in that band. */
697          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
698       }
699
700       /* This is a special case for N=2 that only works for stereo and takes
701          advantage of the fact that mid and side are orthogonal to encode
702          the side with just one bit. */
703       if (N==2 && stereo)
704       {
705          int c, c2;
706          int sign=1;
707          celt_norm v[2], w[2];
708          celt_norm *x2, *y2;
709          mbits = b-qalloc;
710          sbits = 0;
711          if (itheta != 0 && itheta != 16384)
712             sbits = 1<<BITRES;
713          mbits -= sbits;
714          c = itheta > 8192 ? 1 : 0;
715          *remaining_bits -= qalloc+sbits;
716
717          x2 = X;
718          y2 = Y;
719          if (encode)
720          {
721             c2 = 1-c;
722
723             if (c==0)
724             {
725                v[0] = x2[0];
726                v[1] = x2[1];
727                w[0] = y2[0];
728                w[1] = y2[1];
729             } else {
730                v[0] = y2[0];
731                v[1] = y2[1];
732                w[0] = x2[0];
733                w[1] = x2[1];
734             }
735             /* Here we only need to encode a sign for the side */
736             if (v[0]*w[1] - v[1]*w[0] > 0)
737                sign = 1;
738             else
739                sign = -1;
740          }
741          quant_band(encode, m, i, v, NULL, N, mbits, spread, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1);
742          if (sbits)
743          {
744             if (encode)
745             {
746                ec_enc_bits((ec_enc*)ec, sign==1, 1);
747             } else {
748                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
749             }
750          } else {
751             sign = 1;
752          }
753          w[0] = -sign*v[1];
754          w[1] = sign*v[0];
755          if (c==0)
756          {
757             x2[0] = v[0];
758             x2[1] = v[1];
759             y2[0] = w[0];
760             y2[1] = w[1];
761          } else {
762             x2[0] = w[0];
763             x2[1] = w[1];
764             y2[0] = v[0];
765             y2[1] = v[1];
766          }
767       } else
768       {
769          /* "Normal" split code */
770          celt_norm *next_lowband2=NULL;
771          celt_norm *next_lowband_out1=NULL;
772          int next_level=0;
773
774          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
775          if (spread>1 && !stereo)
776             delta >>= 1;
777
778          mbits = (b-qalloc/2-delta)/2;
779          if (mbits > b-qalloc)
780             mbits = b-qalloc;
781          if (mbits<0)
782             mbits=0;
783          sbits = b-qalloc-mbits;
784          *remaining_bits -= qalloc;
785
786          if (lowband && !stereo)
787             next_lowband2 = lowband+N;
788          if (stereo)
789             next_lowband_out1 = lowband_out;
790          else
791             next_level = level+1;
792
793          quant_band(encode, m, i, X, NULL, N, mbits, spread, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level);
794          quant_band(encode, m, i, Y, NULL, N, sbits, spread, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level);
795       }
796
797    } else {
798       /* This is the basic no-split case */
799       q = bits2pulses(m, m->bits[LM][i], N, b);
800       curr_bits = pulses2bits(m->bits[LM][i], N, q);
801       *remaining_bits -= curr_bits;
802
803       /* Ensures we can never bust the budget */
804       while (*remaining_bits < 0 && q > 0)
805       {
806          *remaining_bits += curr_bits;
807          q--;
808          curr_bits = pulses2bits(m->bits[LM][i], N, q);
809          *remaining_bits -= curr_bits;
810       }
811
812       if (encode)
813          alg_quant(X, N, q, spread, lowband, resynth, (ec_enc*)ec);
814       else
815          alg_unquant(X, N, q, spread, lowband, (ec_dec*)ec);
816    }
817
818    /* This code is used by the decoder and by the resynthesis-enabled encoder */
819    if (resynth)
820    {
821       int k;
822
823       if (split)
824       {
825          int j;
826          celt_word16 mid, side;
827 #ifdef FIXED_POINT
828          mid = imid;
829          side = iside;
830 #else
831          mid = (1.f/32768)*imid;
832          side = (1.f/32768)*iside;
833 #endif
834          for (j=0;j<N;j++)
835             X[j] = MULT16_16_Q15(X[j], mid);
836          for (j=0;j<N;j++)
837             Y[j] = MULT16_16_Q15(Y[j], side);
838       }
839
840       if (!stereo && spread0>1 && level==0)
841       {
842          interleave_vector(X, N_B, spread0);
843          if (lowband)
844             interleave_vector(lowband, N_B, spread0);
845       }
846
847       /* Undo time-freq changes that we did earlier */
848       N_B = N_B0;
849       spread = spread0;
850       for (k=0;k<time_divide;k++)
851       {
852          spread >>= 1;
853          N_B <<= 1;
854          haar1(X, N_B, spread);
855          if (lowband)
856             haar1(lowband, N_B, spread);
857       }
858
859       for (k=0;k<recombine;k++)
860       {
861          haar1(X, N_B, spread);
862          if (lowband)
863             haar1(lowband, N_B, spread);
864          N_B>>=1;
865          spread <<= 1;
866       }
867
868       if (lowband_out && !stereo)
869       {
870          int j;
871          celt_word16 n;
872          n = celt_sqrt(SHL32(EXTEND32(N0),22));
873          for (j=0;j<N0;j++)
874             lowband_out[j] = MULT16_16_Q15(n,X[j]);
875       }
876
877       if (stereo)
878       {
879          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
880          renormalise_vector(X, Q15ONE, N, 1);
881          renormalise_vector(Y, Q15ONE, N, 1);
882       }
883    }
884 }
885
886 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
887 {
888    int i, balance;
889    celt_int32 remaining_bits;
890    const celt_int16 * restrict eBands = m->eBands;
891    celt_norm * restrict norm;
892    VARDECL(celt_norm, _norm);
893    int B;
894    int M;
895    int spread;
896    celt_norm *lowband;
897    int update_lowband = 1;
898    int C = _Y != NULL ? 2 : 1;
899    SAVE_STACK;
900
901    M = 1<<LM;
902    B = shortBlocks ? M : 1;
903    spread = fold ? B : 0;
904    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
905    norm = _norm;
906
907    balance = 0;
908    lowband = NULL;
909    for (i=start;i<end;i++)
910    {
911       int tell;
912       int b;
913       int N;
914       int curr_balance;
915       celt_norm * restrict X, * restrict Y;
916       int tf_change=0;
917       
918       X = _X+M*eBands[i];
919       if (_Y!=NULL)
920          Y = _Y+M*eBands[i];
921       else
922          Y = NULL;
923       N = M*eBands[i+1]-M*eBands[i];
924       if (encode)
925          tell = ec_enc_tell((ec_enc*)ec, BITRES);
926       else
927          tell = ec_dec_tell((ec_dec*)ec, BITRES);
928
929       if (i != start)
930          balance -= tell;
931       remaining_bits = (total_bits<<BITRES)-tell-1;
932       curr_balance = (end-i);
933       if (curr_balance > 3)
934          curr_balance = 3;
935       curr_balance = balance / curr_balance;
936       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
937       if (b<0)
938          b = 0;
939       /* Prevents ridiculous bit depths */
940       if (b > C*16*N<<BITRES)
941          b = C*16*N<<BITRES;
942
943       if (M*eBands[i]-N >= M*eBands[start])
944       {
945          if (update_lowband || lowband==NULL)
946             lowband = norm+M*eBands[i]-N;
947       } else
948          lowband = NULL;
949
950       tf_change = tf_res[i];
951       if (i>=m->effEBands)
952       {
953          X=norm;
954          if (_Y!=NULL)
955             Y = norm;
956       }
957       quant_band(encode, m, i, X, Y, N, b, spread, tf_change, lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0);
958
959       balance += pulses[i] + tell;
960
961       /* Update the folding position only as long as we have 2 bit/sample depth */
962       update_lowband = (b>>BITRES)>2*N;
963    }
964    RESTORE_STACK;
965 }
966