Switch iteration over channels to the do{}while(); construct in order to inform the...
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    c=0; do {
74       for (i=0;i<end;i++)
75       {
76          int j;
77          celt_word32 maxval=0;
78          celt_word32 sum = 0;
79          
80          j=M*eBands[i]; do {
81             maxval = MAX32(maxval, X[j+c*N]);
82             maxval = MAX32(maxval, -X[j+c*N]);
83          } while (++j<M*eBands[i+1]);
84          
85          if (maxval > 0)
86          {
87             int shift = celt_ilog2(maxval)-10;
88             j=M*eBands[i]; do {
89                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
90                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
91             } while (++j<M*eBands[i+1]);
92             /* We're adding one here to make damn sure we never end up with a pitch vector that's
93                larger than unity norm */
94             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
95          } else {
96             bank[i+c*m->nbEBands] = EPSILON;
97          }
98          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
99       }
100    } while (++c<C);
101    /*printf ("\n");*/
102 }
103
104 /* Normalise each band such that the energy is one. */
105 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
106 {
107    int i, c, N;
108    const celt_int16 *eBands = m->eBands;
109    const int C = CHANNELS(_C);
110    N = M*m->shortMdctSize;
111    c=0; do {
112       i=0; do {
113          celt_word16 g;
114          int j,shift;
115          celt_word16 E;
116          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
117          E = VSHR32(bank[i+c*m->nbEBands], shift);
118          g = EXTRACT16(celt_rcp(SHL32(E,3)));
119          j=M*eBands[i]; do {
120             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
121          } while (++j<M*eBands[i+1]);
122       } while (++i<end);
123    } while (++c<C);
124 }
125
126 #else /* FIXED_POINT */
127 /* Compute the amplitude (sqrt energy) in each of the bands */
128 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
129 {
130    int i, c, N;
131    const celt_int16 *eBands = m->eBands;
132    const int C = CHANNELS(_C);
133    N = M*m->shortMdctSize;
134    c=0; do {
135       for (i=0;i<end;i++)
136       {
137          int j;
138          celt_word32 sum = 1e-10f;
139          for (j=M*eBands[i];j<M*eBands[i+1];j++)
140             sum += X[j+c*N]*X[j+c*N];
141          bank[i+c*m->nbEBands] = celt_sqrt(sum);
142          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
143       }
144    } while (++c<C);
145    /*printf ("\n");*/
146 }
147
148 /* Normalise each band such that the energy is one. */
149 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
150 {
151    int i, c, N;
152    const celt_int16 *eBands = m->eBands;
153    const int C = CHANNELS(_C);
154    N = M*m->shortMdctSize;
155    c=0; do {
156       for (i=0;i<end;i++)
157       {
158          int j;
159          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
160          for (j=M*eBands[i];j<M*eBands[i+1];j++)
161             X[j+c*N] = freq[j+c*N]*g;
162       }
163    } while (++c<C);
164 }
165
166 #endif /* FIXED_POINT */
167
168 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
169 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
170 {
171    int i, c, N;
172    const celt_int16 *eBands = m->eBands;
173    const int C = CHANNELS(_C);
174    N = M*m->shortMdctSize;
175    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
176    c=0; do {
177       celt_sig * restrict f;
178       const celt_norm * restrict x;
179       f = freq+c*N;
180       x = X+c*N;
181       for (i=0;i<end;i++)
182       {
183          int j, band_end;
184          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
185          j=M*eBands[i];
186          band_end = M*eBands[i+1];
187          do {
188             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
189             x++;
190          } while (++j<band_end);
191       }
192       for (i=M*eBands[m->nbEBands];i<N;i++)
193          *f++ = 0;
194    } while (++c<C);
195 }
196
197 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
198 {
199    int i = bandID;
200    int j;
201    celt_word16 a1, a2;
202    celt_word16 left, right;
203    celt_word16 norm;
204 #ifdef FIXED_POINT
205    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
206 #endif
207    left = VSHR32(bank[i],shift);
208    right = VSHR32(bank[i+m->nbEBands],shift);
209    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
210    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
211    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
212    for (j=0;j<N;j++)
213    {
214       celt_norm r, l;
215       l = X[j];
216       r = Y[j];
217       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
218       /* Side is not encoded, no need to calculate */
219    }
220 }
221
222 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
223 {
224    int j;
225    for (j=0;j<N;j++)
226    {
227       celt_norm r, l;
228       l = MULT16_16_Q15(QCONST16(.70711f,15), X[j]);
229       r = MULT16_16_Q15(QCONST16(.70711f,15), Y[j]);
230       X[j] = l+r;
231       Y[j] = r-l;
232    }
233 }
234
235 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, celt_word16 side, int N)
236 {
237    int j;
238    celt_word32 xp=0;
239    celt_word32 El, Er;
240 #ifdef FIXED_POINT
241    int kl, kr;
242 #endif
243    celt_word32 t, lgain, rgain;
244
245    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
246    for (j=0;j<N;j++)
247       xp = MAC16_16(xp, X[j], Y[j]);
248    /* mid and side are in Q15, not Q14 like X and Y */
249    mid = SHR32(mid, 1);
250    side = SHR32(side, 1);
251    El = MULT16_16(mid, mid) + MULT16_16(side, side) - 2*xp;
252    Er = MULT16_16(mid, mid) + MULT16_16(side, side) + 2*xp;
253    if (Er < EPSILON)
254       Er = EPSILON;
255    if (El < EPSILON)
256       El = EPSILON;
257
258 #ifdef FIXED_POINT
259    kl = celt_ilog2(El)>>1;
260    kr = celt_ilog2(Er)>>1;
261 #endif
262    t = VSHR32(El, (kl-7)<<1);
263    lgain = celt_rsqrt_norm(t);
264    t = VSHR32(Er, (kr-7)<<1);
265    rgain = celt_rsqrt_norm(t);
266
267 #ifdef FIXED_POINT
268    if (kl < 7)
269       kl = 7;
270    if (kr < 7)
271       kr = 7;
272 #endif
273
274    for (j=0;j<N;j++)
275    {
276       celt_norm r, l;
277       l = X[j];
278       r = Y[j];
279       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
280       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
281    }
282 }
283
284 /* Decide whether we should spread the pulses in the current frame */
285 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
286 {
287    int i, c, N0;
288    int sum = 0, nbBands=0;
289    const int C = CHANNELS(_C);
290    const celt_int16 * restrict eBands = m->eBands;
291    int decision;
292    
293    N0 = M*m->shortMdctSize;
294
295    if (M*(eBands[end]-eBands[end-1]) <= 8)
296       return 0;
297    c=0; do {
298       for (i=0;i<end;i++)
299       {
300          int j, N, tmp=0;
301          int tcount[3] = {0};
302          celt_norm * restrict x = X+M*eBands[i]+c*N0;
303          N = M*(eBands[i+1]-eBands[i]);
304          if (N<=8)
305             continue;
306          /* Compute rough CDF of |x[j]| */
307          for (j=0;j<N;j++)
308          {
309             celt_word32 x2N; /* Q13 */
310
311             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
312             if (x2N < QCONST16(0.25f,13))
313                tcount[0]++;
314             if (x2N < QCONST16(0.0625f,13))
315                tcount[1]++;
316             if (x2N < QCONST16(0.015625f,13))
317                tcount[2]++;
318          }
319
320          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
321          sum += tmp*256;
322          nbBands++;
323       }
324    } while (++c<C);
325    sum /= nbBands;
326    /* Recursive averaging */
327    sum = (sum+*average)>>1;
328    *average = sum;
329    /* Hysteresis */
330    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
331    /* decision and last_decision do not use the same encoding */
332    if (sum < 80)
333    {
334       decision = 2;
335       *last_decision = 0;
336    } else if (sum < 256)
337    {
338       decision = 1;
339       *last_decision = 1;
340    } else if (sum < 384)
341    {
342       decision = 3;
343       *last_decision = 2;
344    } else {
345       decision = 0;
346       *last_decision = 3;
347    }
348    return decision;
349 }
350
351 #ifdef MEASURE_NORM_MSE
352
353 float MSE[30] = {0};
354 int nbMSEBands = 0;
355 int MSECount[30] = {0};
356
357 void dump_norm_mse(void)
358 {
359    int i;
360    for (i=0;i<nbMSEBands;i++)
361    {
362       printf ("%g ", MSE[i]/MSECount[i]);
363    }
364    printf ("\n");
365 }
366
367 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
368 {
369    static int init = 0;
370    int i;
371    if (!init)
372    {
373       atexit(dump_norm_mse);
374       init = 1;
375    }
376    for (i=0;i<m->nbEBands;i++)
377    {
378       int j;
379       int c;
380       float g;
381       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
382          continue;
383       c=0; do {
384          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
385          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
386             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
387       } while (++c<C);
388       MSECount[i]+=C;
389    }
390    nbMSEBands = m->nbEBands;
391 }
392
393 #endif
394
395 static void interleave_vector(celt_norm *X, int N0, int stride)
396 {
397    int i,j;
398    VARDECL(celt_norm, tmp);
399    int N;
400    SAVE_STACK;
401    N = N0*stride;
402    ALLOC(tmp, N, celt_norm);
403    for (i=0;i<stride;i++)
404       for (j=0;j<N0;j++)
405          tmp[j*stride+i] = X[i*N0+j];
406    for (j=0;j<N;j++)
407       X[j] = tmp[j];
408    RESTORE_STACK;
409 }
410
411 static void deinterleave_vector(celt_norm *X, int N0, int stride)
412 {
413    int i,j;
414    VARDECL(celt_norm, tmp);
415    int N;
416    SAVE_STACK;
417    N = N0*stride;
418    ALLOC(tmp, N, celt_norm);
419    for (i=0;i<stride;i++)
420       for (j=0;j<N0;j++)
421          tmp[i*N0+j] = X[j*stride+i];
422    for (j=0;j<N;j++)
423       X[j] = tmp[j];
424    RESTORE_STACK;
425 }
426
427 void haar1(celt_norm *X, int N0, int stride)
428 {
429    int i, j;
430    N0 >>= 1;
431    for (i=0;i<stride;i++)
432       for (j=0;j<N0;j++)
433       {
434          celt_norm tmp1, tmp2;
435          tmp1 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i]);
436          tmp2 = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*(2*j+1)+i]);
437          X[stride*2*j+i] = tmp1 + tmp2;
438          X[stride*(2*j+1)+i] = tmp1 - tmp2;
439       }
440 }
441
442 static int compute_qn(int N, int b, int offset, int stereo)
443 {
444    static const celt_int16 exp2_table8[8] =
445       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
446    int qn, qb;
447    int N2 = 2*N-1;
448    if (stereo && N==2)
449       N2--;
450    qb = (b+N2*offset)/N2;
451    if (qb > (b>>1)-(1<<BITRES))
452       qb = (b>>1)-(1<<BITRES);
453
454    if (qb<0)
455        qb = 0;
456    if (qb>8<<BITRES)
457      qb = 8<<BITRES;
458
459    if (qb<(1<<BITRES>>1)) {
460       qn = 1;
461    } else {
462       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
463       qn = (qn+1)>>1<<1;
464    }
465    celt_assert(qn <= 256);
466    return qn;
467 }
468
469 static celt_uint32 lcg_rand(celt_uint32 seed)
470 {
471    return 1664525 * seed + 1013904223;
472 }
473
474 /* This function is responsible for encoding and decoding a band for both
475    the mono and stereo case. Even in the mono case, it can split the band
476    in two and transmit the energy difference with the two half-bands. It
477    can be called recursively so bands can end up being split in 8 parts. */
478 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
479       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
480       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
481       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch)
482 {
483    int q;
484    int curr_bits;
485    int stereo, split;
486    int imid=0, iside=0;
487    int N0=N;
488    int N_B=N;
489    int N_B0;
490    int B0=B;
491    int time_divide=0;
492    int recombine=0;
493    celt_word16 mid=0, side=0;
494
495    N_B /= B;
496    N_B0 = N_B;
497
498    split = stereo = Y != NULL;
499
500    /* Special case for one sample */
501    if (N==1)
502    {
503       int c;
504       celt_norm *x = X;
505       c=0; do {
506          int sign=0;
507          if (*remaining_bits>=1<<BITRES)
508          {
509             if (encode)
510             {
511                sign = x[0]<0;
512                ec_enc_bits((ec_enc*)ec, sign, 1);
513             } else {
514                sign = ec_dec_bits((ec_dec*)ec, 1);
515             }
516             *remaining_bits -= 1<<BITRES;
517             b-=1<<BITRES;
518          }
519          if (resynth)
520             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
521          x = Y;
522       } while (++c<1+stereo);
523       if (lowband_out)
524          lowband_out[0] = SHR16(X[0],4);
525       return;
526    }
527
528    if (!stereo && level == 0)
529    {
530       int k;
531       if (tf_change>0)
532          recombine = tf_change;
533       /* Band recombining to increase frequency resolution */
534
535       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
536       {
537          int j;
538          for (j=0;j<N;j++)
539             lowband_scratch[j] = lowband[j];
540          lowband = lowband_scratch;
541       }
542
543       for (k=0;k<recombine;k++)
544       {
545          B>>=1;
546          N_B<<=1;
547          if (encode)
548             haar1(X, N_B, B);
549          if (lowband)
550             haar1(lowband, N_B, B);
551       }
552
553       /* Increasing the time resolution */
554       while ((N_B&1) == 0 && tf_change<0)
555       {
556          if (encode)
557             haar1(X, N_B, B);
558          if (lowband)
559             haar1(lowband, N_B, B);
560          B <<= 1;
561          N_B >>= 1;
562          time_divide++;
563          tf_change++;
564       }
565       B0=B;
566       N_B0 = N_B;
567
568       /* Reorganize the samples in time order instead of frequency order */
569       if (B0>1)
570       {
571          if (encode)
572             deinterleave_vector(X, N_B, B0);
573          if (lowband)
574             deinterleave_vector(lowband, N_B, B0);
575       }
576    }
577
578    /* If we need more than 32 bits, try splitting the band in two. */
579    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
580    {
581       if (LM>0 || (N&1)==0)
582       {
583          N >>= 1;
584          Y = X+N;
585          split = 1;
586          LM -= 1;
587          B = (B+1)>>1;
588       }
589    }
590
591    if (split)
592    {
593       int qn;
594       int itheta=0;
595       int mbits, sbits, delta;
596       int qalloc;
597       int offset;
598
599       /* Decide on the resolution to give to the split parameter theta */
600       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
601       qn = compute_qn(N, b, offset, stereo);
602
603       qalloc = 0;
604       if (qn!=1)
605       {
606          if (encode)
607          {
608             if (stereo)
609                stereo_split(X, Y, N);
610
611             mid = vector_norm(X, N);
612             side = vector_norm(Y, N);
613             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
614                      Let's do that at higher complexity */
615             /*mid = renormalise_vector(X, Q15ONE, N, 1);
616             side = renormalise_vector(Y, Q15ONE, N, 1);*/
617
618             /* theta is the atan() of the ratio between the (normalized)
619                side and mid. With just that parameter, we can re-scale both
620                mid and side because we know that 1) they have unit norm and
621                2) they are orthogonal. */
622    #ifdef FIXED_POINT
623             /* 0.63662 = 2/pi */
624             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
625    #else
626             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
627    #endif
628
629             itheta = (itheta*qn+8192)>>14;
630          }
631
632          /* Entropy coding of the angle. We use a uniform pdf for the
633             first stereo split but a triangular one for the rest. */
634          if (stereo || B>1)
635          {
636             if (encode)
637                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
638             else
639                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
640             qalloc = log2_frac(qn+1,BITRES);
641          } else {
642             int fs=1, ft;
643             ft = ((qn>>1)+1)*((qn>>1)+1);
644             if (encode)
645             {
646                int fl;
647
648                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
649                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
650                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
651
652                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
653             } else {
654                int fl=0;
655                int fm;
656                fm = ec_decode((ec_dec*)ec, ft);
657
658                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
659                {
660                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
661                   fs = itheta + 1;
662                   fl = itheta*(itheta + 1)>>1;
663                }
664                else
665                {
666                   itheta = (2*(qn + 1)
667                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
668                   fs = qn + 1 - itheta;
669                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
670                }
671
672                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
673             }
674             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
675          }
676          itheta = (celt_int32)itheta*16384/qn;
677       } else {
678          if (stereo && encode)
679             intensity_stereo(m, X, Y, bandE, i, N);
680       }
681
682       if (itheta == 0)
683       {
684          imid = 32767;
685          iside = 0;
686          delta = -10000;
687       } else if (itheta == 16384)
688       {
689          imid = 0;
690          iside = 32767;
691          delta = 10000;
692       } else {
693          imid = bitexact_cos(itheta);
694          iside = bitexact_cos(16384-itheta);
695          /* This is the mid vs side allocation that minimizes squared error
696             in that band. */
697          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
698       }
699
700 #ifdef FIXED_POINT
701       mid = imid;
702       side = iside;
703 #else
704       mid = (1.f/32768)*imid;
705       side = (1.f/32768)*iside;
706 #endif
707
708       /* This is a special case for N=2 that only works for stereo and takes
709          advantage of the fact that mid and side are orthogonal to encode
710          the side with just one bit. */
711       if (N==2 && stereo)
712       {
713          int c;
714          int sign=1;
715          celt_norm *x2, *y2;
716          mbits = b-qalloc;
717          sbits = 0;
718          /* Only need one bit for the side */
719          if (itheta != 0 && itheta != 16384)
720             sbits = 1<<BITRES;
721          mbits -= sbits;
722          c = itheta > 8192;
723          *remaining_bits -= qalloc+sbits;
724
725          x2 = c ? Y : X;
726          y2 = c ? X : Y;
727          if (sbits)
728          {
729             if (encode)
730             {
731                /* Here we only need to encode a sign for the side */
732                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
733                ec_enc_bits((ec_enc*)ec, sign, 1);
734             } else {
735                sign = ec_dec_bits((ec_dec*)ec, 1);
736             }
737          }
738          sign = 2*sign - 1;
739          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch);
740          y2[0] = -sign*x2[1];
741          y2[1] = sign*x2[0];
742          if (resynth)
743          {
744             celt_norm tmp;
745             X[0] = MULT16_16_Q15(mid, X[0]);
746             X[1] = MULT16_16_Q15(mid, X[1]);
747             Y[0] = MULT16_16_Q15(side, Y[0]);
748             Y[1] = MULT16_16_Q15(side, Y[1]);
749             tmp = X[0];
750             X[0] = SUB16(tmp,Y[0]);
751             Y[0] = ADD16(tmp,Y[0]);
752             tmp = X[1];
753             X[1] = SUB16(tmp,Y[1]);
754             Y[1] = ADD16(tmp,Y[1]);
755          }
756       } else {
757          /* "Normal" split code */
758          celt_norm *next_lowband2=NULL;
759          celt_norm *next_lowband_out1=NULL;
760          int next_level=0;
761
762          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
763          if (B>1 && !stereo && itheta > 8192)
764             delta -= delta>>(1+level);
765
766          mbits = (b-qalloc-delta)/2;
767          if (mbits > b-qalloc)
768             mbits = b-qalloc;
769          if (mbits<0)
770             mbits=0;
771          sbits = b-qalloc-mbits;
772          *remaining_bits -= qalloc;
773
774          if (lowband && !stereo)
775             next_lowband2 = lowband+N; /* >32-bit split case */
776
777          /* Only stereo needs to pass on lowband_out. Otherwise, it's
778             handled at the end */
779          if (stereo)
780             next_lowband_out1 = lowband_out;
781          else
782             next_level = level+1;
783
784          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change,
785                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
786                NULL, next_level, seed, MULT16_16_P15(gain,mid), lowband_scratch);
787          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change,
788                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
789                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL);
790       }
791
792    } else {
793       /* This is the basic no-split case */
794       q = bits2pulses(m, i, LM, b);
795       curr_bits = pulses2bits(m, i, LM, q);
796       *remaining_bits -= curr_bits;
797
798       /* Ensures we can never bust the budget */
799       while (*remaining_bits < 0 && q > 0)
800       {
801          *remaining_bits += curr_bits;
802          q--;
803          curr_bits = pulses2bits(m, i, LM, q);
804          *remaining_bits -= curr_bits;
805       }
806
807       if (q!=0)
808       {
809          int K = get_pulses(q);
810
811          /* Finally do the actual quantization */
812          if (encode)
813             alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
814          else
815             alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, seed, gain);
816       } else {
817          /* If there's no pulse, fill the band anyway */
818          int j;
819          if (lowband != NULL && resynth)
820          {
821             if (spread==2 && B<=1)
822             {
823                /* Folded spectrum */
824                for (j=0;j<N;j++)
825                {
826                   *seed = lcg_rand(*seed);
827                   X[j] = (int)(*seed)>>20;
828                }
829             } else {
830                /* Noise */
831                for (j=0;j<N;j++)
832                   X[j] = lowband[j];
833             }
834             renormalise_vector(X, N, gain);
835          } else {
836             /* This is important for encoding the side in stereo mode */
837             for (j=0;j<N;j++)
838                X[j] = 0;
839          }
840       }
841    }
842
843    /* This code is used by the decoder and by the resynthesis-enabled encoder */
844    if (resynth)
845    {
846       if (stereo)
847       {
848          if (N!=2)
849             stereo_merge(X, Y, mid, side, N);
850       } else if (level == 0)
851       {
852          int k;
853
854          /* Undo the sample reorganization going from time order to frequency order */
855          if (B0>1)
856             interleave_vector(X, N_B, B0);
857
858          /* Undo time-freq changes that we did earlier */
859          N_B = N_B0;
860          B = B0;
861          for (k=0;k<time_divide;k++)
862          {
863             B >>= 1;
864             N_B <<= 1;
865             haar1(X, N_B, B);
866          }
867
868          for (k=0;k<recombine;k++)
869          {
870             haar1(X, N_B, B);
871             N_B>>=1;
872             B <<= 1;
873          }
874
875          /* Scale output for later folding */
876          if (lowband_out)
877          {
878             int j;
879             celt_word16 n;
880             n = celt_sqrt(SHL32(EXTEND32(N0),22));
881             for (j=0;j<N0;j++)
882                lowband_out[j] = MULT16_16_Q15(n,X[j]);
883          }
884       }
885    }
886 }
887
888 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM, int codedBands)
889 {
890    int i, balance;
891    celt_int32 remaining_bits;
892    const celt_int16 * restrict eBands = m->eBands;
893    celt_norm * restrict norm;
894    VARDECL(celt_norm, _norm);
895    VARDECL(celt_norm, lowband_scratch);
896    int B;
897    int M;
898    celt_int32 seed;
899    celt_norm *lowband;
900    int update_lowband = 1;
901    int C = _Y != NULL ? 2 : 1;
902    SAVE_STACK;
903
904    M = 1<<LM;
905    B = shortBlocks ? M : 1;
906    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
907    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
908    norm = _norm;
909
910    if (encode)
911       seed = ((ec_enc*)ec)->rng;
912    else
913       seed = ((ec_dec*)ec)->rng;
914    balance = 0;
915    lowband = NULL;
916    for (i=start;i<end;i++)
917    {
918       int tell;
919       int b;
920       int N;
921       int curr_balance;
922       celt_norm *effective_lowband=NULL;
923       celt_norm * restrict X, * restrict Y;
924       int tf_change=0;
925       
926       X = _X+M*eBands[i];
927       if (_Y!=NULL)
928          Y = _Y+M*eBands[i];
929       else
930          Y = NULL;
931       N = M*eBands[i+1]-M*eBands[i];
932       if (encode)
933          tell = ec_enc_tell((ec_enc*)ec, BITRES);
934       else
935          tell = ec_dec_tell((ec_dec*)ec, BITRES);
936
937       /* Compute how many bits we want to allocate to this band */
938       if (i != start)
939          balance -= tell;
940       remaining_bits = (total_bits<<BITRES)-tell-1;
941       if (i <= codedBands-1)
942       {
943          curr_balance = (codedBands-i);
944          if (curr_balance > 3)
945             curr_balance = 3;
946          curr_balance = balance / curr_balance;
947          b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
948          if (b<0)
949             b = 0;
950       } else {
951          b = 0;
952       }
953       /* Prevents ridiculous bit depths */
954       if (b > C*16*N<<BITRES)
955          b = C*16*N<<BITRES;
956
957       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
958             lowband = norm+M*eBands[i];
959
960       tf_change = tf_res[i];
961       if (i>=m->effEBands)
962       {
963          X=norm;
964          if (_Y!=NULL)
965             Y = norm;
966       }
967
968       /* This ensures we never repeat spectral content within one band */
969       if (lowband != NULL)
970       {
971          effective_lowband = lowband-N;
972          if (effective_lowband < norm+M*eBands[start])
973             effective_lowband = norm+M*eBands[start];
974       }
975       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change,
976             effective_lowband, resynth, ec, &remaining_bits, LM,
977             norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch);
978
979       balance += pulses[i] + tell;
980
981       /* Update the folding position only as long as we have 1 bit/sample depth */
982       update_lowband = (b>>BITRES)>N;
983    }
984    RESTORE_STACK;
985 }
986