Fixes a fixed-point overflow in haar1()
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    celt_word16 left, right;
221    celt_word16 norm;
222 #ifdef FIXED_POINT
223    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
224 #endif
225    left = VSHR32(bank[i],shift);
226    right = VSHR32(bank[i+m->nbEBands],shift);
227    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
228    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
229    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
230    for (j=0;j<N;j++)
231    {
232       celt_norm r, l;
233       l = X[j];
234       r = Y[j];
235       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
236       /* Side is not encoded, no need to calculate */
237    }
238 }
239
240 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
241 {
242    int j;
243    for (j=0;j<N;j++)
244    {
245       celt_norm r, l;
246       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
247       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
248       X[j] = l+r;
249       Y[j] = r-l;
250    }
251 }
252
253 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
254 {
255    int j;
256    celt_word32 xp=0;
257    celt_word32 El, Er;
258 #ifdef FIXED_POINT
259    int kl, kr;
260 #endif
261    celt_word32 t, lgain, rgain;
262
263    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
264    for (j=0;j<N;j++)
265       xp = MAC16_16(xp, X[j], Y[j]);
266    /* mid and side are in Q15, not Q14 like X and Y */
267    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*SHL32(xp,2);
268    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*SHL32(xp,2);
269    if (Er < EPSILON)
270       Er = EPSILON;
271    if (El < EPSILON)
272       El = EPSILON;
273
274 #ifdef FIXED_POINT
275    kl = celt_ilog2(El)>>1;
276    kr = celt_ilog2(Er)>>1;
277 #endif
278    t = VSHR32(El, (kl-7)<<1);
279    lgain = celt_rsqrt_norm(t);
280    t = VSHR32(Er, (kr-7)<<1);
281    rgain = celt_rsqrt_norm(t);
282
283    for (j=0;j<N;j++)
284    {
285       celt_norm r, l;
286       l = X[j];
287       r = Y[j];
288       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, l-r), kl));
289       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, l+r), kr));
290    }
291 }
292
293 /* Decide whether we should spread the pulses in the current frame */
294 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
295 {
296    int i, c, N0;
297    int sum = 0, nbBands=0;
298    const int C = CHANNELS(_C);
299    const celt_int16 * restrict eBands = m->eBands;
300    int decision;
301    
302    N0 = M*m->shortMdctSize;
303
304    if (M*(eBands[end]-eBands[end-1]) <= 8)
305       return 0;
306    for (c=0;c<C;c++)
307    {
308       for (i=0;i<end;i++)
309       {
310          int j, N, tmp=0;
311          int tcount[3] = {0};
312          celt_norm * restrict x = X+M*eBands[i]+c*N0;
313          N = M*(eBands[i+1]-eBands[i]);
314          if (N<=8)
315             continue;
316          /* Compute rough CDF of |x[j]| */
317          for (j=0;j<N;j++)
318          {
319             celt_word32 x2N; /* Q13 */
320
321             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
322             if (x2N < QCONST16(0.25f,13))
323                tcount[0]++;
324             if (x2N < QCONST16(0.0625f,13))
325                tcount[1]++;
326             if (x2N < QCONST16(0.015625f,13))
327                tcount[2]++;
328          }
329
330          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
331          sum += tmp*256;
332          nbBands++;
333       }
334    }
335    sum /= nbBands;
336    /* Recursive averaging */
337    sum = (sum+*average)>>1;
338    *average = sum;
339    /* Hysteresis */
340    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
341    /* decision and last_decision do not use the same encoding */
342    if (sum < 128)
343    {
344       decision = 2;
345       *last_decision = 0;
346    } else if (sum < 256)
347    {
348       decision = 1;
349       *last_decision = 1;
350    } else if (sum < 384)
351    {
352       decision = 3;
353       *last_decision = 2;
354    } else {
355       decision = 0;
356       *last_decision = 3;
357    }
358    return decision;
359 }
360
361 #ifdef MEASURE_NORM_MSE
362
363 float MSE[30] = {0};
364 int nbMSEBands = 0;
365 int MSECount[30] = {0};
366
367 void dump_norm_mse(void)
368 {
369    int i;
370    for (i=0;i<nbMSEBands;i++)
371    {
372       printf ("%g ", MSE[i]/MSECount[i]);
373    }
374    printf ("\n");
375 }
376
377 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
378 {
379    static int init = 0;
380    int i;
381    if (!init)
382    {
383       atexit(dump_norm_mse);
384       init = 1;
385    }
386    for (i=0;i<m->nbEBands;i++)
387    {
388       int j;
389       int c;
390       float g;
391       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
392          continue;
393       for (c=0;c<C;c++)
394       {
395          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
396          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
397             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
398       }
399       MSECount[i]+=C;
400    }
401    nbMSEBands = m->nbEBands;
402 }
403
404 #endif
405
406 static void interleave_vector(celt_norm *X, int N0, int stride)
407 {
408    int i,j;
409    VARDECL(celt_norm, tmp);
410    int N;
411    SAVE_STACK;
412    N = N0*stride;
413    ALLOC(tmp, N, celt_norm);
414    for (i=0;i<stride;i++)
415       for (j=0;j<N0;j++)
416          tmp[j*stride+i] = X[i*N0+j];
417    for (j=0;j<N;j++)
418       X[j] = tmp[j];
419    RESTORE_STACK;
420 }
421
422 static void deinterleave_vector(celt_norm *X, int N0, int stride)
423 {
424    int i,j;
425    VARDECL(celt_norm, tmp);
426    int N;
427    SAVE_STACK;
428    N = N0*stride;
429    ALLOC(tmp, N, celt_norm);
430    for (i=0;i<stride;i++)
431       for (j=0;j<N0;j++)
432          tmp[i*N0+j] = X[j*stride+i];
433    for (j=0;j<N;j++)
434       X[j] = tmp[j];
435    RESTORE_STACK;
436 }
437
438 static void haar1(celt_norm *X, int N0, int stride)
439 {
440    int i, j;
441    N0 >>= 1;
442    for (i=0;i<stride;i++)
443       for (j=0;j<N0;j++)
444       {
445          celt_norm tmp1, tmp2;
446          tmp1 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i]);
447          tmp2 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*(2*j+1)+i]);
448          X[stride*2*j+i] = tmp1 + tmp2;
449          X[stride*(2*j+1)+i] = tmp1 - tmp2;
450       }
451 }
452
453 static int compute_qn(int N, int b, int offset, int stereo)
454 {
455    static const celt_int16 exp2_table8[8] =
456       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
457    int qn, qb;
458    int N2 = 2*N-1;
459    if (stereo && N==2)
460       N2--;
461    qb = (b+N2*offset)/N2;
462    if (qb > (b>>1)-(1<<BITRES))
463       qb = (b>>1)-(1<<BITRES);
464
465    if (qb<0)
466        qb = 0;
467    if (qb>8<<BITRES)
468      qb = 8<<BITRES;
469
470    if (qb<(1<<BITRES>>1)) {
471       qn = 1;
472    } else {
473       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
474       qn = (qn+1)>>1<<1;
475    }
476    celt_assert(qn <= 256);
477    return qn;
478 }
479
480
481 /* This function is responsible for encoding and decoding a band for both
482    the mono and stereo case. Even in the mono case, it can split the band
483    in two and transmit the energy difference with the two half-bands. It
484    can be called recursively so bands can end up being split in 8 parts. */
485 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
486       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
487       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
488       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
489 {
490    int q;
491    int curr_bits;
492    int stereo, split;
493    int imid=0, iside=0;
494    int N0=N;
495    int N_B=N;
496    int N_B0;
497    int B0=B;
498    int time_divide=0;
499    int recombine=0;
500    celt_word16 mid=0, side=0;
501
502    N_B /= B;
503    N_B0 = N_B;
504
505    split = stereo = Y != NULL;
506
507    /* Special case for one sample */
508    if (N==1)
509    {
510       int c;
511       celt_norm *x = X;
512       for (c=0;c<1+stereo;c++)
513       {
514          int sign=0;
515          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
516          {
517             if (encode)
518             {
519                sign = x[0]<0;
520                ec_enc_bits((ec_enc*)ec, sign, 1);
521             } else {
522                sign = ec_dec_bits((ec_dec*)ec, 1);
523             }
524             *remaining_bits -= 1<<BITRES;
525             b-=1<<BITRES;
526          }
527          if (resynth)
528             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
529          x = Y;
530       }
531       if (lowband_out)
532          lowband_out[0] = X[0];
533       return;
534    }
535
536    if (!stereo && level == 0)
537    {
538       int k;
539       if (tf_change>0)
540          recombine = tf_change;
541       /* Band recombining to increase frequency resolution */
542
543       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
544       {
545          int j;
546          for (j=0;j<N;j++)
547             lowband_scratch[j] = lowband[j];
548          lowband = lowband_scratch;
549       }
550
551       for (k=0;k<recombine;k++)
552       {
553          B>>=1;
554          N_B<<=1;
555          if (encode)
556             haar1(X, N_B, B);
557          if (lowband)
558             haar1(lowband, N_B, B);
559       }
560
561       /* Increasing the time resolution */
562       while ((N_B&1) == 0 && tf_change<0)
563       {
564          if (encode)
565             haar1(X, N_B, B);
566          if (lowband)
567             haar1(lowband, N_B, B);
568          B <<= 1;
569          N_B >>= 1;
570          time_divide++;
571          tf_change++;
572       }
573       B0=B;
574       N_B0 = N_B;
575
576       /* Reorganize the samples in time order instead of frequency order */
577       if (B0>1)
578       {
579          if (encode)
580             deinterleave_vector(X, N_B, B0);
581          if (lowband)
582             deinterleave_vector(lowband, N_B, B0);
583       }
584    }
585
586    /* If we need more than 32 bits, try splitting the band in two. */
587    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
588    {
589       if (LM>0 || (N&1)==0)
590       {
591          N >>= 1;
592          Y = X+N;
593          split = 1;
594          LM -= 1;
595          B = (B+1)>>1;
596       }
597    }
598
599    if (split)
600    {
601       int qn;
602       int itheta=0;
603       int mbits, sbits, delta;
604       int qalloc;
605       int offset;
606
607       /* Decide on the resolution to give to the split parameter theta */
608       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
609       qn = compute_qn(N, b, offset, stereo);
610
611       qalloc = 0;
612       if (qn!=1)
613       {
614          if (encode)
615          {
616             if (stereo)
617                stereo_split(X, Y, N);
618
619             mid = vector_norm(X, N);
620             side = vector_norm(Y, N);
621             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
622                      Let's do that at higher complexity */
623             /*mid = renormalise_vector(X, Q15ONE, N, 1);
624             side = renormalise_vector(Y, Q15ONE, N, 1);*/
625
626             /* theta is the atan() of the ratio between the (normalized)
627                side and mid. With just that parameter, we can re-scale both
628                mid and side because we know that 1) they have unit norm and
629                2) they are orthogonal. */
630    #ifdef FIXED_POINT
631             /* 0.63662 = 2/pi */
632             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
633    #else
634             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
635    #endif
636
637             itheta = (itheta*qn+8192)>>14;
638          }
639
640          /* Entropy coding of the angle. We use a uniform pdf for the
641             first stereo split but a triangular one for the rest. */
642          if (stereo || B>1)
643          {
644             if (encode)
645                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
646             else
647                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
648             qalloc = log2_frac(qn+1,BITRES);
649          } else {
650             int fs=1, ft;
651             ft = ((qn>>1)+1)*((qn>>1)+1);
652             if (encode)
653             {
654                int fl;
655
656                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
657                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
658                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
659
660                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
661             } else {
662                int fl=0;
663                int fm;
664                fm = ec_decode((ec_dec*)ec, ft);
665
666                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
667                {
668                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
669                   fs = itheta + 1;
670                   fl = itheta*(itheta + 1)>>1;
671                }
672                else
673                {
674                   itheta = (2*(qn + 1)
675                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
676                   fs = qn + 1 - itheta;
677                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
678                }
679
680                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
681             }
682             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
683          }
684          itheta = (celt_int32)itheta*16384/qn;
685       } else {
686          if (stereo && encode)
687             intensity_stereo(m, X, Y, bandE, i, N);
688       }
689
690       if (itheta == 0)
691       {
692          imid = 32767;
693          iside = 0;
694          delta = -10000;
695       } else if (itheta == 16384)
696       {
697          imid = 0;
698          iside = 32767;
699          delta = 10000;
700       } else {
701          imid = bitexact_cos(itheta);
702          iside = bitexact_cos(16384-itheta);
703          /* This is the mid vs side allocation that minimizes squared error
704             in that band. */
705          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
706       }
707
708 #ifdef FIXED_POINT
709       mid = imid;
710       side = iside;
711 #else
712       mid = (1.f/32768)*imid;
713       side = (1.f/32768)*iside;
714 #endif
715
716       /* This is a special case for N=2 that only works for stereo and takes
717          advantage of the fact that mid and side are orthogonal to encode
718          the side with just one bit. */
719       if (N==2 && stereo)
720       {
721          int c;
722          int sign=1;
723          celt_norm *x2, *y2;
724          mbits = b-qalloc;
725          sbits = 0;
726          /* Only need one bit for the side */
727          if (itheta != 0 && itheta != 16384)
728             sbits = 1<<BITRES;
729          mbits -= sbits;
730          c = itheta > 8192;
731          *remaining_bits -= qalloc+sbits;
732
733          x2 = c ? Y : X;
734          y2 = c ? X : Y;
735          if (sbits)
736          {
737             if (encode)
738             {
739                /* Here we only need to encode a sign for the side */
740                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
741                ec_enc_bits((ec_enc*)ec, sign, 1);
742             } else {
743                sign = ec_dec_bits((ec_dec*)ec, 1);
744             }
745          }
746          sign = 2*sign - 1;
747          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed, gain, lowband_scratch);
748          y2[0] = -sign*x2[1];
749          y2[1] = sign*x2[0];
750          if (resynth)
751          {
752             celt_norm tmp;
753             X[0] = MULT16_16_Q15(mid, X[0]);
754             X[1] = MULT16_16_Q15(mid, X[1]);
755             Y[0] = MULT16_16_Q15(side, Y[0]);
756             Y[1] = MULT16_16_Q15(side, Y[1]);
757             tmp = X[0];
758             X[0] = SUB16(tmp,Y[0]);
759             Y[0] = ADD16(tmp,Y[0]);
760             tmp = X[1];
761             X[1] = SUB16(tmp,Y[1]);
762             Y[1] = ADD16(tmp,Y[1]);
763          }
764       } else {
765          /* "Normal" split code */
766          celt_norm *next_lowband2=NULL;
767          celt_norm *next_lowband_out1=NULL;
768          int next_level=0;
769
770          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
771          if (B>1 && !stereo)
772             delta >>= 1;
773
774          mbits = (b-qalloc-delta)/2;
775          if (mbits > b-qalloc)
776             mbits = b-qalloc;
777          if (mbits<0)
778             mbits=0;
779          sbits = b-qalloc-mbits;
780          *remaining_bits -= qalloc;
781
782          if (lowband && !stereo)
783             next_lowband2 = lowband+N; /* >32-bit split case */
784
785          /* Only stereo needs to pass on lowband_out. Otherwise, it's
786             handled at the end */
787          if (stereo)
788             next_lowband_out1 = lowband_out;
789          else
790             next_level = level+1;
791
792          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
793                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
794                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
795          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
796                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
797                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
798       }
799
800    } else {
801       /* This is the basic no-split case */
802       q = bits2pulses(m, i, LM, b);
803       curr_bits = pulses2bits(m, i, LM, q);
804       *remaining_bits -= curr_bits;
805
806       /* Ensures we can never bust the budget */
807       while (*remaining_bits < 0 && q > 0)
808       {
809          *remaining_bits += curr_bits;
810          q--;
811          curr_bits = pulses2bits(m, i, LM, q);
812          *remaining_bits -= curr_bits;
813       }
814
815       /* Finally do the actual quantization */
816       if (encode)
817          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
818       else
819          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed, gain);
820    }
821
822    /* This code is used by the decoder and by the resynthesis-enabled encoder */
823    if (resynth)
824    {
825       if (stereo)
826       {
827          if (N!=2)
828             stereo_merge(X, Y, mid, side, N);
829       } else if (level == 0)
830       {
831          int k;
832
833          /* Undo the sample reorganization going from time order to frequency order */
834          if (B0>1)
835             interleave_vector(X, N_B, B0);
836
837          /* Undo time-freq changes that we did earlier */
838          N_B = N_B0;
839          B = B0;
840          for (k=0;k<time_divide;k++)
841          {
842             B >>= 1;
843             N_B <<= 1;
844             haar1(X, N_B, B);
845          }
846
847          for (k=0;k<recombine;k++)
848          {
849             haar1(X, N_B, B);
850             N_B>>=1;
851             B <<= 1;
852          }
853
854          /* Scale output for later folding */
855          if (lowband_out)
856          {
857             int j;
858             celt_word16 n;
859             n = celt_sqrt(SHL32(EXTEND32(N0),22));
860             for (j=0;j<N0;j++)
861                lowband_out[j] = MULT16_16_Q15(n,X[j]);
862          }
863       }
864    }
865 }
866
867 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
868 {
869    int i, balance;
870    celt_int32 remaining_bits;
871    const celt_int16 * restrict eBands = m->eBands;
872    celt_norm * restrict norm;
873    VARDECL(celt_norm, _norm);
874    VARDECL(celt_norm, lowband_scratch);
875    int B;
876    int M;
877    celt_int32 seed;
878    celt_norm *lowband;
879    int update_lowband = 1;
880    int C = _Y != NULL ? 2 : 1;
881    SAVE_STACK;
882
883    M = 1<<LM;
884    B = shortBlocks ? M : 1;
885    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
886    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
887    norm = _norm;
888
889    if (encode)
890       seed = ((ec_enc*)ec)->rng;
891    else
892       seed = ((ec_dec*)ec)->rng;
893    balance = 0;
894    lowband = NULL;
895    for (i=start;i<end;i++)
896    {
897       int tell;
898       int b;
899       int N;
900       int curr_balance;
901       celt_norm * restrict X, * restrict Y;
902       int tf_change=0;
903       celt_norm *effective_lowband;
904       
905       X = _X+M*eBands[i];
906       if (_Y!=NULL)
907          Y = _Y+M*eBands[i];
908       else
909          Y = NULL;
910       N = M*eBands[i+1]-M*eBands[i];
911       if (encode)
912          tell = ec_enc_tell((ec_enc*)ec, BITRES);
913       else
914          tell = ec_dec_tell((ec_dec*)ec, BITRES);
915
916       /* Compute how many bits we want to allocate to this band */
917       if (i != start)
918          balance -= tell;
919       remaining_bits = (total_bits<<BITRES)-tell-1;
920       curr_balance = (end-i);
921       if (curr_balance > 3)
922          curr_balance = 3;
923       curr_balance = balance / curr_balance;
924       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
925       if (b<0)
926          b = 0;
927       /* Prevents ridiculous bit depths */
928       if (b > C*16*N<<BITRES)
929          b = C*16*N<<BITRES;
930
931       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
932             lowband = norm+M*eBands[i]-N;
933
934       tf_change = tf_res[i];
935       if (i>=m->effEBands)
936       {
937          X=norm;
938          if (_Y!=NULL)
939             Y = norm;
940       }
941
942       if (tf_change==0 && !shortBlocks && fold)
943          effective_lowband = NULL;
944       else
945          effective_lowband = lowband;
946       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
947             effective_lowband, resynth, ec, &remaining_bits, LM,
948             norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
949
950       balance += pulses[i] + tell;
951
952       /* Update the folding position only as long as we have 2 bit/sample depth */
953       update_lowband = (b>>BITRES)>2*N;
954    }
955    RESTORE_STACK;
956 }
957