Using random noise in upper bands when signal is "normal"
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
216 {
217    int j, c;
218    celt_word16 g;
219    celt_word16 delta;
220    const int C = CHANNELS(_C);
221    celt_word32 Sxy=0, Sxx=0, Syy=0;
222    int len = M*m->pitchEnd;
223    int N = M*m->shortMdctSize;
224 #ifdef FIXED_POINT
225    int shift = 0;
226    celt_word32 maxabs=0;
227
228    for (c=0;c<C;c++)
229    {
230       for (j=0;j<len;j++)
231       {
232          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
233          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
234       }
235    }
236    shift = celt_ilog2(maxabs)-12;
237    if (shift<0)
238       shift = 0;
239 #endif
240    delta = PDIV32_16(Q15ONE, len);
241    for (c=0;c<C;c++)
242    {
243       celt_word16 gg = Q15ONE;
244       for (j=0;j<len;j++)
245       {
246          celt_word16 Xj, Pj;
247          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
248          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
249          Sxy = MAC16_16(Sxy, Xj, Pj);
250          Sxx = MAC16_16(Sxx, Pj, Pj);
251          Syy = MAC16_16(Syy, Xj, Xj);
252          gg = SUB16(gg, delta);
253       }
254    }
255 #ifdef FIXED_POINT
256    {
257       celt_word32 num, den;
258       celt_word16 fact;
259       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
260       if (fact < QCONST16(1.f, 14))
261          fact = QCONST16(1.f, 14);
262       num = Sxy;
263       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
264       shift = celt_zlog2(Sxy)-16;
265       if (shift < 0)
266          shift = 0;
267       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
268          g = 0;
269       else
270          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
271
272       /* This MUST round down so that we don't over-estimate the gain */
273       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
274    }
275 #else
276    {
277       float fact = .04f*norm_rate;
278       if (fact < 1)
279          fact = 1;
280       g = Sxy/(.1f+Sxx+.03f*Syy);
281       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
282          g = 0;
283       /* This MUST round down so that we don't over-estimate the gain */
284       *gain_id = floor(20*(g-.5f));
285    }
286 #endif
287    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
288       maximum error amplification factor to 2.0 */
289    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
290    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
291    if (*gain_prod>QCONST32(2.f, 13))
292    {
293       *gain_id=9;
294       *gain_prod = QCONST32(2.f, 13);
295    }
296
297    if (*gain_id < 0)
298    {
299       *gain_id = 0;
300       return 0;
301    } else {
302       if (*gain_id > 15)
303          *gain_id = 15;
304       return 1;
305    }
306 }
307
308 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
309 {
310    int j, c, N;
311    celt_word16 gain;
312    celt_word16 delta;
313    const int C = CHANNELS(_C);
314    int len = M*m->pitchEnd;
315
316    N = M*m->shortMdctSize;
317    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
318    delta = PDIV32_16(gain, len);
319    if (pred)
320       gain = -gain;
321    else
322       delta = -delta;
323    for (c=0;c<C;c++)
324    {
325       celt_word16 gg = gain;
326       for (j=0;j<len;j++)
327       {
328          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
329          gg = ADD16(gg, delta);
330       }
331    }
332 }
333
334 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
335 {
336    int i = bandID;
337    int j;
338    celt_word16 a1, a2;
339    if (stereo_mode==0)
340    {
341       /* Do mid-side when not doing intensity stereo */
342       a1 = QCONST16(.70711f,14);
343       a2 = dir*QCONST16(.70711f,14);
344    } else {
345       celt_word16 left, right;
346       celt_word16 norm;
347 #ifdef FIXED_POINT
348       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
349 #endif
350       left = VSHR32(bank[i],shift);
351       right = VSHR32(bank[i+m->nbEBands],shift);
352       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
353       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
354       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
355    }
356    for (j=0;j<N;j++)
357    {
358       celt_norm r, l;
359       l = X[j];
360       r = Y[j];
361       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
362       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
363    }
364 }
365
366
367 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
368 {
369    int i, c, N0;
370    int NR=0;
371    celt_word32 ratio = EPSILON;
372    const int C = CHANNELS(_C);
373    const celt_int16 * restrict eBands = m->eBands;
374    
375    N0 = M*m->shortMdctSize;
376
377    for (c=0;c<C;c++)
378    {
379    for (i=0;i<end;i++)
380    {
381       int j, N;
382       int max_i=0;
383       celt_word16 max_val=EPSILON;
384       celt_word32 floor_ener=EPSILON;
385       celt_norm * restrict x = X+M*eBands[i]+c*N0;
386       N = M*eBands[i+1]-M*eBands[i];
387       for (j=0;j<N;j++)
388       {
389          if (ABS16(x[j])>max_val)
390          {
391             max_val = ABS16(x[j]);
392             max_i = j;
393          }
394       }
395 #if 0
396       for (j=0;j<N;j++)
397       {
398          if (abs(j-max_i)>2)
399             floor_ener += x[j]*x[j];
400       }
401 #else
402       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
403       if (max_i < N-1)
404          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
405       if (max_i < N-2)
406          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
407       if (max_i > 0)
408          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
409       if (max_i > 1)
410          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
411       floor_ener = MAX32(floor_ener, EPSILON);
412 #endif
413       if (N>7)
414       {
415          celt_word16 r;
416          celt_word16 den = celt_sqrt(floor_ener);
417          den = MAX32(QCONST16(.02f, 15), den);
418          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
419          ratio = ADD32(ratio, EXTEND32(r));
420          NR++;
421       }
422    }
423    }
424    if (NR>0)
425       ratio = DIV32_16(ratio, NR);
426    ratio = ADD32(HALF32(ratio), HALF32(*average));
427    if (!*last_decision)
428    {
429       *last_decision = (ratio < QCONST16(1.8f,8));
430    } else {
431       *last_decision = (ratio < QCONST16(3.f,8));
432    }
433    *average = EXTRACT16(ratio);
434    return *last_decision;
435 }
436
437 static void interleave_vector(celt_norm *X, int N0, int stride)
438 {
439    int i,j;
440    VARDECL(celt_norm, tmp);
441    int N;
442    SAVE_STACK;
443    N = N0*stride;
444    ALLOC(tmp, N, celt_norm);
445    for (i=0;i<stride;i++)
446       for (j=0;j<N0;j++)
447          tmp[j*stride+i] = X[i*N0+j];
448    for (j=0;j<N;j++)
449       X[j] = tmp[j];
450    RESTORE_STACK;
451 }
452
453 static void deinterleave_vector(celt_norm *X, int N0, int stride)
454 {
455    int i,j;
456    VARDECL(celt_norm, tmp);
457    int N;
458    SAVE_STACK;
459    N = N0*stride;
460    ALLOC(tmp, N, celt_norm);
461    for (i=0;i<stride;i++)
462       for (j=0;j<N0;j++)
463          tmp[i*N0+j] = X[j*stride+i];
464    for (j=0;j<N;j++)
465       X[j] = tmp[j];
466    RESTORE_STACK;
467 }
468
469 static void haar1(celt_norm *X, int N0, int stride)
470 {
471    int i, j;
472    N0 >>= 1;
473    for (i=0;i<stride;i++)
474       for (j=0;j<N0;j++)
475       {
476          celt_norm tmp = X[stride*2*j+i];
477          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
478          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
479       }
480 }
481
482 /* This function is responsible for encoding and decoding a band for both
483    the mono and stereo case. Even in the mono case, it can split the band
484    in two and transmit the energy difference with the two half-bands. It
485    can be called recursively so bands can end up being split in 8 parts. */
486 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
487       int N, int b, int spread, int tf_change, celt_norm *lowband, int resynth, void *ec,
488       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level, celt_int32 *seed)
489 {
490    int q;
491    int curr_bits;
492    int stereo, split;
493    int imid=0, iside=0;
494    int N0=N;
495    int N_B=N;
496    int N_B0;
497    int spread0=spread;
498    int time_divide=0;
499    int recombine=0;
500
501    if (spread)
502       N_B /= spread;
503    N_B0 = N_B;
504
505    split = stereo = Y != NULL;
506
507    /* Special case for one sample */
508    if (N==1)
509    {
510       int c;
511       celt_norm *x = X;
512       for (c=0;c<1+stereo;c++)
513       {
514          int sign=0;
515          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
516          {
517             if (encode)
518             {
519                sign = x[0]<0;
520                ec_enc_bits((ec_enc*)ec, sign, 1);
521             } else {
522                sign = ec_dec_bits((ec_dec*)ec, 1);
523             }
524             *remaining_bits -= 1<<BITRES;
525             b-=1<<BITRES;
526          }
527          if (resynth)
528             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
529          x = Y;
530       }
531       if (lowband_out)
532          lowband_out[0] = X[0];
533       return;
534    }
535
536    /* Band recombining to increase frequency resolution */
537    if (!stereo && spread > 1 && level == 0 && tf_change>0)
538    {
539       while (spread>1 && tf_change>0)
540       {
541          spread>>=1;
542          N_B<<=1;
543          if (encode)
544             haar1(X, N_B, spread);
545          if (lowband)
546             haar1(lowband, N_B, spread);
547          recombine++;
548          tf_change--;
549       }
550       spread0=spread;
551       N_B0 = N_B;
552    }
553
554    /* Increasing the time resolution */
555    if (!stereo && level==0)
556    {
557       while ((N_B&1) == 0 && tf_change<0 && spread <= (1<<LM))
558       {
559          if (encode)
560             haar1(X, N_B, spread);
561          if (lowband)
562             haar1(lowband, N_B, spread);
563          spread <<= 1;
564          N_B >>= 1;
565          time_divide++;
566          tf_change++;
567       }
568       spread0 = spread;
569       N_B0 = N_B;
570    }
571
572    /* Reorganize the samples in time order instead of frequency order */
573    if (!stereo && spread0>1 && level==0)
574    {
575       if (encode)
576          deinterleave_vector(X, N_B, spread0);
577       if (lowband)
578          deinterleave_vector(lowband, N_B, spread0);
579    }
580
581    /* If we need more than 32 bits, try splitting the band in two. */
582    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
583    {
584       if (LM>0 || (N&1)==0)
585       {
586          N >>= 1;
587          Y = X+N;
588          split = 1;
589          LM -= 1;
590          spread = (spread+1)>>1;
591       }
592    }
593
594    if (split)
595    {
596       int qb;
597       int itheta=0;
598       int mbits, sbits, delta;
599       int qalloc;
600       celt_word16 mid, side;
601       int offset, N2;
602       offset = m->logN[i]+(LM<<BITRES)-QTHETA_OFFSET;
603
604       /* Decide on the resolution to give to the split parameter theta */
605       N2 = 2*N-1;
606       if (stereo && N>2)
607          N2--;
608       qb = (b+N2*offset)/(N2<<BITRES);
609       if (qb > (b>>(BITRES+1))-1)
610          qb = (b>>(BITRES+1))-1;
611
612       if (qb<0)
613          qb = 0;
614       if (qb>14)
615          qb = 14;
616
617       qalloc = 0;
618       if (qb!=0)
619       {
620          int shift;
621          shift = 14-qb;
622
623          if (encode)
624          {
625             if (stereo)
626                stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, N);
627
628             mid = renormalise_vector(X, Q15ONE, N, 1);
629             side = renormalise_vector(Y, Q15ONE, N, 1);
630
631             /* theta is the atan() of the ration between the (normalized)
632                side and mid. With just that parameter, we can re-scale both
633                mid and side because we know that 1) they have unit norm and
634                2) they are orthogonal. */
635    #ifdef FIXED_POINT
636             /* 0.63662 = 2/pi */
637             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
638    #else
639             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
640    #endif
641
642             itheta = (itheta+(1<<shift>>1))>>shift;
643          }
644
645          /* Entropy coding of the angle. We use a uniform pdf for the
646             first stereo split but a triangular one for the rest. */
647          if (stereo || qb>9 || spread>1)
648          {
649             if (encode)
650                ec_enc_uint((ec_enc*)ec, itheta, (1<<qb)+1);
651             else
652                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
653             qalloc = log2_frac((1<<qb)+1,BITRES);
654          } else {
655             int fs=1, ft;
656             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
657             if (encode)
658             {
659                int j;
660                int fl=0;
661                j=0;
662                while(1)
663                {
664                   if (j==itheta)
665                      break;
666                   fl+=fs;
667                   if (j<(1<<qb>>1))
668                      fs++;
669                   else
670                      fs--;
671                   j++;
672                }
673                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
674             } else {
675                int fl=0;
676                int j, fm;
677                fm = ec_decode((ec_dec*)ec, ft);
678                j=0;
679                while (1)
680                {
681                   if (fm < fl+fs)
682                      break;
683                   fl+=fs;
684                   if (j<(1<<qb>>1))
685                      fs++;
686                   else
687                      fs--;
688                   j++;
689                }
690                itheta = j;
691                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
692             }
693             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
694          }
695          itheta <<= shift;
696       }
697
698       if (itheta == 0)
699       {
700          imid = 32767;
701          iside = 0;
702          delta = -10000;
703       } else if (itheta == 16384)
704       {
705          imid = 0;
706          iside = 32767;
707          delta = 10000;
708       } else {
709          imid = bitexact_cos(itheta);
710          iside = bitexact_cos(16384-itheta);
711          /* This is the mid vs side allocation that minimizes squared error
712             in that band. */
713          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
714       }
715
716       /* This is a special case for N=2 that only works for stereo and takes
717          advantage of the fact that mid and side are orthogonal to encode
718          the side with just one bit. */
719       if (N==2 && stereo)
720       {
721          int c, c2;
722          int sign=1;
723          celt_norm v[2], w[2];
724          celt_norm *x2, *y2;
725          mbits = b-qalloc;
726          sbits = 0;
727          if (itheta != 0 && itheta != 16384)
728             sbits = 1<<BITRES;
729          mbits -= sbits;
730          c = itheta > 8192 ? 1 : 0;
731          *remaining_bits -= qalloc+sbits;
732
733          x2 = X;
734          y2 = Y;
735          if (encode)
736          {
737             c2 = 1-c;
738
739             if (c==0)
740             {
741                v[0] = x2[0];
742                v[1] = x2[1];
743                w[0] = y2[0];
744                w[1] = y2[1];
745             } else {
746                v[0] = y2[0];
747                v[1] = y2[1];
748                w[0] = x2[0];
749                w[1] = x2[1];
750             }
751             /* Here we only need to encode a sign for the side */
752             if (v[0]*w[1] - v[1]*w[0] > 0)
753                sign = 1;
754             else
755                sign = -1;
756          }
757          quant_band(encode, m, i, v, NULL, N, mbits, spread, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed);
758          if (sbits)
759          {
760             if (encode)
761             {
762                ec_enc_bits((ec_enc*)ec, sign==1, 1);
763             } else {
764                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
765             }
766          } else {
767             sign = 1;
768          }
769          w[0] = -sign*v[1];
770          w[1] = sign*v[0];
771          if (c==0)
772          {
773             x2[0] = v[0];
774             x2[1] = v[1];
775             y2[0] = w[0];
776             y2[1] = w[1];
777          } else {
778             x2[0] = w[0];
779             x2[1] = w[1];
780             y2[0] = v[0];
781             y2[1] = v[1];
782          }
783       } else
784       {
785          /* "Normal" split code */
786          celt_norm *next_lowband2=NULL;
787          celt_norm *next_lowband_out1=NULL;
788          int next_level=0;
789
790          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
791          if (spread>1 && !stereo)
792             delta >>= 1;
793
794          mbits = (b-qalloc/2-delta)/2;
795          if (mbits > b-qalloc)
796             mbits = b-qalloc;
797          if (mbits<0)
798             mbits=0;
799          sbits = b-qalloc-mbits;
800          *remaining_bits -= qalloc;
801
802          if (lowband && !stereo)
803             next_lowband2 = lowband+N;
804          if (stereo)
805             next_lowband_out1 = lowband_out;
806          else
807             next_level = level+1;
808
809          quant_band(encode, m, i, X, NULL, N, mbits, spread, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level, seed);
810          quant_band(encode, m, i, Y, NULL, N, sbits, spread, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level, seed);
811       }
812
813    } else {
814       /* This is the basic no-split case */
815       q = bits2pulses(m, m->bits[LM][i], N, b);
816       curr_bits = pulses2bits(m->bits[LM][i], N, q);
817       *remaining_bits -= curr_bits;
818
819       /* Ensures we can never bust the budget */
820       while (*remaining_bits < 0 && q > 0)
821       {
822          *remaining_bits += curr_bits;
823          q--;
824          curr_bits = pulses2bits(m->bits[LM][i], N, q);
825          *remaining_bits -= curr_bits;
826       }
827
828       if (encode)
829          alg_quant(X, N, q, spread, lowband, resynth, (ec_enc*)ec, seed);
830       else
831          alg_unquant(X, N, q, spread, lowband, (ec_dec*)ec, seed);
832    }
833
834    /* This code is used by the decoder and by the resynthesis-enabled encoder */
835    if (resynth)
836    {
837       int k;
838
839       if (split)
840       {
841          int j;
842          celt_word16 mid, side;
843 #ifdef FIXED_POINT
844          mid = imid;
845          side = iside;
846 #else
847          mid = (1.f/32768)*imid;
848          side = (1.f/32768)*iside;
849 #endif
850          for (j=0;j<N;j++)
851             X[j] = MULT16_16_Q15(X[j], mid);
852          for (j=0;j<N;j++)
853             Y[j] = MULT16_16_Q15(Y[j], side);
854       }
855
856       if (!stereo && spread0>1 && level==0)
857       {
858          interleave_vector(X, N_B, spread0);
859          if (lowband)
860             interleave_vector(lowband, N_B, spread0);
861       }
862
863       /* Undo time-freq changes that we did earlier */
864       N_B = N_B0;
865       spread = spread0;
866       for (k=0;k<time_divide;k++)
867       {
868          spread >>= 1;
869          N_B <<= 1;
870          haar1(X, N_B, spread);
871          if (lowband)
872             haar1(lowband, N_B, spread);
873       }
874
875       for (k=0;k<recombine;k++)
876       {
877          haar1(X, N_B, spread);
878          if (lowband)
879             haar1(lowband, N_B, spread);
880          N_B>>=1;
881          spread <<= 1;
882       }
883
884       if (lowband_out && !stereo)
885       {
886          int j;
887          celt_word16 n;
888          n = celt_sqrt(SHL32(EXTEND32(N0),22));
889          for (j=0;j<N0;j++)
890             lowband_out[j] = MULT16_16_Q15(n,X[j]);
891       }
892
893       if (stereo)
894       {
895          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
896          renormalise_vector(X, Q15ONE, N, 1);
897          renormalise_vector(Y, Q15ONE, N, 1);
898       }
899    }
900 }
901
902 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
903 {
904    int i, balance;
905    celt_int32 remaining_bits;
906    const celt_int16 * restrict eBands = m->eBands;
907    celt_norm * restrict norm;
908    VARDECL(celt_norm, _norm);
909    int B;
910    int M;
911    int spread;
912    celt_int32 seed;
913    celt_norm *lowband;
914    int update_lowband = 1;
915    int C = _Y != NULL ? 2 : 1;
916    SAVE_STACK;
917
918    M = 1<<LM;
919    B = shortBlocks ? M : 1;
920    spread = fold ? B : 0;
921    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
922    norm = _norm;
923
924    if (encode)
925       seed = ((ec_enc*)ec)->rng;
926    else
927       seed = ((ec_dec*)ec)->rng;
928    balance = 0;
929    lowband = NULL;
930    for (i=start;i<end;i++)
931    {
932       int tell;
933       int b;
934       int N;
935       int curr_balance;
936       celt_norm * restrict X, * restrict Y;
937       int tf_change=0;
938       celt_norm *effective_lowband;
939       
940       X = _X+M*eBands[i];
941       if (_Y!=NULL)
942          Y = _Y+M*eBands[i];
943       else
944          Y = NULL;
945       N = M*eBands[i+1]-M*eBands[i];
946       if (encode)
947          tell = ec_enc_tell((ec_enc*)ec, BITRES);
948       else
949          tell = ec_dec_tell((ec_dec*)ec, BITRES);
950
951       if (i != start)
952          balance -= tell;
953       remaining_bits = (total_bits<<BITRES)-tell-1;
954       curr_balance = (end-i);
955       if (curr_balance > 3)
956          curr_balance = 3;
957       curr_balance = balance / curr_balance;
958       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
959       if (b<0)
960          b = 0;
961       /* Prevents ridiculous bit depths */
962       if (b > C*16*N<<BITRES)
963          b = C*16*N<<BITRES;
964
965       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
966             lowband = norm+M*eBands[i]-N;
967
968       tf_change = tf_res[i];
969       if (i>=m->effEBands)
970       {
971          X=norm;
972          if (_Y!=NULL)
973             Y = norm;
974       }
975
976       if (tf_change==0 && !shortBlocks && fold)
977          effective_lowband = NULL;
978       else
979          effective_lowband = lowband;
980       quant_band(encode, m, i, X, Y, N, b, spread, tf_change, effective_lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0, &seed);
981
982       balance += pulses[i] + tell;
983
984       /* Update the folding position only as long as we have 2 bit/sample depth */
985       update_lowband = (b>>BITRES)>2*N;
986    }
987    RESTORE_STACK;
988 }
989