Fixing the qtheta dependency for the delta allocation
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
216 {
217    int j, c;
218    celt_word16 g;
219    celt_word16 delta;
220    const int C = CHANNELS(_C);
221    celt_word32 Sxy=0, Sxx=0, Syy=0;
222    int len = M*m->pitchEnd;
223    int N = M*m->shortMdctSize;
224 #ifdef FIXED_POINT
225    int shift = 0;
226    celt_word32 maxabs=0;
227
228    for (c=0;c<C;c++)
229    {
230       for (j=0;j<len;j++)
231       {
232          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
233          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
234       }
235    }
236    shift = celt_ilog2(maxabs)-12;
237    if (shift<0)
238       shift = 0;
239 #endif
240    delta = PDIV32_16(Q15ONE, len);
241    for (c=0;c<C;c++)
242    {
243       celt_word16 gg = Q15ONE;
244       for (j=0;j<len;j++)
245       {
246          celt_word16 Xj, Pj;
247          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
248          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
249          Sxy = MAC16_16(Sxy, Xj, Pj);
250          Sxx = MAC16_16(Sxx, Pj, Pj);
251          Syy = MAC16_16(Syy, Xj, Xj);
252          gg = SUB16(gg, delta);
253       }
254    }
255 #ifdef FIXED_POINT
256    {
257       celt_word32 num, den;
258       celt_word16 fact;
259       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
260       if (fact < QCONST16(1.f, 14))
261          fact = QCONST16(1.f, 14);
262       num = Sxy;
263       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
264       shift = celt_zlog2(Sxy)-16;
265       if (shift < 0)
266          shift = 0;
267       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
268          g = 0;
269       else
270          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
271
272       /* This MUST round down so that we don't over-estimate the gain */
273       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
274    }
275 #else
276    {
277       float fact = .04f*norm_rate;
278       if (fact < 1)
279          fact = 1;
280       g = Sxy/(.1f+Sxx+.03f*Syy);
281       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
282          g = 0;
283       /* This MUST round down so that we don't over-estimate the gain */
284       *gain_id = floor(20*(g-.5f));
285    }
286 #endif
287    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
288       maximum error amplification factor to 2.0 */
289    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
290    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
291    if (*gain_prod>QCONST32(2.f, 13))
292    {
293       *gain_id=9;
294       *gain_prod = QCONST32(2.f, 13);
295    }
296
297    if (*gain_id < 0)
298    {
299       *gain_id = 0;
300       return 0;
301    } else {
302       if (*gain_id > 15)
303          *gain_id = 15;
304       return 1;
305    }
306 }
307
308 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
309 {
310    int j, c, N;
311    celt_word16 gain;
312    celt_word16 delta;
313    const int C = CHANNELS(_C);
314    int len = M*m->pitchEnd;
315
316    N = M*m->shortMdctSize;
317    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
318    delta = PDIV32_16(gain, len);
319    if (pred)
320       gain = -gain;
321    else
322       delta = -delta;
323    for (c=0;c<C;c++)
324    {
325       celt_word16 gg = gain;
326       for (j=0;j<len;j++)
327       {
328          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
329          gg = ADD16(gg, delta);
330       }
331    }
332 }
333
334 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
335 {
336    int i = bandID;
337    int j;
338    celt_word16 a1, a2;
339    if (stereo_mode==0)
340    {
341       /* Do mid-side when not doing intensity stereo */
342       a1 = QCONST16(.70711f,14);
343       a2 = dir*QCONST16(.70711f,14);
344    } else {
345       celt_word16 left, right;
346       celt_word16 norm;
347 #ifdef FIXED_POINT
348       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
349 #endif
350       left = VSHR32(bank[i],shift);
351       right = VSHR32(bank[i+m->nbEBands],shift);
352       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
353       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
354       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
355    }
356    for (j=0;j<N;j++)
357    {
358       celt_norm r, l;
359       l = X[j];
360       r = Y[j];
361       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
362       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
363    }
364 }
365
366 /* Decide whether we should spread the pulses in the current frame */
367 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
368 {
369    int i, c, N0;
370    int NR=0;
371    celt_word32 ratio = EPSILON;
372    const int C = CHANNELS(_C);
373    const celt_int16 * restrict eBands = m->eBands;
374    
375    N0 = M*m->shortMdctSize;
376
377    for (c=0;c<C;c++)
378    {
379    for (i=0;i<end;i++)
380    {
381       int j, N;
382       int max_i=0;
383       celt_word16 max_val=EPSILON;
384       celt_word32 floor_ener=EPSILON;
385       celt_norm * restrict x = X+M*eBands[i]+c*N0;
386       N = M*eBands[i+1]-M*eBands[i];
387       for (j=0;j<N;j++)
388       {
389          if (ABS16(x[j])>max_val)
390          {
391             max_val = ABS16(x[j]);
392             max_i = j;
393          }
394       }
395 #if 0
396       for (j=0;j<N;j++)
397       {
398          if (abs(j-max_i)>2)
399             floor_ener += x[j]*x[j];
400       }
401 #else
402       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
403       if (max_i < N-1)
404          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
405       if (max_i < N-2)
406          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
407       if (max_i > 0)
408          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
409       if (max_i > 1)
410          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
411       floor_ener = MAX32(floor_ener, EPSILON);
412 #endif
413       if (N>7)
414       {
415          celt_word16 r;
416          celt_word16 den = celt_sqrt(floor_ener);
417          den = MAX32(QCONST16(.02f, 15), den);
418          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
419          ratio = ADD32(ratio, EXTEND32(r));
420          NR++;
421       }
422    }
423    }
424    if (NR>0)
425       ratio = DIV32_16(ratio, NR);
426    ratio = ADD32(HALF32(ratio), HALF32(*average));
427    if (!*last_decision)
428    {
429       *last_decision = (ratio < QCONST16(1.8f,8));
430    } else {
431       *last_decision = (ratio < QCONST16(3.f,8));
432    }
433    *average = EXTRACT16(ratio);
434    return *last_decision;
435 }
436
437 static void interleave_vector(celt_norm *X, int N0, int stride)
438 {
439    int i,j;
440    VARDECL(celt_norm, tmp);
441    int N;
442    SAVE_STACK;
443    N = N0*stride;
444    ALLOC(tmp, N, celt_norm);
445    for (i=0;i<stride;i++)
446       for (j=0;j<N0;j++)
447          tmp[j*stride+i] = X[i*N0+j];
448    for (j=0;j<N;j++)
449       X[j] = tmp[j];
450    RESTORE_STACK;
451 }
452
453 static void deinterleave_vector(celt_norm *X, int N0, int stride)
454 {
455    int i,j;
456    VARDECL(celt_norm, tmp);
457    int N;
458    SAVE_STACK;
459    N = N0*stride;
460    ALLOC(tmp, N, celt_norm);
461    for (i=0;i<stride;i++)
462       for (j=0;j<N0;j++)
463          tmp[i*N0+j] = X[j*stride+i];
464    for (j=0;j<N;j++)
465       X[j] = tmp[j];
466    RESTORE_STACK;
467 }
468
469 static void haar1(celt_norm *X, int N0, int stride)
470 {
471    int i, j;
472    N0 >>= 1;
473    for (i=0;i<stride;i++)
474       for (j=0;j<N0;j++)
475       {
476          celt_norm tmp = X[stride*2*j+i];
477          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
478          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
479       }
480 }
481
482 /* This function is responsible for encoding and decoding a band for both
483    the mono and stereo case. Even in the mono case, it can split the band
484    in two and transmit the energy difference with the two half-bands. It
485    can be called recursively so bands can end up being split in 8 parts. */
486 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
487       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
488       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level, celt_int32 *seed)
489 {
490    int q;
491    int curr_bits;
492    int stereo, split;
493    int imid=0, iside=0;
494    int N0=N;
495    int N_B=N;
496    int N_B0;
497    int B0=B;
498    int time_divide=0;
499    int recombine=0;
500
501    N_B /= B;
502    N_B0 = N_B;
503
504    split = stereo = Y != NULL;
505
506    /* Special case for one sample */
507    if (N==1)
508    {
509       int c;
510       celt_norm *x = X;
511       for (c=0;c<1+stereo;c++)
512       {
513          int sign=0;
514          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
515          {
516             if (encode)
517             {
518                sign = x[0]<0;
519                ec_enc_bits((ec_enc*)ec, sign, 1);
520             } else {
521                sign = ec_dec_bits((ec_dec*)ec, 1);
522             }
523             *remaining_bits -= 1<<BITRES;
524             b-=1<<BITRES;
525          }
526          if (resynth)
527             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
528          x = Y;
529       }
530       if (lowband_out)
531          lowband_out[0] = X[0];
532       return;
533    }
534
535    /* Band recombining to increase frequency resolution */
536    if (!stereo && B > 1 && level == 0 && tf_change>0)
537    {
538       while (B>1 && tf_change>0)
539       {
540          B>>=1;
541          N_B<<=1;
542          if (encode)
543             haar1(X, N_B, B);
544          if (lowband)
545             haar1(lowband, N_B, B);
546          recombine++;
547          tf_change--;
548       }
549       B0=B;
550       N_B0 = N_B;
551    }
552
553    /* Increasing the time resolution */
554    if (!stereo && level==0)
555    {
556       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
557       {
558          if (encode)
559             haar1(X, N_B, B);
560          if (lowband)
561             haar1(lowband, N_B, B);
562          B <<= 1;
563          N_B >>= 1;
564          time_divide++;
565          tf_change++;
566       }
567       B0 = B;
568       N_B0 = N_B;
569    }
570
571    /* Reorganize the samples in time order instead of frequency order */
572    if (!stereo && B0>1 && level==0)
573    {
574       if (encode)
575          deinterleave_vector(X, N_B, B0);
576       if (lowband)
577          deinterleave_vector(lowband, N_B, B0);
578    }
579
580    /* If we need more than 32 bits, try splitting the band in two. */
581    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
582    {
583       if (LM>0 || (N&1)==0)
584       {
585          N >>= 1;
586          Y = X+N;
587          split = 1;
588          LM -= 1;
589          B = (B+1)>>1;
590       }
591    }
592
593    if (split)
594    {
595       int qb;
596       int itheta=0;
597       int mbits, sbits, delta;
598       int qalloc;
599       celt_word16 mid, side;
600       int offset, N2;
601       offset = ((m->logN[i]+(LM<<BITRES))>>1)-QTHETA_OFFSET;
602
603       /* Decide on the resolution to give to the split parameter theta */
604       N2 = 2*N-1;
605       if (stereo && N==2)
606          N2--;
607       qb = (b+N2*offset)/(N2<<BITRES);
608       if (qb > (b>>(BITRES+1))-1)
609          qb = (b>>(BITRES+1))-1;
610
611       if (qb<0)
612          qb = 0;
613       if (qb>14)
614          qb = 14;
615
616       qalloc = 0;
617       if (qb!=0)
618       {
619          int shift;
620          shift = 14-qb;
621
622          if (encode)
623          {
624             if (stereo)
625                stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, N);
626
627             mid = renormalise_vector(X, Q15ONE, N, 1);
628             side = renormalise_vector(Y, Q15ONE, N, 1);
629
630             /* theta is the atan() of the ratio between the (normalized)
631                side and mid. With just that parameter, we can re-scale both
632                mid and side because we know that 1) they have unit norm and
633                2) they are orthogonal. */
634    #ifdef FIXED_POINT
635             /* 0.63662 = 2/pi */
636             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
637    #else
638             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
639    #endif
640
641             itheta = (itheta+(1<<shift>>1))>>shift;
642          }
643
644          /* Entropy coding of the angle. We use a uniform pdf for the
645             first stereo split but a triangular one for the rest. */
646          if (stereo || qb>8 || B>1)
647          {
648             if (encode)
649                ec_enc_uint((ec_enc*)ec, itheta, (1<<qb)+1);
650             else
651                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
652             qalloc = log2_frac((1<<qb)+1,BITRES);
653          } else {
654             int fs=1, ft;
655             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
656             if (encode)
657             {
658                int fl;
659
660                fs = itheta <= (1<<qb>>1) ? itheta + 1 : (1<<qb) + 1 - itheta;
661                fl = itheta <= (1<<qb>>1) ? itheta*(itheta + 1)>>1 :
662                 ft - (((1<<qb) + 1 - itheta)*((1<<qb) + 2 - itheta)>>1);
663
664                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
665             } else {
666                int fl=0;
667                int fm;
668                fm = ec_decode((ec_dec*)ec, ft);
669
670                if (fm < ((1<<qb>>1)*((1<<qb>>1) + 1)>>1))
671                {
672                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
673                   fs = itheta + 1;
674                   fl = itheta*(itheta + 1)>>1;
675                }
676                else
677                {
678                   itheta = (2*((1<<qb) + 1)
679                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
680                   fs = (1<<qb) + 1 - itheta;
681                   fl = ft - (((1<<qb) + 1 - itheta)*((1<<qb) + 2 - itheta)>>1);
682                }
683
684                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
685             }
686             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
687          }
688          itheta <<= shift;
689       }
690
691       if (itheta == 0)
692       {
693          imid = 32767;
694          iside = 0;
695          delta = -10000;
696       } else if (itheta == 16384)
697       {
698          imid = 0;
699          iside = 32767;
700          delta = 10000;
701       } else {
702          imid = bitexact_cos(itheta);
703          iside = bitexact_cos(16384-itheta);
704          /* This is the mid vs side allocation that minimizes squared error
705             in that band. */
706          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
707       }
708
709       /* This is a special case for N=2 that only works for stereo and takes
710          advantage of the fact that mid and side are orthogonal to encode
711          the side with just one bit. */
712       if (N==2 && stereo)
713       {
714          int c, c2;
715          int sign=1;
716          celt_norm v[2], w[2];
717          celt_norm *x2, *y2;
718          mbits = b-qalloc;
719          sbits = 0;
720          /* Only need one bit for the side */
721          if (itheta != 0 && itheta != 16384)
722             sbits = 1<<BITRES;
723          mbits -= sbits;
724          c = itheta > 8192 ? 1 : 0;
725          *remaining_bits -= qalloc+sbits;
726
727          x2 = X;
728          y2 = Y;
729          if (encode)
730          {
731             c2 = 1-c;
732
733             /* v is the largest vector between mid and side. w is the other */
734             if (c==0)
735             {
736                v[0] = x2[0];
737                v[1] = x2[1];
738                w[0] = y2[0];
739                w[1] = y2[1];
740             } else {
741                v[0] = y2[0];
742                v[1] = y2[1];
743                w[0] = x2[0];
744                w[1] = x2[1];
745             }
746             /* Here we only need to encode a sign for the side */
747             if (v[0]*w[1] - v[1]*w[0] > 0)
748                sign = 1;
749             else
750                sign = -1;
751          }
752          quant_band(encode, m, i, v, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed);
753          if (sbits)
754          {
755             if (encode)
756             {
757                ec_enc_bits((ec_enc*)ec, sign==1, 1);
758             } else {
759                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
760             }
761          } else {
762             sign = 1;
763          }
764          w[0] = -sign*v[1];
765          w[1] = sign*v[0];
766          if (c==0)
767          {
768             x2[0] = v[0];
769             x2[1] = v[1];
770             y2[0] = w[0];
771             y2[1] = w[1];
772          } else {
773             x2[0] = w[0];
774             x2[1] = w[1];
775             y2[0] = v[0];
776             y2[1] = v[1];
777          }
778       } else
779       {
780          /* "Normal" split code */
781          celt_norm *next_lowband2=NULL;
782          celt_norm *next_lowband_out1=NULL;
783          int next_level=0;
784
785          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
786          if (B>1 && !stereo)
787             delta >>= 1;
788
789          mbits = (b-qalloc-delta)/2;
790          if (mbits > b-qalloc)
791             mbits = b-qalloc;
792          if (mbits<0)
793             mbits=0;
794          sbits = b-qalloc-mbits;
795          *remaining_bits -= qalloc;
796
797          if (lowband && !stereo)
798             next_lowband2 = lowband+N; /* >32-bit split case */
799
800          /* Only stereo needs to pass on lowband_out. Otherwise, it's handled at the end */
801          if (stereo)
802             next_lowband_out1 = lowband_out;
803          else
804             next_level = level+1;
805
806          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level, seed);
807          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level, seed);
808       }
809
810    } else {
811       /* This is the basic no-split case */
812       q = bits2pulses(m, m->bits[LM][i], N, b);
813       curr_bits = pulses2bits(m->bits[LM][i], N, q);
814       *remaining_bits -= curr_bits;
815
816       /* Ensures we can never bust the budget */
817       while (*remaining_bits < 0 && q > 0)
818       {
819          *remaining_bits += curr_bits;
820          q--;
821          curr_bits = pulses2bits(m->bits[LM][i], N, q);
822          *remaining_bits -= curr_bits;
823       }
824
825       /* Finally do the actual quantization */
826       if (encode)
827          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed);
828       else
829          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed);
830    }
831
832    /* This code is used by the decoder and by the resynthesis-enabled encoder */
833    if (resynth)
834    {
835       int k;
836
837       if (split)
838       {
839          int j;
840          celt_word16 mid, side;
841 #ifdef FIXED_POINT
842          mid = imid;
843          side = iside;
844 #else
845          mid = (1.f/32768)*imid;
846          side = (1.f/32768)*iside;
847 #endif
848          for (j=0;j<N;j++)
849             X[j] = MULT16_16_Q15(X[j], mid);
850          for (j=0;j<N;j++)
851             Y[j] = MULT16_16_Q15(Y[j], side);
852       }
853
854       /* Undo the sample reorganization going from time order to frequency order */
855       if (!stereo && B0>1 && level==0)
856       {
857          interleave_vector(X, N_B, B0);
858          if (lowband)
859             interleave_vector(lowband, N_B, B0);
860       }
861
862       /* Undo time-freq changes that we did earlier */
863       N_B = N_B0;
864       B = B0;
865       for (k=0;k<time_divide;k++)
866       {
867          B >>= 1;
868          N_B <<= 1;
869          haar1(X, N_B, B);
870          if (lowband)
871             haar1(lowband, N_B, B);
872       }
873
874       for (k=0;k<recombine;k++)
875       {
876          haar1(X, N_B, B);
877          if (lowband)
878             haar1(lowband, N_B, B);
879          N_B>>=1;
880          B <<= 1;
881       }
882
883       /* Scale output for later folding */
884       if (lowband_out && !stereo)
885       {
886          int j;
887          celt_word16 n;
888          n = celt_sqrt(SHL32(EXTEND32(N0),22));
889          for (j=0;j<N0;j++)
890             lowband_out[j] = MULT16_16_Q15(n,X[j]);
891       }
892
893       if (stereo)
894       {
895          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
896          /* We only need to renormalize because quantization may not
897             have preserved orthogonality of mid and side */
898          renormalise_vector(X, Q15ONE, N, 1);
899          renormalise_vector(Y, Q15ONE, N, 1);
900       }
901    }
902 }
903
904 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
905 {
906    int i, balance;
907    celt_int32 remaining_bits;
908    const celt_int16 * restrict eBands = m->eBands;
909    celt_norm * restrict norm;
910    VARDECL(celt_norm, _norm);
911    int B;
912    int M;
913    int spread;
914    celt_int32 seed;
915    celt_norm *lowband;
916    int update_lowband = 1;
917    int C = _Y != NULL ? 2 : 1;
918    SAVE_STACK;
919
920    M = 1<<LM;
921    B = shortBlocks ? M : 1;
922    spread = fold ? B : 0;
923    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
924    norm = _norm;
925
926    if (encode)
927       seed = ((ec_enc*)ec)->rng;
928    else
929       seed = ((ec_dec*)ec)->rng;
930    balance = 0;
931    lowband = NULL;
932    for (i=start;i<end;i++)
933    {
934       int tell;
935       int b;
936       int N;
937       int curr_balance;
938       celt_norm * restrict X, * restrict Y;
939       int tf_change=0;
940       celt_norm *effective_lowband;
941       
942       X = _X+M*eBands[i];
943       if (_Y!=NULL)
944          Y = _Y+M*eBands[i];
945       else
946          Y = NULL;
947       N = M*eBands[i+1]-M*eBands[i];
948       if (encode)
949          tell = ec_enc_tell((ec_enc*)ec, BITRES);
950       else
951          tell = ec_dec_tell((ec_dec*)ec, BITRES);
952
953       /* Compute how many bits we want to allocate to this band */
954       if (i != start)
955          balance -= tell;
956       remaining_bits = (total_bits<<BITRES)-tell-1;
957       curr_balance = (end-i);
958       if (curr_balance > 3)
959          curr_balance = 3;
960       curr_balance = balance / curr_balance;
961       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
962       if (b<0)
963          b = 0;
964       /* Prevents ridiculous bit depths */
965       if (b > C*16*N<<BITRES)
966          b = C*16*N<<BITRES;
967
968       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
969             lowband = norm+M*eBands[i]-N;
970
971       tf_change = tf_res[i];
972       if (i>=m->effEBands)
973       {
974          X=norm;
975          if (_Y!=NULL)
976             Y = norm;
977       }
978
979       if (tf_change==0 && !shortBlocks && fold)
980          effective_lowband = NULL;
981       else
982          effective_lowband = lowband;
983       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change, effective_lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0, &seed);
984
985       balance += pulses[i] + tell;
986
987       /* Update the folding position only as long as we have 2 bit/sample depth */
988       update_lowband = (b>>BITRES)>2*N;
989    }
990    RESTORE_STACK;
991 }
992