Folding code moved to quant_band() to prevent duplication.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i], Q15ONE);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    celt_word16 left, right;
221    celt_word16 norm;
222 #ifdef FIXED_POINT
223    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
224 #endif
225    left = VSHR32(bank[i],shift);
226    right = VSHR32(bank[i+m->nbEBands],shift);
227    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
228    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
229    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
230    for (j=0;j<N;j++)
231    {
232       celt_norm r, l;
233       l = X[j];
234       r = Y[j];
235       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
236       /* Side is not encoded, no need to calculate */
237    }
238 }
239
240 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
241 {
242    int j;
243    for (j=0;j<N;j++)
244    {
245       celt_norm r, l;
246       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
247       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
248       X[j] = l+r;
249       Y[j] = r-l;
250    }
251 }
252
253 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
254 {
255    int j;
256    celt_word32 xp=0;
257    celt_word32 El, Er;
258 #ifdef FIXED_POINT
259    int kl, kr;
260 #endif
261    celt_word32 t, lgain, rgain;
262
263    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
264    for (j=0;j<N;j++)
265       xp = MAC16_16(xp, X[j], Y[j]);
266    /* mid and side are in Q15, not Q14 like X and Y */
267    mid = SHR32(mid, 1);
268    side = SHR32(side, 1);
269    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*xp;
270    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*xp;
271    if (Er < EPSILON)
272       Er = EPSILON;
273    if (El < EPSILON)
274       El = EPSILON;
275
276 #ifdef FIXED_POINT
277    kl = celt_ilog2(El)>>1;
278    kr = celt_ilog2(Er)>>1;
279 #endif
280    t = VSHR32(El, (kl-7)<<1);
281    lgain = celt_rsqrt_norm(t);
282    t = VSHR32(Er, (kr-7)<<1);
283    rgain = celt_rsqrt_norm(t);
284
285 #ifdef FIXED_POINT
286    if (kl < 7)
287       kl = 7;
288    if (kr < 7)
289       kr = 7;
290 #endif
291
292    for (j=0;j<N;j++)
293    {
294       celt_norm r, l;
295       l = X[j];
296       r = Y[j];
297       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
298       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
299    }
300 }
301
302 /* Decide whether we should spread the pulses in the current frame */
303 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
304 {
305    int i, c, N0;
306    int sum = 0, nbBands=0;
307    const int C = CHANNELS(_C);
308    const celt_int16 * restrict eBands = m->eBands;
309    int decision;
310    
311    N0 = M*m->shortMdctSize;
312
313    if (M*(eBands[end]-eBands[end-1]) <= 8)
314       return 0;
315    for (c=0;c<C;c++)
316    {
317       for (i=0;i<end;i++)
318       {
319          int j, N, tmp=0;
320          int tcount[3] = {0};
321          celt_norm * restrict x = X+M*eBands[i]+c*N0;
322          N = M*(eBands[i+1]-eBands[i]);
323          if (N<=8)
324             continue;
325          /* Compute rough CDF of |x[j]| */
326          for (j=0;j<N;j++)
327          {
328             celt_word32 x2N; /* Q13 */
329
330             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
331             if (x2N < QCONST16(0.25f,13))
332                tcount[0]++;
333             if (x2N < QCONST16(0.0625f,13))
334                tcount[1]++;
335             if (x2N < QCONST16(0.015625f,13))
336                tcount[2]++;
337          }
338
339          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
340          sum += tmp*256;
341          nbBands++;
342       }
343    }
344    sum /= nbBands;
345    /* Recursive averaging */
346    sum = (sum+*average)>>1;
347    *average = sum;
348    /* Hysteresis */
349    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
350    /* decision and last_decision do not use the same encoding */
351    if (sum < 80)
352    {
353       decision = 2;
354       *last_decision = 0;
355    } else if (sum < 256)
356    {
357       decision = 1;
358       *last_decision = 1;
359    } else if (sum < 384)
360    {
361       decision = 3;
362       *last_decision = 2;
363    } else {
364       decision = 0;
365       *last_decision = 3;
366    }
367    return decision;
368 }
369
370 #ifdef MEASURE_NORM_MSE
371
372 float MSE[30] = {0};
373 int nbMSEBands = 0;
374 int MSECount[30] = {0};
375
376 void dump_norm_mse(void)
377 {
378    int i;
379    for (i=0;i<nbMSEBands;i++)
380    {
381       printf ("%g ", MSE[i]/MSECount[i]);
382    }
383    printf ("\n");
384 }
385
386 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
387 {
388    static int init = 0;
389    int i;
390    if (!init)
391    {
392       atexit(dump_norm_mse);
393       init = 1;
394    }
395    for (i=0;i<m->nbEBands;i++)
396    {
397       int j;
398       int c;
399       float g;
400       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
401          continue;
402       for (c=0;c<C;c++)
403       {
404          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
405          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
406             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
407       }
408       MSECount[i]+=C;
409    }
410    nbMSEBands = m->nbEBands;
411 }
412
413 #endif
414
415 static void interleave_vector(celt_norm *X, int N0, int stride)
416 {
417    int i,j;
418    VARDECL(celt_norm, tmp);
419    int N;
420    SAVE_STACK;
421    N = N0*stride;
422    ALLOC(tmp, N, celt_norm);
423    for (i=0;i<stride;i++)
424       for (j=0;j<N0;j++)
425          tmp[j*stride+i] = X[i*N0+j];
426    for (j=0;j<N;j++)
427       X[j] = tmp[j];
428    RESTORE_STACK;
429 }
430
431 static void deinterleave_vector(celt_norm *X, int N0, int stride)
432 {
433    int i,j;
434    VARDECL(celt_norm, tmp);
435    int N;
436    SAVE_STACK;
437    N = N0*stride;
438    ALLOC(tmp, N, celt_norm);
439    for (i=0;i<stride;i++)
440       for (j=0;j<N0;j++)
441          tmp[i*N0+j] = X[j*stride+i];
442    for (j=0;j<N;j++)
443       X[j] = tmp[j];
444    RESTORE_STACK;
445 }
446
447 void haar1(celt_norm *X, int N0, int stride)
448 {
449    int i, j;
450    N0 >>= 1;
451    for (i=0;i<stride;i++)
452       for (j=0;j<N0;j++)
453       {
454          celt_norm tmp1, tmp2;
455          tmp1 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i]);
456          tmp2 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*(2*j+1)+i]);
457          X[stride*2*j+i] = tmp1 + tmp2;
458          X[stride*(2*j+1)+i] = tmp1 - tmp2;
459       }
460 }
461
462 static int compute_qn(int N, int b, int offset, int stereo)
463 {
464    static const celt_int16 exp2_table8[8] =
465       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
466    int qn, qb;
467    int N2 = 2*N-1;
468    if (stereo && N==2)
469       N2--;
470    qb = (b+N2*offset)/N2;
471    if (qb > (b>>1)-(1<<BITRES))
472       qb = (b>>1)-(1<<BITRES);
473
474    if (qb<0)
475        qb = 0;
476    if (qb>8<<BITRES)
477      qb = 8<<BITRES;
478
479    if (qb<(1<<BITRES>>1)) {
480       qn = 1;
481    } else {
482       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
483       qn = (qn+1)>>1<<1;
484    }
485    celt_assert(qn <= 256);
486    return qn;
487 }
488
489 static celt_uint32 lcg_rand(celt_uint32 seed)
490 {
491    return 1664525 * seed + 1013904223;
492 }
493
494 /* This function is responsible for encoding and decoding a band for both
495    the mono and stereo case. Even in the mono case, it can split the band
496    in two and transmit the energy difference with the two half-bands. It
497    can be called recursively so bands can end up being split in 8 parts. */
498 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
499       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
500       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
501       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
502 {
503    int q;
504    int curr_bits;
505    int stereo, split;
506    int imid=0, iside=0;
507    int N0=N;
508    int N_B=N;
509    int N_B0;
510    int B0=B;
511    int time_divide=0;
512    int recombine=0;
513    celt_word16 mid=0, side=0;
514
515    N_B /= B;
516    N_B0 = N_B;
517
518    split = stereo = Y != NULL;
519
520    /* Special case for one sample */
521    if (N==1)
522    {
523       int c;
524       celt_norm *x = X;
525       for (c=0;c<1+stereo;c++)
526       {
527          int sign=0;
528          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
529          {
530             if (encode)
531             {
532                sign = x[0]<0;
533                ec_enc_bits((ec_enc*)ec, sign, 1);
534             } else {
535                sign = ec_dec_bits((ec_dec*)ec, 1);
536             }
537             *remaining_bits -= 1<<BITRES;
538             b-=1<<BITRES;
539          }
540          if (resynth)
541             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
542          x = Y;
543       }
544       if (lowband_out)
545          lowband_out[0] = X[0];
546       return;
547    }
548
549    if (!stereo && level == 0)
550    {
551       int k;
552       if (tf_change>0)
553          recombine = tf_change;
554       /* Band recombining to increase frequency resolution */
555
556       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
557       {
558          int j;
559          for (j=0;j<N;j++)
560             lowband_scratch[j] = lowband[j];
561          lowband = lowband_scratch;
562       }
563
564       for (k=0;k<recombine;k++)
565       {
566          B>>=1;
567          N_B<<=1;
568          if (encode)
569             haar1(X, N_B, B);
570          if (lowband)
571             haar1(lowband, N_B, B);
572       }
573
574       /* Increasing the time resolution */
575       while ((N_B&1) == 0 && tf_change<0)
576       {
577          if (encode)
578             haar1(X, N_B, B);
579          if (lowband)
580             haar1(lowband, N_B, B);
581          B <<= 1;
582          N_B >>= 1;
583          time_divide++;
584          tf_change++;
585       }
586       B0=B;
587       N_B0 = N_B;
588
589       /* Reorganize the samples in time order instead of frequency order */
590       if (B0>1)
591       {
592          if (encode)
593             deinterleave_vector(X, N_B, B0);
594          if (lowband)
595             deinterleave_vector(lowband, N_B, B0);
596       }
597    }
598
599    /* If we need more than 32 bits, try splitting the band in two. */
600    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
601    {
602       if (LM>0 || (N&1)==0)
603       {
604          N >>= 1;
605          Y = X+N;
606          split = 1;
607          LM -= 1;
608          B = (B+1)>>1;
609       }
610    }
611
612    if (split)
613    {
614       int qn;
615       int itheta=0;
616       int mbits, sbits, delta;
617       int qalloc;
618       int offset;
619
620       /* Decide on the resolution to give to the split parameter theta */
621       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
622       qn = compute_qn(N, b, offset, stereo);
623
624       qalloc = 0;
625       if (qn!=1)
626       {
627          if (encode)
628          {
629             if (stereo)
630                stereo_split(X, Y, N);
631
632             mid = vector_norm(X, N);
633             side = vector_norm(Y, N);
634             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
635                      Let's do that at higher complexity */
636             /*mid = renormalise_vector(X, Q15ONE, N, 1);
637             side = renormalise_vector(Y, Q15ONE, N, 1);*/
638
639             /* theta is the atan() of the ratio between the (normalized)
640                side and mid. With just that parameter, we can re-scale both
641                mid and side because we know that 1) they have unit norm and
642                2) they are orthogonal. */
643    #ifdef FIXED_POINT
644             /* 0.63662 = 2/pi */
645             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
646    #else
647             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
648    #endif
649
650             itheta = (itheta*qn+8192)>>14;
651          }
652
653          /* Entropy coding of the angle. We use a uniform pdf for the
654             first stereo split but a triangular one for the rest. */
655          if (stereo || B>1)
656          {
657             if (encode)
658                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
659             else
660                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
661             qalloc = log2_frac(qn+1,BITRES);
662          } else {
663             int fs=1, ft;
664             ft = ((qn>>1)+1)*((qn>>1)+1);
665             if (encode)
666             {
667                int fl;
668
669                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
670                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
671                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
672
673                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
674             } else {
675                int fl=0;
676                int fm;
677                fm = ec_decode((ec_dec*)ec, ft);
678
679                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
680                {
681                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
682                   fs = itheta + 1;
683                   fl = itheta*(itheta + 1)>>1;
684                }
685                else
686                {
687                   itheta = (2*(qn + 1)
688                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
689                   fs = qn + 1 - itheta;
690                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
691                }
692
693                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
694             }
695             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
696          }
697          itheta = (celt_int32)itheta*16384/qn;
698       } else {
699          if (stereo && encode)
700             intensity_stereo(m, X, Y, bandE, i, N);
701       }
702
703       if (itheta == 0)
704       {
705          imid = 32767;
706          iside = 0;
707          delta = -10000;
708       } else if (itheta == 16384)
709       {
710          imid = 0;
711          iside = 32767;
712          delta = 10000;
713       } else {
714          imid = bitexact_cos(itheta);
715          iside = bitexact_cos(16384-itheta);
716          /* This is the mid vs side allocation that minimizes squared error
717             in that band. */
718          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
719       }
720
721 #ifdef FIXED_POINT
722       mid = imid;
723       side = iside;
724 #else
725       mid = (1.f/32768)*imid;
726       side = (1.f/32768)*iside;
727 #endif
728
729       /* This is a special case for N=2 that only works for stereo and takes
730          advantage of the fact that mid and side are orthogonal to encode
731          the side with just one bit. */
732       if (N==2 && stereo)
733       {
734          int c;
735          int sign=1;
736          celt_norm *x2, *y2;
737          mbits = b-qalloc;
738          sbits = 0;
739          /* Only need one bit for the side */
740          if (itheta != 0 && itheta != 16384)
741             sbits = 1<<BITRES;
742          mbits -= sbits;
743          c = itheta > 8192;
744          *remaining_bits -= qalloc+sbits;
745
746          x2 = c ? Y : X;
747          y2 = c ? X : Y;
748          if (sbits)
749          {
750             if (encode)
751             {
752                /* Here we only need to encode a sign for the side */
753                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
754                ec_enc_bits((ec_enc*)ec, sign, 1);
755             } else {
756                sign = ec_dec_bits((ec_dec*)ec, 1);
757             }
758          }
759          sign = 2*sign - 1;
760          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch);
761          y2[0] = -sign*x2[1];
762          y2[1] = sign*x2[0];
763          if (resynth)
764          {
765             celt_norm tmp;
766             X[0] = MULT16_16_Q15(mid, X[0]);
767             X[1] = MULT16_16_Q15(mid, X[1]);
768             Y[0] = MULT16_16_Q15(side, Y[0]);
769             Y[1] = MULT16_16_Q15(side, Y[1]);
770             tmp = X[0];
771             X[0] = SUB16(tmp,Y[0]);
772             Y[0] = ADD16(tmp,Y[0]);
773             tmp = X[1];
774             X[1] = SUB16(tmp,Y[1]);
775             Y[1] = ADD16(tmp,Y[1]);
776          }
777       } else {
778          /* "Normal" split code */
779          celt_norm *next_lowband2=NULL;
780          celt_norm *next_lowband_out1=NULL;
781          int next_level=0;
782
783          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
784          if (B>1 && !stereo && itheta > 8192)
785             delta -= delta>>(1+level);
786
787          mbits = (b-qalloc-delta)/2;
788          if (mbits > b-qalloc)
789             mbits = b-qalloc;
790          if (mbits<0)
791             mbits=0;
792          sbits = b-qalloc-mbits;
793          *remaining_bits -= qalloc;
794
795          if (lowband && !stereo)
796             next_lowband2 = lowband+N; /* >32-bit split case */
797
798          /* Only stereo needs to pass on lowband_out. Otherwise, it's
799             handled at the end */
800          if (stereo)
801             next_lowband_out1 = lowband_out;
802          else
803             next_level = level+1;
804
805          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
806                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
807                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
808          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
809                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
810                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
811       }
812
813    } else {
814       /* This is the basic no-split case */
815       q = bits2pulses(m, i, LM, b);
816       curr_bits = pulses2bits(m, i, LM, q);
817       *remaining_bits -= curr_bits;
818
819       /* Ensures we can never bust the budget */
820       while (*remaining_bits < 0 && q > 0)
821       {
822          *remaining_bits += curr_bits;
823          q--;
824          curr_bits = pulses2bits(m, i, LM, q);
825          *remaining_bits -= curr_bits;
826       }
827
828       if (q!=0)
829       {
830          int K = get_pulses(q);
831
832          /* Finally do the actual quantization */
833          if (encode)
834             alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
835          else
836             alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, seed, gain);
837       } else {
838          int j;
839          if (lowband != NULL && resynth)
840          {
841             if (spread==2 && B<=1)
842             {
843                for (j=0;j<N;j++)
844                {
845                   *seed = lcg_rand(*seed);
846                   X[j] = (int)(*seed)>>20;
847                }
848             } else {
849                for (j=0;j<N;j++)
850                   X[j] = lowband[j];
851             }
852             renormalise_vector(X, N, gain);
853          } else {
854             /* This is important for encoding the side in stereo mode */
855             for (j=0;j<N;j++)
856                X[j] = 0;
857          }
858       }
859    }
860
861    /* This code is used by the decoder and by the resynthesis-enabled encoder */
862    if (resynth)
863    {
864       if (stereo)
865       {
866          if (N!=2)
867             stereo_merge(X, Y, mid, side, N);
868       } else if (level == 0)
869       {
870          int k;
871
872          /* Undo the sample reorganization going from time order to frequency order */
873          if (B0>1)
874             interleave_vector(X, N_B, B0);
875
876          /* Undo time-freq changes that we did earlier */
877          N_B = N_B0;
878          B = B0;
879          for (k=0;k<time_divide;k++)
880          {
881             B >>= 1;
882             N_B <<= 1;
883             haar1(X, N_B, B);
884          }
885
886          for (k=0;k<recombine;k++)
887          {
888             haar1(X, N_B, B);
889             N_B>>=1;
890             B <<= 1;
891          }
892
893          /* Scale output for later folding */
894          if (lowband_out)
895          {
896             int j;
897             celt_word16 n;
898             n = celt_sqrt(SHL32(EXTEND32(N0),22));
899             for (j=0;j<N0;j++)
900                lowband_out[j] = MULT16_16_Q15(n,X[j]);
901          }
902       }
903    }
904 }
905
906 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM, int codedBands)
907 {
908    int i, balance;
909    celt_int32 remaining_bits;
910    const celt_int16 * restrict eBands = m->eBands;
911    celt_norm * restrict norm;
912    VARDECL(celt_norm, _norm);
913    VARDECL(celt_norm, lowband_scratch);
914    int B;
915    int M;
916    celt_int32 seed;
917    celt_norm *lowband;
918    int update_lowband = 1;
919    int C = _Y != NULL ? 2 : 1;
920    SAVE_STACK;
921
922    M = 1<<LM;
923    B = shortBlocks ? M : 1;
924    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
925    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
926    norm = _norm;
927
928    if (encode)
929       seed = ((ec_enc*)ec)->rng;
930    else
931       seed = ((ec_dec*)ec)->rng;
932    balance = 0;
933    lowband = NULL;
934    for (i=start;i<end;i++)
935    {
936       int tell;
937       int b;
938       int N;
939       int curr_balance;
940       celt_norm * restrict X, * restrict Y;
941       int tf_change=0;
942       
943       X = _X+M*eBands[i];
944       if (_Y!=NULL)
945          Y = _Y+M*eBands[i];
946       else
947          Y = NULL;
948       N = M*eBands[i+1]-M*eBands[i];
949       if (encode)
950          tell = ec_enc_tell((ec_enc*)ec, BITRES);
951       else
952          tell = ec_dec_tell((ec_dec*)ec, BITRES);
953
954       /* Compute how many bits we want to allocate to this band */
955       if (i != start)
956          balance -= tell;
957       remaining_bits = (total_bits<<BITRES)-tell-1;
958       if (i <= codedBands-1)
959       {
960          curr_balance = (codedBands-i);
961          if (curr_balance > 3)
962             curr_balance = 3;
963          curr_balance = balance / curr_balance;
964          b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
965          if (b<0)
966             b = 0;
967       } else {
968          b = 0;
969       }
970       /* Prevents ridiculous bit depths */
971       if (b > C*16*N<<BITRES)
972          b = C*16*N<<BITRES;
973
974       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
975             lowband = norm+M*eBands[i]-N;
976
977       tf_change = tf_res[i];
978       if (i>=m->effEBands)
979       {
980          X=norm;
981          if (_Y!=NULL)
982             Y = norm;
983       }
984
985       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
986             lowband, resynth, ec, &remaining_bits, LM,
987             norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
988
989       balance += pulses[i] + tell;
990
991       /* Update the folding position only as long as we have 2 bit/sample depth */
992       update_lowband = (b>>BITRES)>2*N;
993    }
994    RESTORE_STACK;
995 }
996