Fix stereo for N=2
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    celt_word16 left, right;
221    celt_word16 norm;
222 #ifdef FIXED_POINT
223    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
224 #endif
225    left = VSHR32(bank[i],shift);
226    right = VSHR32(bank[i+m->nbEBands],shift);
227    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
228    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
229    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
230    for (j=0;j<N;j++)
231    {
232       celt_norm r, l;
233       l = X[j];
234       r = Y[j];
235       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
236       /* Side is not encoded, no need to calculate */
237    }
238 }
239
240 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
241 {
242    int j;
243    for (j=0;j<N;j++)
244    {
245       celt_norm r, l;
246       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
247       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
248       X[j] = l+r;
249       Y[j] = r-l;
250    }
251 }
252
253 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
254 {
255    int j;
256    celt_word32 xp=0;
257    celt_word32 El, Er;
258 #ifdef FIXED_POINT
259    int kl, kr;
260 #endif
261    celt_word32 t, lgain, rgain;
262
263    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
264    for (j=0;j<N;j++)
265       xp = MAC16_16(xp, X[j], Y[j]);
266    /* mid and side are in Q15, not Q14 like X and Y */
267    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*SHL32(xp,2);
268    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*SHL32(xp,2);
269    if (Er < EPSILON)
270       Er = EPSILON;
271    if (El < EPSILON)
272       El = EPSILON;
273
274 #ifdef FIXED_POINT
275    kl = celt_ilog2(El)>>1;
276    kr = celt_ilog2(Er)>>1;
277 #endif
278    t = VSHR32(El, (kl-7)<<1);
279    lgain = celt_rsqrt_norm(t);
280    t = VSHR32(Er, (kr-7)<<1);
281    rgain = celt_rsqrt_norm(t);
282
283    for (j=0;j<N;j++)
284    {
285       celt_norm r, l;
286       l = X[j];
287       r = Y[j];
288       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, l-r), kl));
289       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, l+r), kr));
290    }
291 }
292
293 /* Decide whether we should spread the pulses in the current frame */
294 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
295 {
296    int i, c, N0;
297    int sum = 0, nbBands=0;
298    const int C = CHANNELS(_C);
299    const celt_int16 * restrict eBands = m->eBands;
300    int decision;
301    
302    N0 = M*m->shortMdctSize;
303
304    if (M*(eBands[end]-eBands[end-1]) <= 8)
305       return 0;
306    for (c=0;c<C;c++)
307    {
308       for (i=0;i<end;i++)
309       {
310          int j, N, tmp=0;
311          int tcount[3] = {0};
312          celt_norm * restrict x = X+M*eBands[i]+c*N0;
313          N = M*(eBands[i+1]-eBands[i]);
314          if (N<=8)
315             continue;
316          /* Compute rough CDF of |x[j]| */
317          for (j=0;j<N;j++)
318          {
319             celt_word32 x2N; /* Q13 */
320
321             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
322             if (x2N < QCONST16(0.25f,13))
323                tcount[0]++;
324             if (x2N < QCONST16(0.0625f,13))
325                tcount[1]++;
326             if (x2N < QCONST16(0.015625f,13))
327                tcount[2]++;
328          }
329
330          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
331          sum += tmp*256;
332          nbBands++;
333       }
334    }
335    sum /= nbBands;
336    /* Recursive averaging */
337    sum = (sum+*average)>>1;
338    *average = sum;
339    /* Hysteresis */
340    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
341    /* decision and last_decision do not use the same encoding */
342    if (sum < 128)
343    {
344       decision = 2;
345       *last_decision = 0;
346    } else if (sum < 256)
347    {
348       decision = 1;
349       *last_decision = 1;
350    } else if (sum < 384)
351    {
352       decision = 3;
353       *last_decision = 2;
354    } else {
355       decision = 0;
356       *last_decision = 3;
357    }
358    return decision;
359 }
360
361 #ifdef MEASURE_NORM_MSE
362
363 float MSE[30] = {0};
364 int nbMSEBands = 0;
365 int MSECount[30] = {0};
366
367 void dump_norm_mse(void)
368 {
369    int i;
370    for (i=0;i<nbMSEBands;i++)
371    {
372       printf ("%g ", MSE[i]/MSECount[i]);
373    }
374    printf ("\n");
375 }
376
377 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
378 {
379    static int init = 0;
380    int i;
381    if (!init)
382    {
383       atexit(dump_norm_mse);
384       init = 1;
385    }
386    for (i=0;i<m->nbEBands;i++)
387    {
388       int j;
389       int c;
390       float g;
391       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
392          continue;
393       for (c=0;c<C;c++)
394       {
395          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
396          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
397             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
398       }
399       MSECount[i]+=C;
400    }
401    nbMSEBands = m->nbEBands;
402 }
403
404 #endif
405
406 static void interleave_vector(celt_norm *X, int N0, int stride)
407 {
408    int i,j;
409    VARDECL(celt_norm, tmp);
410    int N;
411    SAVE_STACK;
412    N = N0*stride;
413    ALLOC(tmp, N, celt_norm);
414    for (i=0;i<stride;i++)
415       for (j=0;j<N0;j++)
416          tmp[j*stride+i] = X[i*N0+j];
417    for (j=0;j<N;j++)
418       X[j] = tmp[j];
419    RESTORE_STACK;
420 }
421
422 static void deinterleave_vector(celt_norm *X, int N0, int stride)
423 {
424    int i,j;
425    VARDECL(celt_norm, tmp);
426    int N;
427    SAVE_STACK;
428    N = N0*stride;
429    ALLOC(tmp, N, celt_norm);
430    for (i=0;i<stride;i++)
431       for (j=0;j<N0;j++)
432          tmp[i*N0+j] = X[j*stride+i];
433    for (j=0;j<N;j++)
434       X[j] = tmp[j];
435    RESTORE_STACK;
436 }
437
438 static void haar1(celt_norm *X, int N0, int stride)
439 {
440    int i, j;
441    N0 >>= 1;
442    for (i=0;i<stride;i++)
443       for (j=0;j<N0;j++)
444       {
445          celt_norm tmp = X[stride*2*j+i];
446          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
447          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
448       }
449 }
450
451 static int compute_qn(int N, int b, int offset, int stereo)
452 {
453    static const celt_int16 exp2_table8[8] =
454       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
455    int qn, qb;
456    int N2 = 2*N-1;
457    if (stereo && N==2)
458       N2--;
459    qb = (b+N2*offset)/N2;
460    if (qb > (b>>1)-(1<<BITRES))
461       qb = (b>>1)-(1<<BITRES);
462
463    if (qb<0)
464        qb = 0;
465    if (qb>8<<BITRES)
466      qb = 8<<BITRES;
467
468    if (qb<(1<<BITRES>>1)) {
469       qn = 1;
470    } else {
471       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
472       qn = (qn+1)>>1<<1;
473    }
474    celt_assert(qn <= 256);
475    return qn;
476 }
477
478
479 /* This function is responsible for encoding and decoding a band for both
480    the mono and stereo case. Even in the mono case, it can split the band
481    in two and transmit the energy difference with the two half-bands. It
482    can be called recursively so bands can end up being split in 8 parts. */
483 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
484       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
485       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
486       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
487 {
488    int q;
489    int curr_bits;
490    int stereo, split;
491    int imid=0, iside=0;
492    int N0=N;
493    int N_B=N;
494    int N_B0;
495    int B0=B;
496    int time_divide=0;
497    int recombine=0;
498    celt_word16 mid=0, side=0;
499
500    N_B /= B;
501    N_B0 = N_B;
502
503    split = stereo = Y != NULL;
504
505    /* Special case for one sample */
506    if (N==1)
507    {
508       int c;
509       celt_norm *x = X;
510       for (c=0;c<1+stereo;c++)
511       {
512          int sign=0;
513          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
514          {
515             if (encode)
516             {
517                sign = x[0]<0;
518                ec_enc_bits((ec_enc*)ec, sign, 1);
519             } else {
520                sign = ec_dec_bits((ec_dec*)ec, 1);
521             }
522             *remaining_bits -= 1<<BITRES;
523             b-=1<<BITRES;
524          }
525          if (resynth)
526             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
527          x = Y;
528       }
529       if (lowband_out)
530          lowband_out[0] = X[0];
531       return;
532    }
533
534    if (!stereo && level == 0)
535    {
536       int k;
537       if (tf_change>0)
538          recombine = tf_change;
539       /* Band recombining to increase frequency resolution */
540
541       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
542       {
543          int j;
544          for (j=0;j<N;j++)
545             lowband_scratch[j] = lowband[j];
546          lowband = lowband_scratch;
547       }
548
549       for (k=0;k<recombine;k++)
550       {
551          B>>=1;
552          N_B<<=1;
553          if (encode)
554             haar1(X, N_B, B);
555          if (lowband)
556             haar1(lowband, N_B, B);
557       }
558
559       /* Increasing the time resolution */
560       while ((N_B&1) == 0 && tf_change<0)
561       {
562          if (encode)
563             haar1(X, N_B, B);
564          if (lowband)
565             haar1(lowband, N_B, B);
566          B <<= 1;
567          N_B >>= 1;
568          time_divide++;
569          tf_change++;
570       }
571       B0=B;
572       N_B0 = N_B;
573
574       /* Reorganize the samples in time order instead of frequency order */
575       if (B0>1)
576       {
577          if (encode)
578             deinterleave_vector(X, N_B, B0);
579          if (lowband)
580             deinterleave_vector(lowband, N_B, B0);
581       }
582    }
583
584    /* If we need more than 32 bits, try splitting the band in two. */
585    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
586    {
587       if (LM>0 || (N&1)==0)
588       {
589          N >>= 1;
590          Y = X+N;
591          split = 1;
592          LM -= 1;
593          B = (B+1)>>1;
594       }
595    }
596
597    if (split)
598    {
599       int qn;
600       int itheta=0;
601       int mbits, sbits, delta;
602       int qalloc;
603       int offset;
604
605       /* Decide on the resolution to give to the split parameter theta */
606       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
607       qn = compute_qn(N, b, offset, stereo);
608
609       qalloc = 0;
610       if (qn!=1)
611       {
612          if (encode)
613          {
614             if (stereo)
615                stereo_split(X, Y, N);
616
617             mid = vector_norm(X, N);
618             side = vector_norm(Y, N);
619             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
620                      Let's do that at higher complexity */
621             /*mid = renormalise_vector(X, Q15ONE, N, 1);
622             side = renormalise_vector(Y, Q15ONE, N, 1);*/
623
624             /* theta is the atan() of the ratio between the (normalized)
625                side and mid. With just that parameter, we can re-scale both
626                mid and side because we know that 1) they have unit norm and
627                2) they are orthogonal. */
628    #ifdef FIXED_POINT
629             /* 0.63662 = 2/pi */
630             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
631    #else
632             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
633    #endif
634
635             itheta = (itheta*qn+8192)>>14;
636          }
637
638          /* Entropy coding of the angle. We use a uniform pdf for the
639             first stereo split but a triangular one for the rest. */
640          if (stereo || B>1)
641          {
642             if (encode)
643                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
644             else
645                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
646             qalloc = log2_frac(qn+1,BITRES);
647          } else {
648             int fs=1, ft;
649             ft = ((qn>>1)+1)*((qn>>1)+1);
650             if (encode)
651             {
652                int fl;
653
654                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
655                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
656                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
657
658                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
659             } else {
660                int fl=0;
661                int fm;
662                fm = ec_decode((ec_dec*)ec, ft);
663
664                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
665                {
666                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
667                   fs = itheta + 1;
668                   fl = itheta*(itheta + 1)>>1;
669                }
670                else
671                {
672                   itheta = (2*(qn + 1)
673                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
674                   fs = qn + 1 - itheta;
675                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
676                }
677
678                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
679             }
680             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
681          }
682          itheta = (celt_int32)itheta*16384/qn;
683       } else {
684          if (stereo && encode)
685             intensity_stereo(m, X, Y, bandE, i, N);
686       }
687
688       if (itheta == 0)
689       {
690          imid = 32767;
691          iside = 0;
692          delta = -10000;
693       } else if (itheta == 16384)
694       {
695          imid = 0;
696          iside = 32767;
697          delta = 10000;
698       } else {
699          imid = bitexact_cos(itheta);
700          iside = bitexact_cos(16384-itheta);
701          /* This is the mid vs side allocation that minimizes squared error
702             in that band. */
703          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
704       }
705
706 #ifdef FIXED_POINT
707       mid = imid;
708       side = iside;
709 #else
710       mid = (1.f/32768)*imid;
711       side = (1.f/32768)*iside;
712 #endif
713
714       /* This is a special case for N=2 that only works for stereo and takes
715          advantage of the fact that mid and side are orthogonal to encode
716          the side with just one bit. */
717       if (N==2 && stereo)
718       {
719          int c;
720          int sign=1;
721          celt_norm *x2, *y2;
722          mbits = b-qalloc;
723          sbits = 0;
724          /* Only need one bit for the side */
725          if (itheta != 0 && itheta != 16384)
726             sbits = 1<<BITRES;
727          mbits -= sbits;
728          c = itheta > 8192;
729          *remaining_bits -= qalloc+sbits;
730
731          x2 = c ? Y : X;
732          y2 = c ? X : Y;
733          if (sbits)
734          {
735             if (encode)
736             {
737                /* Here we only need to encode a sign for the side */
738                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
739                ec_enc_bits((ec_enc*)ec, sign, 1);
740             } else {
741                sign = ec_dec_bits((ec_dec*)ec, 1);
742             }
743          }
744          sign = 2*sign - 1;
745          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed, gain, lowband_scratch);
746          y2[0] = -sign*x2[1];
747          y2[1] = sign*x2[0];
748          if (resynth)
749          {
750             celt_norm tmp;
751             X[0] = MULT16_16_Q15(mid, X[0]);
752             X[1] = MULT16_16_Q15(mid, X[1]);
753             Y[0] = MULT16_16_Q15(side, Y[0]);
754             Y[1] = MULT16_16_Q15(side, Y[1]);
755             tmp = X[0];
756             X[0] = SUB16(tmp,Y[0]);
757             Y[0] = ADD16(tmp,Y[0]);
758             tmp = X[1];
759             X[1] = SUB16(tmp,Y[1]);
760             Y[1] = ADD16(tmp,Y[1]);
761          }
762       } else {
763          /* "Normal" split code */
764          celt_norm *next_lowband2=NULL;
765          celt_norm *next_lowband_out1=NULL;
766          int next_level=0;
767
768          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
769          if (B>1 && !stereo)
770             delta >>= 1;
771
772          mbits = (b-qalloc-delta)/2;
773          if (mbits > b-qalloc)
774             mbits = b-qalloc;
775          if (mbits<0)
776             mbits=0;
777          sbits = b-qalloc-mbits;
778          *remaining_bits -= qalloc;
779
780          if (lowband && !stereo)
781             next_lowband2 = lowband+N; /* >32-bit split case */
782
783          /* Only stereo needs to pass on lowband_out. Otherwise, it's
784             handled at the end */
785          if (stereo)
786             next_lowband_out1 = lowband_out;
787          else
788             next_level = level+1;
789
790          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
791                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
792                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
793          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
794                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
795                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
796       }
797
798    } else {
799       /* This is the basic no-split case */
800       q = bits2pulses(m, i, LM, b);
801       curr_bits = pulses2bits(m, i, LM, q);
802       *remaining_bits -= curr_bits;
803
804       /* Ensures we can never bust the budget */
805       while (*remaining_bits < 0 && q > 0)
806       {
807          *remaining_bits += curr_bits;
808          q--;
809          curr_bits = pulses2bits(m, i, LM, q);
810          *remaining_bits -= curr_bits;
811       }
812
813       /* Finally do the actual quantization */
814       if (encode)
815          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
816       else
817          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed, gain);
818    }
819
820    /* This code is used by the decoder and by the resynthesis-enabled encoder */
821    if (resynth)
822    {
823       if (stereo)
824       {
825          if (N!=2)
826             stereo_merge(X, Y, mid, side, N);
827       } else if (level == 0)
828       {
829          int k;
830
831          /* Undo the sample reorganization going from time order to frequency order */
832          if (B0>1)
833             interleave_vector(X, N_B, B0);
834
835          /* Undo time-freq changes that we did earlier */
836          N_B = N_B0;
837          B = B0;
838          for (k=0;k<time_divide;k++)
839          {
840             B >>= 1;
841             N_B <<= 1;
842             haar1(X, N_B, B);
843          }
844
845          for (k=0;k<recombine;k++)
846          {
847             haar1(X, N_B, B);
848             N_B>>=1;
849             B <<= 1;
850          }
851
852          /* Scale output for later folding */
853          if (lowband_out)
854          {
855             int j;
856             celt_word16 n;
857             n = celt_sqrt(SHL32(EXTEND32(N0),22));
858             for (j=0;j<N0;j++)
859                lowband_out[j] = MULT16_16_Q15(n,X[j]);
860          }
861       }
862    }
863 }
864
865 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
866 {
867    int i, balance;
868    celt_int32 remaining_bits;
869    const celt_int16 * restrict eBands = m->eBands;
870    celt_norm * restrict norm;
871    VARDECL(celt_norm, _norm);
872    VARDECL(celt_norm, lowband_scratch);
873    int B;
874    int M;
875    celt_int32 seed;
876    celt_norm *lowband;
877    int update_lowband = 1;
878    int C = _Y != NULL ? 2 : 1;
879    SAVE_STACK;
880
881    M = 1<<LM;
882    B = shortBlocks ? M : 1;
883    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
884    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
885    norm = _norm;
886
887    if (encode)
888       seed = ((ec_enc*)ec)->rng;
889    else
890       seed = ((ec_dec*)ec)->rng;
891    balance = 0;
892    lowband = NULL;
893    for (i=start;i<end;i++)
894    {
895       int tell;
896       int b;
897       int N;
898       int curr_balance;
899       celt_norm * restrict X, * restrict Y;
900       int tf_change=0;
901       celt_norm *effective_lowband;
902       
903       X = _X+M*eBands[i];
904       if (_Y!=NULL)
905          Y = _Y+M*eBands[i];
906       else
907          Y = NULL;
908       N = M*eBands[i+1]-M*eBands[i];
909       if (encode)
910          tell = ec_enc_tell((ec_enc*)ec, BITRES);
911       else
912          tell = ec_dec_tell((ec_dec*)ec, BITRES);
913
914       /* Compute how many bits we want to allocate to this band */
915       if (i != start)
916          balance -= tell;
917       remaining_bits = (total_bits<<BITRES)-tell-1;
918       curr_balance = (end-i);
919       if (curr_balance > 3)
920          curr_balance = 3;
921       curr_balance = balance / curr_balance;
922       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
923       if (b<0)
924          b = 0;
925       /* Prevents ridiculous bit depths */
926       if (b > C*16*N<<BITRES)
927          b = C*16*N<<BITRES;
928
929       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
930             lowband = norm+M*eBands[i]-N;
931
932       tf_change = tf_res[i];
933       if (i>=m->effEBands)
934       {
935          X=norm;
936          if (_Y!=NULL)
937             Y = norm;
938       }
939
940       if (tf_change==0 && !shortBlocks && fold)
941          effective_lowband = NULL;
942       else
943          effective_lowband = lowband;
944       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
945             effective_lowband, resynth, ec, &remaining_bits, LM,
946             norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
947
948       balance += pulses[i] + tell;
949
950       /* Update the folding position only as long as we have 2 bit/sample depth */
951       update_lowband = (b>>BITRES)>2*N;
952    }
953    RESTORE_STACK;
954 }
955