Measuring the normalized error directly within the encoder
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
216 {
217    int j, c;
218    celt_word16 g;
219    celt_word16 delta;
220    const int C = CHANNELS(_C);
221    celt_word32 Sxy=0, Sxx=0, Syy=0;
222    int len = M*m->pitchEnd;
223    int N = M*m->shortMdctSize;
224 #ifdef FIXED_POINT
225    int shift = 0;
226    celt_word32 maxabs=0;
227
228    for (c=0;c<C;c++)
229    {
230       for (j=0;j<len;j++)
231       {
232          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
233          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
234       }
235    }
236    shift = celt_ilog2(maxabs)-12;
237    if (shift<0)
238       shift = 0;
239 #endif
240    delta = PDIV32_16(Q15ONE, len);
241    for (c=0;c<C;c++)
242    {
243       celt_word16 gg = Q15ONE;
244       for (j=0;j<len;j++)
245       {
246          celt_word16 Xj, Pj;
247          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
248          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
249          Sxy = MAC16_16(Sxy, Xj, Pj);
250          Sxx = MAC16_16(Sxx, Pj, Pj);
251          Syy = MAC16_16(Syy, Xj, Xj);
252          gg = SUB16(gg, delta);
253       }
254    }
255 #ifdef FIXED_POINT
256    {
257       celt_word32 num, den;
258       celt_word16 fact;
259       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
260       if (fact < QCONST16(1.f, 14))
261          fact = QCONST16(1.f, 14);
262       num = Sxy;
263       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
264       shift = celt_zlog2(Sxy)-16;
265       if (shift < 0)
266          shift = 0;
267       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
268          g = 0;
269       else
270          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
271
272       /* This MUST round down so that we don't over-estimate the gain */
273       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
274    }
275 #else
276    {
277       float fact = .04f*norm_rate;
278       if (fact < 1)
279          fact = 1;
280       g = Sxy/(.1f+Sxx+.03f*Syy);
281       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
282          g = 0;
283       /* This MUST round down so that we don't over-estimate the gain */
284       *gain_id = floor(20*(g-.5f));
285    }
286 #endif
287    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
288       maximum error amplification factor to 2.0 */
289    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
290    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
291    if (*gain_prod>QCONST32(2.f, 13))
292    {
293       *gain_id=9;
294       *gain_prod = QCONST32(2.f, 13);
295    }
296
297    if (*gain_id < 0)
298    {
299       *gain_id = 0;
300       return 0;
301    } else {
302       if (*gain_id > 15)
303          *gain_id = 15;
304       return 1;
305    }
306 }
307
308 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
309 {
310    int j, c, N;
311    celt_word16 gain;
312    celt_word16 delta;
313    const int C = CHANNELS(_C);
314    int len = M*m->pitchEnd;
315
316    N = M*m->shortMdctSize;
317    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
318    delta = PDIV32_16(gain, len);
319    if (pred)
320       gain = -gain;
321    else
322       delta = -delta;
323    for (c=0;c<C;c++)
324    {
325       celt_word16 gg = gain;
326       for (j=0;j<len;j++)
327       {
328          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
329          gg = ADD16(gg, delta);
330       }
331    }
332 }
333
334 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
335 {
336    int i = bandID;
337    int j;
338    celt_word16 a1, a2;
339    if (stereo_mode==0)
340    {
341       /* Do mid-side when not doing intensity stereo */
342       a1 = QCONST16(.70711f,14);
343       a2 = dir*QCONST16(.70711f,14);
344    } else {
345       celt_word16 left, right;
346       celt_word16 norm;
347 #ifdef FIXED_POINT
348       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
349 #endif
350       left = VSHR32(bank[i],shift);
351       right = VSHR32(bank[i+m->nbEBands],shift);
352       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
353       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
354       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
355    }
356    for (j=0;j<N;j++)
357    {
358       celt_norm r, l;
359       l = X[j];
360       r = Y[j];
361       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
362       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
363    }
364 }
365
366 /* Decide whether we should spread the pulses in the current frame */
367 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
368 {
369    int i, c, N0;
370    int NR=0;
371    celt_word32 ratio = EPSILON;
372    const int C = CHANNELS(_C);
373    const celt_int16 * restrict eBands = m->eBands;
374    
375    N0 = M*m->shortMdctSize;
376
377    for (c=0;c<C;c++)
378    {
379    for (i=0;i<end;i++)
380    {
381       int j, N;
382       int max_i=0;
383       celt_word16 max_val=EPSILON;
384       celt_word32 floor_ener=EPSILON;
385       celt_norm * restrict x = X+M*eBands[i]+c*N0;
386       N = M*eBands[i+1]-M*eBands[i];
387       for (j=0;j<N;j++)
388       {
389          if (ABS16(x[j])>max_val)
390          {
391             max_val = ABS16(x[j]);
392             max_i = j;
393          }
394       }
395 #if 0
396       for (j=0;j<N;j++)
397       {
398          if (abs(j-max_i)>2)
399             floor_ener += x[j]*x[j];
400       }
401 #else
402       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
403       if (max_i < N-1)
404          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
405       if (max_i < N-2)
406          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
407       if (max_i > 0)
408          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
409       if (max_i > 1)
410          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
411       floor_ener = MAX32(floor_ener, EPSILON);
412 #endif
413       if (N>7)
414       {
415          celt_word16 r;
416          celt_word16 den = celt_sqrt(floor_ener);
417          den = MAX32(QCONST16(.02f, 15), den);
418          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
419          ratio = ADD32(ratio, EXTEND32(r));
420          NR++;
421       }
422    }
423    }
424    if (NR>0)
425       ratio = DIV32_16(ratio, NR);
426    ratio = ADD32(HALF32(ratio), HALF32(*average));
427    if (!*last_decision)
428    {
429       *last_decision = (ratio < QCONST16(1.8f,8));
430    } else {
431       *last_decision = (ratio < QCONST16(3.f,8));
432    }
433    *average = EXTRACT16(ratio);
434    return *last_decision;
435 }
436
437 #ifdef MEASURE_NORM_MSE
438
439 float MSE[30] = {0};
440 int nbMSEBands = 0;
441 int MSECount[30] = {0};
442
443 void dump_norm_mse(void)
444 {
445    int i;
446    for (i=0;i<nbMSEBands;i++)
447    {
448       printf ("%f ", MSE[i]/MSECount[i]);
449    }
450    printf ("\n");
451 }
452
453 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M)
454 {
455    static int init = 0;
456    int i;
457    if (!init)
458    {
459       atexit(dump_norm_mse);
460       init = 1;
461    }
462    for (i=0;i<m->nbEBands;i++)
463    {
464       int j;
465       float g = bandE[i]/(1e-15+bandE0[i]);
466       if (bandE0[i]<1)
467          continue;
468       for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
469          MSE[i] += (g*X[j]-X0[j])*(g*X[j]-X0[j]);
470       MSECount[i]++;
471    }
472    nbMSEBands = m->nbEBands;
473 }
474
475 #endif
476
477 static void interleave_vector(celt_norm *X, int N0, int stride)
478 {
479    int i,j;
480    VARDECL(celt_norm, tmp);
481    int N;
482    SAVE_STACK;
483    N = N0*stride;
484    ALLOC(tmp, N, celt_norm);
485    for (i=0;i<stride;i++)
486       for (j=0;j<N0;j++)
487          tmp[j*stride+i] = X[i*N0+j];
488    for (j=0;j<N;j++)
489       X[j] = tmp[j];
490    RESTORE_STACK;
491 }
492
493 static void deinterleave_vector(celt_norm *X, int N0, int stride)
494 {
495    int i,j;
496    VARDECL(celt_norm, tmp);
497    int N;
498    SAVE_STACK;
499    N = N0*stride;
500    ALLOC(tmp, N, celt_norm);
501    for (i=0;i<stride;i++)
502       for (j=0;j<N0;j++)
503          tmp[i*N0+j] = X[j*stride+i];
504    for (j=0;j<N;j++)
505       X[j] = tmp[j];
506    RESTORE_STACK;
507 }
508
509 static void haar1(celt_norm *X, int N0, int stride)
510 {
511    int i, j;
512    N0 >>= 1;
513    for (i=0;i<stride;i++)
514       for (j=0;j<N0;j++)
515       {
516          celt_norm tmp = X[stride*2*j+i];
517          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
518          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
519       }
520 }
521
522 /* This function is responsible for encoding and decoding a band for both
523    the mono and stereo case. Even in the mono case, it can split the band
524    in two and transmit the energy difference with the two half-bands. It
525    can be called recursively so bands can end up being split in 8 parts. */
526 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
527       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
528       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level, celt_int32 *seed)
529 {
530    int q;
531    int curr_bits;
532    int stereo, split;
533    int imid=0, iside=0;
534    int N0=N;
535    int N_B=N;
536    int N_B0;
537    int B0=B;
538    int time_divide=0;
539    int recombine=0;
540
541    N_B /= B;
542    N_B0 = N_B;
543
544    split = stereo = Y != NULL;
545
546    /* Special case for one sample */
547    if (N==1)
548    {
549       int c;
550       celt_norm *x = X;
551       for (c=0;c<1+stereo;c++)
552       {
553          int sign=0;
554          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
555          {
556             if (encode)
557             {
558                sign = x[0]<0;
559                ec_enc_bits((ec_enc*)ec, sign, 1);
560             } else {
561                sign = ec_dec_bits((ec_dec*)ec, 1);
562             }
563             *remaining_bits -= 1<<BITRES;
564             b-=1<<BITRES;
565          }
566          if (resynth)
567             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
568          x = Y;
569       }
570       if (lowband_out)
571          lowband_out[0] = X[0];
572       return;
573    }
574
575    /* Band recombining to increase frequency resolution */
576    if (!stereo && B > 1 && level == 0 && tf_change>0)
577    {
578       while (B>1 && tf_change>0)
579       {
580          B>>=1;
581          N_B<<=1;
582          if (encode)
583             haar1(X, N_B, B);
584          if (lowband)
585             haar1(lowband, N_B, B);
586          recombine++;
587          tf_change--;
588       }
589       B0=B;
590       N_B0 = N_B;
591    }
592
593    /* Increasing the time resolution */
594    if (!stereo && level==0)
595    {
596       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
597       {
598          if (encode)
599             haar1(X, N_B, B);
600          if (lowband)
601             haar1(lowband, N_B, B);
602          B <<= 1;
603          N_B >>= 1;
604          time_divide++;
605          tf_change++;
606       }
607       B0 = B;
608       N_B0 = N_B;
609    }
610
611    /* Reorganize the samples in time order instead of frequency order */
612    if (!stereo && B0>1 && level==0)
613    {
614       if (encode)
615          deinterleave_vector(X, N_B, B0);
616       if (lowband)
617          deinterleave_vector(lowband, N_B, B0);
618    }
619
620    /* If we need more than 32 bits, try splitting the band in two. */
621    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
622    {
623       if (LM>0 || (N&1)==0)
624       {
625          N >>= 1;
626          Y = X+N;
627          split = 1;
628          LM -= 1;
629          B = (B+1)>>1;
630       }
631    }
632
633    if (split)
634    {
635       int qb;
636       int itheta=0;
637       int mbits, sbits, delta;
638       int qalloc;
639       celt_word16 mid, side;
640       int offset, N2;
641       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
642
643       /* Decide on the resolution to give to the split parameter theta */
644       N2 = 2*N-1;
645       if (stereo && N==2)
646          N2--;
647       qb = (b+N2*offset)/(N2<<BITRES);
648       if (qb > (b>>(BITRES+1))-1)
649          qb = (b>>(BITRES+1))-1;
650
651       if (qb<0)
652          qb = 0;
653       if (qb>14)
654          qb = 14;
655
656       qalloc = 0;
657       if (qb!=0)
658       {
659          int shift;
660          shift = 14-qb;
661
662          if (encode)
663          {
664             if (stereo)
665                stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, N);
666
667             mid = renormalise_vector(X, Q15ONE, N, 1);
668             side = renormalise_vector(Y, Q15ONE, N, 1);
669
670             /* theta is the atan() of the ratio between the (normalized)
671                side and mid. With just that parameter, we can re-scale both
672                mid and side because we know that 1) they have unit norm and
673                2) they are orthogonal. */
674    #ifdef FIXED_POINT
675             /* 0.63662 = 2/pi */
676             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
677    #else
678             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
679    #endif
680
681             itheta = (itheta+(1<<shift>>1))>>shift;
682          }
683
684          /* Entropy coding of the angle. We use a uniform pdf for the
685             first stereo split but a triangular one for the rest. */
686          if (stereo || qb>8 || B>1)
687          {
688             if (encode)
689                ec_enc_uint((ec_enc*)ec, itheta, (1<<qb)+1);
690             else
691                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
692             qalloc = log2_frac((1<<qb)+1,BITRES);
693          } else {
694             int fs=1, ft;
695             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
696             if (encode)
697             {
698                int fl;
699
700                fs = itheta <= (1<<qb>>1) ? itheta + 1 : (1<<qb) + 1 - itheta;
701                fl = itheta <= (1<<qb>>1) ? itheta*(itheta + 1)>>1 :
702                 ft - (((1<<qb) + 1 - itheta)*((1<<qb) + 2 - itheta)>>1);
703
704                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
705             } else {
706                int fl=0;
707                int fm;
708                fm = ec_decode((ec_dec*)ec, ft);
709
710                if (fm < ((1<<qb>>1)*((1<<qb>>1) + 1)>>1))
711                {
712                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
713                   fs = itheta + 1;
714                   fl = itheta*(itheta + 1)>>1;
715                }
716                else
717                {
718                   itheta = (2*((1<<qb) + 1)
719                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
720                   fs = (1<<qb) + 1 - itheta;
721                   fl = ft - (((1<<qb) + 1 - itheta)*((1<<qb) + 2 - itheta)>>1);
722                }
723
724                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
725             }
726             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
727          }
728          itheta <<= shift;
729       }
730
731       if (itheta == 0)
732       {
733          imid = 32767;
734          iside = 0;
735          delta = -10000;
736       } else if (itheta == 16384)
737       {
738          imid = 0;
739          iside = 32767;
740          delta = 10000;
741       } else {
742          imid = bitexact_cos(itheta);
743          iside = bitexact_cos(16384-itheta);
744          /* This is the mid vs side allocation that minimizes squared error
745             in that band. */
746          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
747       }
748
749       /* This is a special case for N=2 that only works for stereo and takes
750          advantage of the fact that mid and side are orthogonal to encode
751          the side with just one bit. */
752       if (N==2 && stereo)
753       {
754          int c, c2;
755          int sign=1;
756          celt_norm v[2], w[2];
757          celt_norm *x2, *y2;
758          mbits = b-qalloc;
759          sbits = 0;
760          /* Only need one bit for the side */
761          if (itheta != 0 && itheta != 16384)
762             sbits = 1<<BITRES;
763          mbits -= sbits;
764          c = itheta > 8192 ? 1 : 0;
765          *remaining_bits -= qalloc+sbits;
766
767          x2 = X;
768          y2 = Y;
769          if (encode)
770          {
771             c2 = 1-c;
772
773             /* v is the largest vector between mid and side. w is the other */
774             if (c==0)
775             {
776                v[0] = x2[0];
777                v[1] = x2[1];
778                w[0] = y2[0];
779                w[1] = y2[1];
780             } else {
781                v[0] = y2[0];
782                v[1] = y2[1];
783                w[0] = x2[0];
784                w[1] = x2[1];
785             }
786             /* Here we only need to encode a sign for the side */
787             if (v[0]*w[1] - v[1]*w[0] > 0)
788                sign = 1;
789             else
790                sign = -1;
791          }
792          quant_band(encode, m, i, v, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed);
793          if (sbits)
794          {
795             if (encode)
796             {
797                ec_enc_bits((ec_enc*)ec, sign==1, 1);
798             } else {
799                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
800             }
801          } else {
802             sign = 1;
803          }
804          w[0] = -sign*v[1];
805          w[1] = sign*v[0];
806          if (c==0)
807          {
808             x2[0] = v[0];
809             x2[1] = v[1];
810             y2[0] = w[0];
811             y2[1] = w[1];
812          } else {
813             x2[0] = w[0];
814             x2[1] = w[1];
815             y2[0] = v[0];
816             y2[1] = v[1];
817          }
818       } else
819       {
820          /* "Normal" split code */
821          celt_norm *next_lowband2=NULL;
822          celt_norm *next_lowband_out1=NULL;
823          int next_level=0;
824
825          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
826          if (B>1 && !stereo)
827             delta >>= 1;
828
829          mbits = (b-qalloc-delta)/2;
830          if (mbits > b-qalloc)
831             mbits = b-qalloc;
832          if (mbits<0)
833             mbits=0;
834          sbits = b-qalloc-mbits;
835          *remaining_bits -= qalloc;
836
837          if (lowband && !stereo)
838             next_lowband2 = lowband+N; /* >32-bit split case */
839
840          /* Only stereo needs to pass on lowband_out. Otherwise, it's handled at the end */
841          if (stereo)
842             next_lowband_out1 = lowband_out;
843          else
844             next_level = level+1;
845
846          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level, seed);
847          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level, seed);
848       }
849
850    } else {
851       /* This is the basic no-split case */
852       q = bits2pulses(m, m->bits[LM][i], N, b);
853       curr_bits = pulses2bits(m->bits[LM][i], N, q);
854       *remaining_bits -= curr_bits;
855
856       /* Ensures we can never bust the budget */
857       while (*remaining_bits < 0 && q > 0)
858       {
859          *remaining_bits += curr_bits;
860          q--;
861          curr_bits = pulses2bits(m->bits[LM][i], N, q);
862          *remaining_bits -= curr_bits;
863       }
864
865       /* Finally do the actual quantization */
866       if (encode)
867          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed);
868       else
869          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed);
870    }
871
872    /* This code is used by the decoder and by the resynthesis-enabled encoder */
873    if (resynth)
874    {
875       int k;
876
877       if (split)
878       {
879          int j;
880          celt_word16 mid, side;
881 #ifdef FIXED_POINT
882          mid = imid;
883          side = iside;
884 #else
885          mid = (1.f/32768)*imid;
886          side = (1.f/32768)*iside;
887 #endif
888          for (j=0;j<N;j++)
889             X[j] = MULT16_16_Q15(X[j], mid);
890          for (j=0;j<N;j++)
891             Y[j] = MULT16_16_Q15(Y[j], side);
892       }
893
894       /* Undo the sample reorganization going from time order to frequency order */
895       if (!stereo && B0>1 && level==0)
896       {
897          interleave_vector(X, N_B, B0);
898          if (lowband)
899             interleave_vector(lowband, N_B, B0);
900       }
901
902       /* Undo time-freq changes that we did earlier */
903       N_B = N_B0;
904       B = B0;
905       for (k=0;k<time_divide;k++)
906       {
907          B >>= 1;
908          N_B <<= 1;
909          haar1(X, N_B, B);
910          if (lowband)
911             haar1(lowband, N_B, B);
912       }
913
914       for (k=0;k<recombine;k++)
915       {
916          haar1(X, N_B, B);
917          if (lowband)
918             haar1(lowband, N_B, B);
919          N_B>>=1;
920          B <<= 1;
921       }
922
923       /* Scale output for later folding */
924       if (lowband_out && !stereo)
925       {
926          int j;
927          celt_word16 n;
928          n = celt_sqrt(SHL32(EXTEND32(N0),22));
929          for (j=0;j<N0;j++)
930             lowband_out[j] = MULT16_16_Q15(n,X[j]);
931       }
932
933       if (stereo)
934       {
935          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
936          /* We only need to renormalize because quantization may not
937             have preserved orthogonality of mid and side */
938          renormalise_vector(X, Q15ONE, N, 1);
939          renormalise_vector(Y, Q15ONE, N, 1);
940       }
941    }
942 }
943
944 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
945 {
946    int i, balance;
947    celt_int32 remaining_bits;
948    const celt_int16 * restrict eBands = m->eBands;
949    celt_norm * restrict norm;
950    VARDECL(celt_norm, _norm);
951    int B;
952    int M;
953    int spread;
954    celt_int32 seed;
955    celt_norm *lowband;
956    int update_lowband = 1;
957    int C = _Y != NULL ? 2 : 1;
958    SAVE_STACK;
959
960    M = 1<<LM;
961    B = shortBlocks ? M : 1;
962    spread = fold ? B : 0;
963    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
964    norm = _norm;
965
966    if (encode)
967       seed = ((ec_enc*)ec)->rng;
968    else
969       seed = ((ec_dec*)ec)->rng;
970    balance = 0;
971    lowband = NULL;
972    for (i=start;i<end;i++)
973    {
974       int tell;
975       int b;
976       int N;
977       int curr_balance;
978       celt_norm * restrict X, * restrict Y;
979       int tf_change=0;
980       celt_norm *effective_lowband;
981       
982       X = _X+M*eBands[i];
983       if (_Y!=NULL)
984          Y = _Y+M*eBands[i];
985       else
986          Y = NULL;
987       N = M*eBands[i+1]-M*eBands[i];
988       if (encode)
989          tell = ec_enc_tell((ec_enc*)ec, BITRES);
990       else
991          tell = ec_dec_tell((ec_dec*)ec, BITRES);
992
993       /* Compute how many bits we want to allocate to this band */
994       if (i != start)
995          balance -= tell;
996       remaining_bits = (total_bits<<BITRES)-tell-1;
997       curr_balance = (end-i);
998       if (curr_balance > 3)
999          curr_balance = 3;
1000       curr_balance = balance / curr_balance;
1001       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
1002       if (b<0)
1003          b = 0;
1004       /* Prevents ridiculous bit depths */
1005       if (b > C*16*N<<BITRES)
1006          b = C*16*N<<BITRES;
1007
1008       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
1009             lowband = norm+M*eBands[i]-N;
1010
1011       tf_change = tf_res[i];
1012       if (i>=m->effEBands)
1013       {
1014          X=norm;
1015          if (_Y!=NULL)
1016             Y = norm;
1017       }
1018
1019       if (tf_change==0 && !shortBlocks && fold)
1020          effective_lowband = NULL;
1021       else
1022          effective_lowband = lowband;
1023       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change, effective_lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0, &seed);
1024
1025       balance += pulses[i] + tell;
1026
1027       /* Update the folding position only as long as we have 2 bit/sample depth */
1028       update_lowband = (b>>BITRES)>2*N;
1029    }
1030    RESTORE_STACK;
1031 }
1032