New TF decision code based on L1-norm. Needs more work.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    celt_word16 left, right;
221    celt_word16 norm;
222 #ifdef FIXED_POINT
223    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
224 #endif
225    left = VSHR32(bank[i],shift);
226    right = VSHR32(bank[i+m->nbEBands],shift);
227    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
228    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
229    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
230    for (j=0;j<N;j++)
231    {
232       celt_norm r, l;
233       l = X[j];
234       r = Y[j];
235       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
236       /* Side is not encoded, no need to calculate */
237    }
238 }
239
240 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
241 {
242    int j;
243    for (j=0;j<N;j++)
244    {
245       celt_norm r, l;
246       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
247       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
248       X[j] = l+r;
249       Y[j] = r-l;
250    }
251 }
252
253 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
254 {
255    int j;
256    celt_word32 xp=0;
257    celt_word32 El, Er;
258 #ifdef FIXED_POINT
259    int kl, kr;
260 #endif
261    celt_word32 t, lgain, rgain;
262
263    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
264    for (j=0;j<N;j++)
265       xp = MAC16_16(xp, X[j], Y[j]);
266    /* mid and side are in Q15, not Q14 like X and Y */
267    mid = SHR32(mid, 1);
268    side = SHR32(side, 1);
269    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*xp;
270    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*xp;
271    if (Er < EPSILON)
272       Er = EPSILON;
273    if (El < EPSILON)
274       El = EPSILON;
275
276 #ifdef FIXED_POINT
277    kl = celt_ilog2(El)>>1;
278    kr = celt_ilog2(Er)>>1;
279 #endif
280    t = VSHR32(El, (kl-7)<<1);
281    lgain = celt_rsqrt_norm(t);
282    t = VSHR32(Er, (kr-7)<<1);
283    rgain = celt_rsqrt_norm(t);
284
285 #ifdef FIXED_POINT
286    if (kl < 7)
287       kl = 7;
288    if (kr < 7)
289       kr = 7;
290 #endif
291
292    for (j=0;j<N;j++)
293    {
294       celt_norm r, l;
295       l = X[j];
296       r = Y[j];
297       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
298       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
299    }
300 }
301
302 /* Decide whether we should spread the pulses in the current frame */
303 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
304 {
305    int i, c, N0;
306    int sum = 0, nbBands=0;
307    const int C = CHANNELS(_C);
308    const celt_int16 * restrict eBands = m->eBands;
309    int decision;
310    
311    N0 = M*m->shortMdctSize;
312
313    if (M*(eBands[end]-eBands[end-1]) <= 8)
314       return 0;
315    for (c=0;c<C;c++)
316    {
317       for (i=0;i<end;i++)
318       {
319          int j, N, tmp=0;
320          int tcount[3] = {0};
321          celt_norm * restrict x = X+M*eBands[i]+c*N0;
322          N = M*(eBands[i+1]-eBands[i]);
323          if (N<=8)
324             continue;
325          /* Compute rough CDF of |x[j]| */
326          for (j=0;j<N;j++)
327          {
328             celt_word32 x2N; /* Q13 */
329
330             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
331             if (x2N < QCONST16(0.25f,13))
332                tcount[0]++;
333             if (x2N < QCONST16(0.0625f,13))
334                tcount[1]++;
335             if (x2N < QCONST16(0.015625f,13))
336                tcount[2]++;
337          }
338
339          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
340          sum += tmp*256;
341          nbBands++;
342       }
343    }
344    sum /= nbBands;
345    /* Recursive averaging */
346    sum = (sum+*average)>>1;
347    *average = sum;
348    /* Hysteresis */
349    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
350    /* decision and last_decision do not use the same encoding */
351    if (sum < 80)
352    {
353       decision = 2;
354       *last_decision = 0;
355    } else if (sum < 256)
356    {
357       decision = 1;
358       *last_decision = 1;
359    } else if (sum < 384)
360    {
361       decision = 3;
362       *last_decision = 2;
363    } else {
364       decision = 0;
365       *last_decision = 3;
366    }
367    return decision;
368 }
369
370 #ifdef MEASURE_NORM_MSE
371
372 float MSE[30] = {0};
373 int nbMSEBands = 0;
374 int MSECount[30] = {0};
375
376 void dump_norm_mse(void)
377 {
378    int i;
379    for (i=0;i<nbMSEBands;i++)
380    {
381       printf ("%g ", MSE[i]/MSECount[i]);
382    }
383    printf ("\n");
384 }
385
386 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
387 {
388    static int init = 0;
389    int i;
390    if (!init)
391    {
392       atexit(dump_norm_mse);
393       init = 1;
394    }
395    for (i=0;i<m->nbEBands;i++)
396    {
397       int j;
398       int c;
399       float g;
400       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
401          continue;
402       for (c=0;c<C;c++)
403       {
404          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
405          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
406             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
407       }
408       MSECount[i]+=C;
409    }
410    nbMSEBands = m->nbEBands;
411 }
412
413 #endif
414
415 static void interleave_vector(celt_norm *X, int N0, int stride)
416 {
417    int i,j;
418    VARDECL(celt_norm, tmp);
419    int N;
420    SAVE_STACK;
421    N = N0*stride;
422    ALLOC(tmp, N, celt_norm);
423    for (i=0;i<stride;i++)
424       for (j=0;j<N0;j++)
425          tmp[j*stride+i] = X[i*N0+j];
426    for (j=0;j<N;j++)
427       X[j] = tmp[j];
428    RESTORE_STACK;
429 }
430
431 static void deinterleave_vector(celt_norm *X, int N0, int stride)
432 {
433    int i,j;
434    VARDECL(celt_norm, tmp);
435    int N;
436    SAVE_STACK;
437    N = N0*stride;
438    ALLOC(tmp, N, celt_norm);
439    for (i=0;i<stride;i++)
440       for (j=0;j<N0;j++)
441          tmp[i*N0+j] = X[j*stride+i];
442    for (j=0;j<N;j++)
443       X[j] = tmp[j];
444    RESTORE_STACK;
445 }
446
447 void haar1(celt_norm *X, int N0, int stride)
448 {
449    int i, j;
450    N0 >>= 1;
451    for (i=0;i<stride;i++)
452       for (j=0;j<N0;j++)
453       {
454          celt_norm tmp1, tmp2;
455          tmp1 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i]);
456          tmp2 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*(2*j+1)+i]);
457          X[stride*2*j+i] = tmp1 + tmp2;
458          X[stride*(2*j+1)+i] = tmp1 - tmp2;
459       }
460 }
461
462 static int compute_qn(int N, int b, int offset, int stereo)
463 {
464    static const celt_int16 exp2_table8[8] =
465       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
466    int qn, qb;
467    int N2 = 2*N-1;
468    if (stereo && N==2)
469       N2--;
470    qb = (b+N2*offset)/N2;
471    if (qb > (b>>1)-(1<<BITRES))
472       qb = (b>>1)-(1<<BITRES);
473
474    if (qb<0)
475        qb = 0;
476    if (qb>8<<BITRES)
477      qb = 8<<BITRES;
478
479    if (qb<(1<<BITRES>>1)) {
480       qn = 1;
481    } else {
482       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
483       qn = (qn+1)>>1<<1;
484    }
485    celt_assert(qn <= 256);
486    return qn;
487 }
488
489
490 /* This function is responsible for encoding and decoding a band for both
491    the mono and stereo case. Even in the mono case, it can split the band
492    in two and transmit the energy difference with the two half-bands. It
493    can be called recursively so bands can end up being split in 8 parts. */
494 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
495       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
496       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
497       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
498 {
499    int q;
500    int curr_bits;
501    int stereo, split;
502    int imid=0, iside=0;
503    int N0=N;
504    int N_B=N;
505    int N_B0;
506    int B0=B;
507    int time_divide=0;
508    int recombine=0;
509    celt_word16 mid=0, side=0;
510
511    N_B /= B;
512    N_B0 = N_B;
513
514    split = stereo = Y != NULL;
515
516    /* Special case for one sample */
517    if (N==1)
518    {
519       int c;
520       celt_norm *x = X;
521       for (c=0;c<1+stereo;c++)
522       {
523          int sign=0;
524          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
525          {
526             if (encode)
527             {
528                sign = x[0]<0;
529                ec_enc_bits((ec_enc*)ec, sign, 1);
530             } else {
531                sign = ec_dec_bits((ec_dec*)ec, 1);
532             }
533             *remaining_bits -= 1<<BITRES;
534             b-=1<<BITRES;
535          }
536          if (resynth)
537             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
538          x = Y;
539       }
540       if (lowband_out)
541          lowband_out[0] = X[0];
542       return;
543    }
544
545    if (!stereo && level == 0)
546    {
547       int k;
548       if (tf_change>0)
549          recombine = tf_change;
550       /* Band recombining to increase frequency resolution */
551
552       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
553       {
554          int j;
555          for (j=0;j<N;j++)
556             lowband_scratch[j] = lowband[j];
557          lowband = lowband_scratch;
558       }
559
560       for (k=0;k<recombine;k++)
561       {
562          B>>=1;
563          N_B<<=1;
564          if (encode)
565             haar1(X, N_B, B);
566          if (lowband)
567             haar1(lowband, N_B, B);
568       }
569
570       /* Increasing the time resolution */
571       while ((N_B&1) == 0 && tf_change<0)
572       {
573          if (encode)
574             haar1(X, N_B, B);
575          if (lowband)
576             haar1(lowband, N_B, B);
577          B <<= 1;
578          N_B >>= 1;
579          time_divide++;
580          tf_change++;
581       }
582       B0=B;
583       N_B0 = N_B;
584
585       /* Reorganize the samples in time order instead of frequency order */
586       if (B0>1)
587       {
588          if (encode)
589             deinterleave_vector(X, N_B, B0);
590          if (lowband)
591             deinterleave_vector(lowband, N_B, B0);
592       }
593    }
594
595    /* If we need more than 32 bits, try splitting the band in two. */
596    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
597    {
598       if (LM>0 || (N&1)==0)
599       {
600          N >>= 1;
601          Y = X+N;
602          split = 1;
603          LM -= 1;
604          B = (B+1)>>1;
605       }
606    }
607
608    if (split)
609    {
610       int qn;
611       int itheta=0;
612       int mbits, sbits, delta;
613       int qalloc;
614       int offset;
615
616       /* Decide on the resolution to give to the split parameter theta */
617       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
618       qn = compute_qn(N, b, offset, stereo);
619
620       qalloc = 0;
621       if (qn!=1)
622       {
623          if (encode)
624          {
625             if (stereo)
626                stereo_split(X, Y, N);
627
628             mid = vector_norm(X, N);
629             side = vector_norm(Y, N);
630             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
631                      Let's do that at higher complexity */
632             /*mid = renormalise_vector(X, Q15ONE, N, 1);
633             side = renormalise_vector(Y, Q15ONE, N, 1);*/
634
635             /* theta is the atan() of the ratio between the (normalized)
636                side and mid. With just that parameter, we can re-scale both
637                mid and side because we know that 1) they have unit norm and
638                2) they are orthogonal. */
639    #ifdef FIXED_POINT
640             /* 0.63662 = 2/pi */
641             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
642    #else
643             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
644    #endif
645
646             itheta = (itheta*qn+8192)>>14;
647          }
648
649          /* Entropy coding of the angle. We use a uniform pdf for the
650             first stereo split but a triangular one for the rest. */
651          if (stereo || B>1)
652          {
653             if (encode)
654                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
655             else
656                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
657             qalloc = log2_frac(qn+1,BITRES);
658          } else {
659             int fs=1, ft;
660             ft = ((qn>>1)+1)*((qn>>1)+1);
661             if (encode)
662             {
663                int fl;
664
665                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
666                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
667                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
668
669                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
670             } else {
671                int fl=0;
672                int fm;
673                fm = ec_decode((ec_dec*)ec, ft);
674
675                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
676                {
677                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
678                   fs = itheta + 1;
679                   fl = itheta*(itheta + 1)>>1;
680                }
681                else
682                {
683                   itheta = (2*(qn + 1)
684                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
685                   fs = qn + 1 - itheta;
686                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
687                }
688
689                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
690             }
691             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
692          }
693          itheta = (celt_int32)itheta*16384/qn;
694       } else {
695          if (stereo && encode)
696             intensity_stereo(m, X, Y, bandE, i, N);
697       }
698
699       if (itheta == 0)
700       {
701          imid = 32767;
702          iside = 0;
703          delta = -10000;
704       } else if (itheta == 16384)
705       {
706          imid = 0;
707          iside = 32767;
708          delta = 10000;
709       } else {
710          imid = bitexact_cos(itheta);
711          iside = bitexact_cos(16384-itheta);
712          /* This is the mid vs side allocation that minimizes squared error
713             in that band. */
714          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
715       }
716
717 #ifdef FIXED_POINT
718       mid = imid;
719       side = iside;
720 #else
721       mid = (1.f/32768)*imid;
722       side = (1.f/32768)*iside;
723 #endif
724
725       /* This is a special case for N=2 that only works for stereo and takes
726          advantage of the fact that mid and side are orthogonal to encode
727          the side with just one bit. */
728       if (N==2 && stereo)
729       {
730          int c;
731          int sign=1;
732          celt_norm *x2, *y2;
733          mbits = b-qalloc;
734          sbits = 0;
735          /* Only need one bit for the side */
736          if (itheta != 0 && itheta != 16384)
737             sbits = 1<<BITRES;
738          mbits -= sbits;
739          c = itheta > 8192;
740          *remaining_bits -= qalloc+sbits;
741
742          x2 = c ? Y : X;
743          y2 = c ? X : Y;
744          if (sbits)
745          {
746             if (encode)
747             {
748                /* Here we only need to encode a sign for the side */
749                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
750                ec_enc_bits((ec_enc*)ec, sign, 1);
751             } else {
752                sign = ec_dec_bits((ec_dec*)ec, 1);
753             }
754          }
755          sign = 2*sign - 1;
756          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch);
757          y2[0] = -sign*x2[1];
758          y2[1] = sign*x2[0];
759          if (resynth)
760          {
761             celt_norm tmp;
762             X[0] = MULT16_16_Q15(mid, X[0]);
763             X[1] = MULT16_16_Q15(mid, X[1]);
764             Y[0] = MULT16_16_Q15(side, Y[0]);
765             Y[1] = MULT16_16_Q15(side, Y[1]);
766             tmp = X[0];
767             X[0] = SUB16(tmp,Y[0]);
768             Y[0] = ADD16(tmp,Y[0]);
769             tmp = X[1];
770             X[1] = SUB16(tmp,Y[1]);
771             Y[1] = ADD16(tmp,Y[1]);
772          }
773       } else {
774          /* "Normal" split code */
775          celt_norm *next_lowband2=NULL;
776          celt_norm *next_lowband_out1=NULL;
777          int next_level=0;
778
779          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
780          if (B>1 && !stereo)
781             delta >>= 1;
782
783          mbits = (b-qalloc-delta)/2;
784          if (mbits > b-qalloc)
785             mbits = b-qalloc;
786          if (mbits<0)
787             mbits=0;
788          sbits = b-qalloc-mbits;
789          *remaining_bits -= qalloc;
790
791          if (lowband && !stereo)
792             next_lowband2 = lowband+N; /* >32-bit split case */
793
794          /* Only stereo needs to pass on lowband_out. Otherwise, it's
795             handled at the end */
796          if (stereo)
797             next_lowband_out1 = lowband_out;
798          else
799             next_level = level+1;
800
801          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
802                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
803                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
804          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
805                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
806                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
807       }
808
809    } else {
810       /* This is the basic no-split case */
811       q = bits2pulses(m, i, LM, b);
812       curr_bits = pulses2bits(m, i, LM, q);
813       *remaining_bits -= curr_bits;
814
815       /* Ensures we can never bust the budget */
816       while (*remaining_bits < 0 && q > 0)
817       {
818          *remaining_bits += curr_bits;
819          q--;
820          curr_bits = pulses2bits(m, i, LM, q);
821          *remaining_bits -= curr_bits;
822       }
823
824       /* Finally do the actual quantization */
825       if (encode)
826          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
827       else
828          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed, gain);
829    }
830
831    /* This code is used by the decoder and by the resynthesis-enabled encoder */
832    if (resynth)
833    {
834       if (stereo)
835       {
836          if (N!=2)
837             stereo_merge(X, Y, mid, side, N);
838       } else if (level == 0)
839       {
840          int k;
841
842          /* Undo the sample reorganization going from time order to frequency order */
843          if (B0>1)
844             interleave_vector(X, N_B, B0);
845
846          /* Undo time-freq changes that we did earlier */
847          N_B = N_B0;
848          B = B0;
849          for (k=0;k<time_divide;k++)
850          {
851             B >>= 1;
852             N_B <<= 1;
853             haar1(X, N_B, B);
854          }
855
856          for (k=0;k<recombine;k++)
857          {
858             haar1(X, N_B, B);
859             N_B>>=1;
860             B <<= 1;
861          }
862
863          /* Scale output for later folding */
864          if (lowband_out)
865          {
866             int j;
867             celt_word16 n;
868             n = celt_sqrt(SHL32(EXTEND32(N0),22));
869             for (j=0;j<N0;j++)
870                lowband_out[j] = MULT16_16_Q15(n,X[j]);
871          }
872       }
873    }
874 }
875
876 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM, int codedBands)
877 {
878    int i, balance;
879    celt_int32 remaining_bits;
880    const celt_int16 * restrict eBands = m->eBands;
881    celt_norm * restrict norm;
882    VARDECL(celt_norm, _norm);
883    VARDECL(celt_norm, lowband_scratch);
884    int B;
885    int M;
886    celt_int32 seed;
887    celt_norm *lowband;
888    int update_lowband = 1;
889    int C = _Y != NULL ? 2 : 1;
890    SAVE_STACK;
891
892    M = 1<<LM;
893    B = shortBlocks ? M : 1;
894    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
895    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
896    norm = _norm;
897
898    if (encode)
899       seed = ((ec_enc*)ec)->rng;
900    else
901       seed = ((ec_dec*)ec)->rng;
902    balance = 0;
903    lowband = NULL;
904    for (i=start;i<end;i++)
905    {
906       int tell;
907       int b;
908       int N;
909       int curr_balance;
910       celt_norm * restrict X, * restrict Y;
911       int tf_change=0;
912       celt_norm *effective_lowband;
913       
914       X = _X+M*eBands[i];
915       if (_Y!=NULL)
916          Y = _Y+M*eBands[i];
917       else
918          Y = NULL;
919       N = M*eBands[i+1]-M*eBands[i];
920       if (encode)
921          tell = ec_enc_tell((ec_enc*)ec, BITRES);
922       else
923          tell = ec_dec_tell((ec_dec*)ec, BITRES);
924
925       /* Compute how many bits we want to allocate to this band */
926       if (i != start)
927          balance -= tell;
928       remaining_bits = (total_bits<<BITRES)-tell-1;
929       if (i <= codedBands-1)
930       {
931          curr_balance = (codedBands-i);
932          if (curr_balance > 3)
933             curr_balance = 3;
934          curr_balance = balance / curr_balance;
935          b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
936          if (b<0)
937             b = 0;
938       } else {
939          b = 0;
940       }
941       /* Prevents ridiculous bit depths */
942       if (b > C*16*N<<BITRES)
943          b = C*16*N<<BITRES;
944
945       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
946             lowband = norm+M*eBands[i]-N;
947
948       tf_change = tf_res[i];
949       if (i>=m->effEBands)
950       {
951          X=norm;
952          if (_Y!=NULL)
953             Y = norm;
954       }
955
956       if (tf_change==0 && !shortBlocks && fold==2)
957          effective_lowband = NULL;
958       else
959          effective_lowband = lowband;
960       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
961             effective_lowband, resynth, ec, &remaining_bits, LM,
962             norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
963
964       balance += pulses[i] + tell;
965
966       /* Update the folding position only as long as we have 2 bit/sample depth */
967       update_lowband = (b>>BITRES)>2*N;
968    }
969    RESTORE_STACK;
970 }
971