Re-organize spreading/folding constants.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    c=0; do {
74       for (i=0;i<end;i++)
75       {
76          int j;
77          celt_word32 maxval=0;
78          celt_word32 sum = 0;
79          
80          j=M*eBands[i]; do {
81             maxval = MAX32(maxval, X[j+c*N]);
82             maxval = MAX32(maxval, -X[j+c*N]);
83          } while (++j<M*eBands[i+1]);
84          
85          if (maxval > 0)
86          {
87             int shift = celt_ilog2(maxval)-10;
88             j=M*eBands[i]; do {
89                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
90                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
91             } while (++j<M*eBands[i+1]);
92             /* We're adding one here to make damn sure we never end up with a pitch vector that's
93                larger than unity norm */
94             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
95          } else {
96             bank[i+c*m->nbEBands] = EPSILON;
97          }
98          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
99       }
100    } while (++c<C);
101    /*printf ("\n");*/
102 }
103
104 /* Normalise each band such that the energy is one. */
105 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
106 {
107    int i, c, N;
108    const celt_int16 *eBands = m->eBands;
109    const int C = CHANNELS(_C);
110    N = M*m->shortMdctSize;
111    c=0; do {
112       i=0; do {
113          celt_word16 g;
114          int j,shift;
115          celt_word16 E;
116          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
117          E = VSHR32(bank[i+c*m->nbEBands], shift);
118          g = EXTRACT16(celt_rcp(SHL32(E,3)));
119          j=M*eBands[i]; do {
120             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
121          } while (++j<M*eBands[i+1]);
122       } while (++i<end);
123    } while (++c<C);
124 }
125
126 #else /* FIXED_POINT */
127 /* Compute the amplitude (sqrt energy) in each of the bands */
128 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
129 {
130    int i, c, N;
131    const celt_int16 *eBands = m->eBands;
132    const int C = CHANNELS(_C);
133    N = M*m->shortMdctSize;
134    c=0; do {
135       for (i=0;i<end;i++)
136       {
137          int j;
138          celt_word32 sum = 1e-10f;
139          for (j=M*eBands[i];j<M*eBands[i+1];j++)
140             sum += X[j+c*N]*X[j+c*N];
141          bank[i+c*m->nbEBands] = celt_sqrt(sum);
142          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
143       }
144    } while (++c<C);
145    /*printf ("\n");*/
146 }
147
148 /* Normalise each band such that the energy is one. */
149 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
150 {
151    int i, c, N;
152    const celt_int16 *eBands = m->eBands;
153    const int C = CHANNELS(_C);
154    N = M*m->shortMdctSize;
155    c=0; do {
156       for (i=0;i<end;i++)
157       {
158          int j;
159          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
160          for (j=M*eBands[i];j<M*eBands[i+1];j++)
161             X[j+c*N] = freq[j+c*N]*g;
162       }
163    } while (++c<C);
164 }
165
166 #endif /* FIXED_POINT */
167
168 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
169 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
170 {
171    int i, c, N;
172    const celt_int16 *eBands = m->eBands;
173    const int C = CHANNELS(_C);
174    N = M*m->shortMdctSize;
175    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
176    c=0; do {
177       celt_sig * restrict f;
178       const celt_norm * restrict x;
179       f = freq+c*N;
180       x = X+c*N;
181       for (i=0;i<end;i++)
182       {
183          int j, band_end;
184          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
185          j=M*eBands[i];
186          band_end = M*eBands[i+1];
187          do {
188             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
189             x++;
190          } while (++j<band_end);
191       }
192       for (i=M*eBands[m->nbEBands];i<N;i++)
193          *f++ = 0;
194    } while (++c<C);
195 }
196
197 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
198 {
199    int i = bandID;
200    int j;
201    celt_word16 a1, a2;
202    celt_word16 left, right;
203    celt_word16 norm;
204 #ifdef FIXED_POINT
205    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
206 #endif
207    left = VSHR32(bank[i],shift);
208    right = VSHR32(bank[i+m->nbEBands],shift);
209    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
210    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
211    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
212    for (j=0;j<N;j++)
213    {
214       celt_norm r, l;
215       l = X[j];
216       r = Y[j];
217       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
218       /* Side is not encoded, no need to calculate */
219    }
220 }
221
222 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
223 {
224    int j;
225    for (j=0;j<N;j++)
226    {
227       celt_norm r, l;
228       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
229       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
230       X[j] = l+r;
231       Y[j] = r-l;
232    }
233 }
234
235 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
236 {
237    int j;
238    celt_word32 xp=0;
239    celt_word32 El, Er;
240 #ifdef FIXED_POINT
241    int kl, kr;
242 #endif
243    celt_word32 t, lgain, rgain;
244
245    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
246    for (j=0;j<N;j++)
247       xp = MAC16_16(xp, X[j], Y[j]);
248    /* mid and side are in Q15, not Q14 like X and Y */
249    mid = SHR32(mid, 1);
250    side = SHR32(side, 1);
251    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*xp;
252    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*xp;
253    if (Er < EPSILON)
254       Er = EPSILON;
255    if (El < EPSILON)
256       El = EPSILON;
257
258 #ifdef FIXED_POINT
259    kl = celt_ilog2(El)>>1;
260    kr = celt_ilog2(Er)>>1;
261 #endif
262    t = VSHR32(El, (kl-7)<<1);
263    lgain = celt_rsqrt_norm(t);
264    t = VSHR32(Er, (kr-7)<<1);
265    rgain = celt_rsqrt_norm(t);
266
267 #ifdef FIXED_POINT
268    if (kl < 7)
269       kl = 7;
270    if (kr < 7)
271       kr = 7;
272 #endif
273
274    for (j=0;j<N;j++)
275    {
276       celt_norm r, l;
277       l = X[j];
278       r = Y[j];
279       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
280       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
281    }
282 }
283
284 /* Decide whether we should spread the pulses in the current frame */
285 int spreading_decision(const CELTMode *m, celt_norm *X, int *average, int last_decision, int end, int _C, int M)
286 {
287    int i, c, N0;
288    int sum = 0, nbBands=0;
289    const int C = CHANNELS(_C);
290    const celt_int16 * restrict eBands = m->eBands;
291    int decision;
292    
293    N0 = M*m->shortMdctSize;
294
295    if (M*(eBands[end]-eBands[end-1]) <= 8)
296       return SPREAD_NONE;
297    c=0; do {
298       for (i=0;i<end;i++)
299       {
300          int j, N, tmp=0;
301          int tcount[3] = {0};
302          celt_norm * restrict x = X+M*eBands[i]+c*N0;
303          N = M*(eBands[i+1]-eBands[i]);
304          if (N<=8)
305             continue;
306          /* Compute rough CDF of |x[j]| */
307          for (j=0;j<N;j++)
308          {
309             celt_word32 x2N; /* Q13 */
310
311             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
312             if (x2N < QCONST16(0.25f,13))
313                tcount[0]++;
314             if (x2N < QCONST16(0.0625f,13))
315                tcount[1]++;
316             if (x2N < QCONST16(0.015625f,13))
317                tcount[2]++;
318          }
319
320          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
321          sum += tmp*256;
322          nbBands++;
323       }
324    } while (++c<C);
325    sum /= nbBands;
326    /* Recursive averaging */
327    sum = (sum+*average)>>1;
328    *average = sum;
329    /* Hysteresis */
330    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
331    if (sum < 80)
332    {
333       decision = SPREAD_AGGRESSIVE;
334    } else if (sum < 256)
335    {
336       decision = SPREAD_NORMAL;
337    } else if (sum < 384)
338    {
339       decision = SPREAD_LIGHT;
340    } else {
341       decision = SPREAD_NONE;
342    }
343    return decision;
344 }
345
346 #ifdef MEASURE_NORM_MSE
347
348 float MSE[30] = {0};
349 int nbMSEBands = 0;
350 int MSECount[30] = {0};
351
352 void dump_norm_mse(void)
353 {
354    int i;
355    for (i=0;i<nbMSEBands;i++)
356    {
357       printf ("%g ", MSE[i]/MSECount[i]);
358    }
359    printf ("\n");
360 }
361
362 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
363 {
364    static int init = 0;
365    int i;
366    if (!init)
367    {
368       atexit(dump_norm_mse);
369       init = 1;
370    }
371    for (i=0;i<m->nbEBands;i++)
372    {
373       int j;
374       int c;
375       float g;
376       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
377          continue;
378       c=0; do {
379          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
380          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
381             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
382       } while (++c<C);
383       MSECount[i]+=C;
384    }
385    nbMSEBands = m->nbEBands;
386 }
387
388 #endif
389
390 static void interleave_vector(celt_norm *X, int N0, int stride)
391 {
392    int i,j;
393    VARDECL(celt_norm, tmp);
394    int N;
395    SAVE_STACK;
396    N = N0*stride;
397    ALLOC(tmp, N, celt_norm);
398    for (i=0;i<stride;i++)
399       for (j=0;j<N0;j++)
400          tmp[j*stride+i] = X[i*N0+j];
401    for (j=0;j<N;j++)
402       X[j] = tmp[j];
403    RESTORE_STACK;
404 }
405
406 static void deinterleave_vector(celt_norm *X, int N0, int stride)
407 {
408    int i,j;
409    VARDECL(celt_norm, tmp);
410    int N;
411    SAVE_STACK;
412    N = N0*stride;
413    ALLOC(tmp, N, celt_norm);
414    for (i=0;i<stride;i++)
415       for (j=0;j<N0;j++)
416          tmp[i*N0+j] = X[j*stride+i];
417    for (j=0;j<N;j++)
418       X[j] = tmp[j];
419    RESTORE_STACK;
420 }
421
422 void haar1(celt_norm *X, int N0, int stride)
423 {
424    int i, j;
425    N0 >>= 1;
426    for (i=0;i<stride;i++)
427       for (j=0;j<N0;j++)
428       {
429          celt_norm tmp1, tmp2;
430          tmp1 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i]);
431          tmp2 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*(2*j+1)+i]);
432          X[stride*2*j+i] = tmp1 + tmp2;
433          X[stride*(2*j+1)+i] = tmp1 - tmp2;
434       }
435 }
436
437 static int compute_qn(int N, int b, int offset, int stereo)
438 {
439    static const celt_int16 exp2_table8[8] =
440       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
441    int qn, qb;
442    int N2 = 2*N-1;
443    if (stereo && N==2)
444       N2--;
445    qb = (b+N2*offset)/N2;
446    if (qb > (b>>1)-(1<<BITRES))
447       qb = (b>>1)-(1<<BITRES);
448
449    if (qb<0)
450        qb = 0;
451    if (qb>8<<BITRES)
452      qb = 8<<BITRES;
453
454    if (qb<(1<<BITRES>>1)) {
455       qn = 1;
456    } else {
457       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
458       qn = (qn+1)>>1<<1;
459    }
460    celt_assert(qn <= 256);
461    return qn;
462 }
463
464 static celt_uint32 lcg_rand(celt_uint32 seed)
465 {
466    return 1664525 * seed + 1013904223;
467 }
468
469 /* This function is responsible for encoding and decoding a band for both
470    the mono and stereo case. Even in the mono case, it can split the band
471    in two and transmit the energy difference with the two half-bands. It
472    can be called recursively so bands can end up being split in 8 parts. */
473 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
474       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
475       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
476       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
477 {
478    int q;
479    int curr_bits;
480    int stereo, split;
481    int imid=0, iside=0;
482    int N0=N;
483    int N_B=N;
484    int N_B0;
485    int B0=B;
486    int time_divide=0;
487    int recombine=0;
488    int inv = 0;
489    celt_word16 mid=0, side=0;
490
491    N_B /= B;
492    N_B0 = N_B;
493
494    split = stereo = Y != NULL;
495
496    /* Special case for one sample */
497    if (N==1)
498    {
499       int c;
500       celt_norm *x = X;
501       c=0; do {
502          int sign=0;
503          if (*remaining_bits>=1<<BITRES)
504          {
505             if (encode)
506             {
507                sign = x[0]<0;
508                ec_enc_bits((ec_enc*)ec, sign, 1);
509             } else {
510                sign = ec_dec_bits((ec_dec*)ec, 1);
511             }
512             *remaining_bits -= 1<<BITRES;
513             b-=1<<BITRES;
514          }
515          if (resynth)
516             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
517          x = Y;
518       } while (++c<1+stereo);
519       if (lowband_out)
520          lowband_out[0] = SHR16(X[0],4);
521       return;
522    }
523
524    if (!stereo && level == 0)
525    {
526       int k;
527       if (tf_change>0)
528          recombine = tf_change;
529       /* Band recombining to increase frequency resolution */
530
531       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
532       {
533          int j;
534          for (j=0;j<N;j++)
535             lowband_scratch[j] = lowband[j];
536          lowband = lowband_scratch;
537       }
538
539       for (k=0;k<recombine;k++)
540       {
541          B>>=1;
542          N_B<<=1;
543          if (encode)
544             haar1(X, N_B, B);
545          if (lowband)
546             haar1(lowband, N_B, B);
547       }
548
549       /* Increasing the time resolution */
550       while ((N_B&1) == 0 && tf_change<0)
551       {
552          if (encode)
553             haar1(X, N_B, B);
554          if (lowband)
555             haar1(lowband, N_B, B);
556          B <<= 1;
557          N_B >>= 1;
558          time_divide++;
559          tf_change++;
560       }
561       B0=B;
562       N_B0 = N_B;
563
564       /* Reorganize the samples in time order instead of frequency order */
565       if (B0>1)
566       {
567          if (encode)
568             deinterleave_vector(X, N_B, B0);
569          if (lowband)
570             deinterleave_vector(lowband, N_B, B0);
571       }
572    }
573
574    /* If we need more than 32 bits, try splitting the band in two. */
575    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
576    {
577       if (LM>0 || (N&1)==0)
578       {
579          N >>= 1;
580          Y = X+N;
581          split = 1;
582          LM -= 1;
583          B = (B+1)>>1;
584       }
585    }
586
587    if (split)
588    {
589       int qn;
590       int itheta=0;
591       int mbits, sbits, delta;
592       int qalloc;
593       int offset;
594
595       /* Decide on the resolution to give to the split parameter theta */
596       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
597       qn = compute_qn(N, b, offset, stereo);
598       qalloc = 0;
599       if (stereo && i>=intensity)
600          qn = 1;
601       if (encode)
602       {
603          /* theta is the atan() of the ratio between the (normalized)
604             side and mid. With just that parameter, we can re-scale both
605             mid and side because we know that 1) they have unit norm and
606             2) they are orthogonal. */
607          itheta = stereo_itheta(X, Y, stereo, N);
608       }
609       if (qn!=1)
610       {
611          if (encode)
612             itheta = (itheta*qn+8192)>>14;
613
614          /* Entropy coding of the angle. We use a uniform pdf for the
615             first stereo split but a triangular one for the rest. */
616          if (stereo || B>1)
617          {
618             if (encode)
619                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
620             else
621                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
622             qalloc = log2_frac(qn+1,BITRES);
623          } else {
624             int fs=1, ft;
625             ft = ((qn>>1)+1)*((qn>>1)+1);
626             if (encode)
627             {
628                int fl;
629
630                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
631                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
632                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
633
634                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
635             } else {
636                int fl=0;
637                int fm;
638                fm = ec_decode((ec_dec*)ec, ft);
639
640                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
641                {
642                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
643                   fs = itheta + 1;
644                   fl = itheta*(itheta + 1)>>1;
645                }
646                else
647                {
648                   itheta = (2*(qn + 1)
649                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
650                   fs = qn + 1 - itheta;
651                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
652                }
653
654                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
655             }
656             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
657          }
658          itheta = (celt_int32)itheta*16384/qn;
659          if (encode && stereo)
660             stereo_split(X, Y, N);
661          /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
662                   Let's do that at higher complexity */
663       } else if (stereo) {
664          if (encode)
665          {
666             inv = itheta > 8192;
667             if (inv)
668             {
669                int j;
670                for (j=0;j<N;j++)
671                   Y[j] = -Y[j];
672             }
673             intensity_stereo(m, X, Y, bandE, i, N);
674          }
675          if (b>2<<BITRES && *remaining_bits > 2<<BITRES)
676          {
677             if (encode)
678                ec_enc_bit_prob(ec, inv, 16384);
679             else
680                inv = ec_dec_bit_prob(ec, 16384);
681             qalloc = inv ? 16 : 4;
682          } else
683             inv = 0;
684          itheta = 0;
685       }
686
687       if (itheta == 0)
688       {
689          imid = 32767;
690          iside = 0;
691          delta = -10000;
692       } else if (itheta == 16384)
693       {
694          imid = 0;
695          iside = 32767;
696          delta = 10000;
697       } else {
698          imid = bitexact_cos(itheta);
699          iside = bitexact_cos(16384-itheta);
700          /* This is the mid vs side allocation that minimizes squared error
701             in that band. */
702          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
703       }
704
705 #ifdef FIXED_POINT
706       mid = imid;
707       side = iside;
708 #else
709       mid = (1.f/32768)*imid;
710       side = (1.f/32768)*iside;
711 #endif
712
713       /* This is a special case for N=2 that only works for stereo and takes
714          advantage of the fact that mid and side are orthogonal to encode
715          the side with just one bit. */
716       if (N==2 && stereo)
717       {
718          int c;
719          int sign=1;
720          celt_norm *x2, *y2;
721          mbits = b-qalloc;
722          sbits = 0;
723          /* Only need one bit for the side */
724          if (itheta != 0 && itheta != 16384)
725             sbits = 1<<BITRES;
726          mbits -= sbits;
727          c = itheta > 8192;
728          *remaining_bits -= qalloc+sbits;
729
730          x2 = c ? Y : X;
731          y2 = c ? X : Y;
732          if (sbits)
733          {
734             if (encode)
735             {
736                /* Here we only need to encode a sign for the side */
737                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
738                ec_enc_bits((ec_enc*)ec, sign, 1);
739             } else {
740                sign = ec_dec_bits((ec_dec*)ec, 1);
741             }
742          }
743          sign = 2*sign - 1;
744          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch);
745          y2[0] = -sign*x2[1];
746          y2[1] = sign*x2[0];
747          if (resynth)
748          {
749             celt_norm tmp;
750             X[0] = MULT16_16_Q15(mid, X[0]);
751             X[1] = MULT16_16_Q15(mid, X[1]);
752             Y[0] = MULT16_16_Q15(side, Y[0]);
753             Y[1] = MULT16_16_Q15(side, Y[1]);
754             tmp = X[0];
755             X[0] = SUB16(tmp,Y[0]);
756             Y[0] = ADD16(tmp,Y[0]);
757             tmp = X[1];
758             X[1] = SUB16(tmp,Y[1]);
759             Y[1] = ADD16(tmp,Y[1]);
760          }
761       } else {
762          /* "Normal" split code */
763          celt_norm *next_lowband2=NULL;
764          celt_norm *next_lowband_out1=NULL;
765          int next_level=0;
766
767          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
768          if (B>1 && !stereo && itheta > 8192)
769             delta -= delta>>(1+level);
770
771          mbits = (b-qalloc-delta)/2;
772          if (mbits > b-qalloc)
773             mbits = b-qalloc;
774          if (mbits<0)
775             mbits=0;
776          sbits = b-qalloc-mbits;
777          *remaining_bits -= qalloc;
778
779          if (lowband && !stereo)
780             next_lowband2 = lowband+N; /* >32-bit split case */
781
782          /* Only stereo needs to pass on lowband_out. Otherwise, it's
783             handled at the end */
784          if (stereo)
785             next_lowband_out1 = lowband_out;
786          else
787             next_level = level+1;
788
789          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
790                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
791                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
792          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
793                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
794                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
795       }
796
797    } else {
798       /* This is the basic no-split case */
799       q = bits2pulses(m, i, LM, b);
800       curr_bits = pulses2bits(m, i, LM, q);
801       *remaining_bits -= curr_bits;
802
803       /* Ensures we can never bust the budget */
804       while (*remaining_bits < 0 && q > 0)
805       {
806          *remaining_bits += curr_bits;
807          q--;
808          curr_bits = pulses2bits(m, i, LM, q);
809          *remaining_bits -= curr_bits;
810       }
811
812       if (q!=0)
813       {
814          int K = get_pulses(q);
815
816          /* Finally do the actual quantization */
817          if (encode)
818             alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
819          else
820             alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, seed, gain);
821       } else {
822          /* If there's no pulse, fill the band anyway */
823          int j;
824          if (lowband != NULL && resynth)
825          {
826             if (spread==SPREAD_AGGRESSIVE && B<=1)
827             {
828                /* Noise */
829                for (j=0;j<N;j++)
830                {
831                   *seed = lcg_rand(*seed);
832                   X[j] = (int)(*seed)>>20;
833                }
834             } else {
835                /* Folded spectrum */
836                for (j=0;j<N;j++)
837                   X[j] = lowband[j];
838             }
839             renormalise_vector(X, N, gain);
840          } else {
841             /* This is important for encoding the side in stereo mode */
842             for (j=0;j<N;j++)
843                X[j] = 0;
844          }
845       }
846    }
847
848    /* This code is used by the decoder and by the resynthesis-enabled encoder */
849    if (resynth)
850    {
851       if (stereo)
852       {
853          if (N!=2)
854             stereo_merge(X, Y, mid, side, N);
855          if (inv)
856          {
857             int j;
858             for (j=0;j<N;j++)
859                Y[j] = -Y[j];
860          }
861       } else if (level == 0)
862       {
863          int k;
864
865          /* Undo the sample reorganization going from time order to frequency order */
866          if (B0>1)
867             interleave_vector(X, N_B, B0);
868
869          /* Undo time-freq changes that we did earlier */
870          N_B = N_B0;
871          B = B0;
872          for (k=0;k<time_divide;k++)
873          {
874             B >>= 1;
875             N_B <<= 1;
876             haar1(X, N_B, B);
877          }
878
879          for (k=0;k<recombine;k++)
880          {
881             haar1(X, N_B, B);
882             N_B>>=1;
883             B <<= 1;
884          }
885
886          /* Scale output for later folding */
887          if (lowband_out)
888          {
889             int j;
890             celt_word16 n;
891             n = celt_sqrt(SHL32(EXTEND32(N0),22));
892             for (j=0;j<N0;j++)
893                lowband_out[j] = MULT16_16_Q15(n,X[j]);
894          }
895       }
896    }
897 }
898
899 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
900       celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses,
901       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
902       int total_bits, void *ec, int LM, int codedBands)
903 {
904    int i, balance;
905    celt_int32 remaining_bits;
906    const celt_int16 * restrict eBands = m->eBands;
907    celt_norm * restrict norm, * restrict norm2;
908    VARDECL(celt_norm, _norm);
909    VARDECL(celt_norm, lowband_scratch);
910    int B;
911    int M;
912    celt_int32 seed;
913    int lowband_offset;
914    int update_lowband = 1;
915    int C = _Y != NULL ? 2 : 1;
916    SAVE_STACK;
917
918    M = 1<<LM;
919    B = shortBlocks ? M : 1;
920    ALLOC(_norm, C*M*eBands[m->nbEBands], celt_norm);
921    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
922    norm = _norm;
923    norm2 = norm + M*eBands[m->nbEBands];
924 #if 0
925    if (C==2)
926    {
927       int j;
928       int left = 0;
929       for (j=intensity;j<codedBands;j++)
930       {
931          int tmp = pulses[j]/2;
932          left += tmp;
933          pulses[j] -= tmp;
934       }
935       if (codedBands) {
936          int perband;
937          perband = left/(m->eBands[codedBands]-m->eBands[start]);
938          for (j=start;j<codedBands;j++)
939             pulses[j] += perband*(m->eBands[j+1]-m->eBands[j]);
940          left = left-(m->eBands[codedBands]-m->eBands[start])*perband;
941          for (j=start;j<codedBands;j++)
942          {
943             int tmp = IMIN(left, m->eBands[j+1]-m->eBands[j]);
944             pulses[j] += tmp;
945             left -= tmp;
946          }
947       }
948    }
949 #endif
950    if (encode)
951       seed = ((ec_enc*)ec)->rng;
952    else
953       seed = ((ec_dec*)ec)->rng;
954    balance = 0;
955    lowband_offset = -1;
956    for (i=start;i<end;i++)
957    {
958       int tell;
959       int b;
960       int N;
961       int curr_balance;
962       int effective_lowband=-1;
963       celt_norm * restrict X, * restrict Y;
964       int tf_change=0;
965       
966       X = _X+M*eBands[i];
967       if (_Y!=NULL)
968          Y = _Y+M*eBands[i];
969       else
970          Y = NULL;
971       N = M*eBands[i+1]-M*eBands[i];
972       if (encode)
973          tell = ec_enc_tell((ec_enc*)ec, BITRES);
974       else
975          tell = ec_dec_tell((ec_dec*)ec, BITRES);
976
977       /* Compute how many bits we want to allocate to this band */
978       if (i != start)
979          balance -= tell;
980       remaining_bits = (total_bits<<BITRES)-tell-1;
981       if (i <= codedBands-1)
982       {
983          curr_balance = (codedBands-i);
984          if (curr_balance > 3)
985             curr_balance = 3;
986          curr_balance = balance / curr_balance;
987          b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
988          if (b<0)
989             b = 0;
990       } else {
991          b = 0;
992       }
993       /* Prevents ridiculous bit depths */
994       if (b > C*16*N<<BITRES)
995          b = C*16*N<<BITRES;
996
997       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==-1))
998             lowband_offset = M*eBands[i];
999
1000       tf_change = tf_res[i];
1001       if (i>=m->effEBands)
1002       {
1003          X=norm;
1004          if (_Y!=NULL)
1005             Y = norm;
1006       }
1007
1008       /* This ensures we never repeat spectral content within one band */
1009       if (lowband_offset != -1)
1010       {
1011          effective_lowband = lowband_offset-N;
1012          if (effective_lowband < M*eBands[start])
1013             effective_lowband = M*eBands[start];
1014       }
1015       if (dual_stereo && i==intensity)
1016       {
1017          int j;
1018
1019          /* Switch off dual stereo to do intensity */
1020          dual_stereo = 0;
1021          for (j=0;j<M*eBands[i];j++)
1022             norm[j] = HALF32(norm[j]+norm2[j]);
1023       }
1024       if (dual_stereo)
1025       {
1026          quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
1027                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1028                norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
1029          quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
1030                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1031                norm2+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
1032       } else {
1033          quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
1034                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1035                norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
1036       }
1037       balance += pulses[i] + tell;
1038
1039       /* Update the folding position only as long as we have 1 bit/sample depth */
1040       update_lowband = (b>>BITRES)>N;
1041    }
1042    RESTORE_STACK;
1043 }
1044