More recombining "infrastructure"
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48
49 #ifdef FIXED_POINT
50 /* Compute the amplitude (sqrt energy) in each of the bands */
51 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
52 {
53    int i, c, N;
54    const celt_int16 *eBands = m->eBands;
55    const int C = CHANNELS(_C);
56    N = M*m->eBands[m->nbEBands+1];
57    for (c=0;c<C;c++)
58    {
59       for (i=0;i<m->nbEBands;i++)
60       {
61          int j;
62          celt_word32 maxval=0;
63          celt_word32 sum = 0;
64          
65          j=M*eBands[i]; do {
66             maxval = MAX32(maxval, X[j+c*N]);
67             maxval = MAX32(maxval, -X[j+c*N]);
68          } while (++j<M*eBands[i+1]);
69          
70          if (maxval > 0)
71          {
72             int shift = celt_ilog2(maxval)-10;
73             j=M*eBands[i]; do {
74                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
75                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
76             } while (++j<M*eBands[i+1]);
77             /* We're adding one here to make damn sure we never end up with a pitch vector that's
78                larger than unity norm */
79             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
80          } else {
81             bank[i+c*m->nbEBands] = EPSILON;
82          }
83          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
84       }
85    }
86    /*printf ("\n");*/
87 }
88
89 /* Normalise each band such that the energy is one. */
90 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
91 {
92    int i, c, N;
93    const celt_int16 *eBands = m->eBands;
94    const int C = CHANNELS(_C);
95    N = M*m->eBands[m->nbEBands+1];
96    for (c=0;c<C;c++)
97    {
98       i=0; do {
99          celt_word16 g;
100          int j,shift;
101          celt_word16 E;
102          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
103          E = VSHR32(bank[i+c*m->nbEBands], shift);
104          g = EXTRACT16(celt_rcp(SHL32(E,3)));
105          j=M*eBands[i]; do {
106             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
107          } while (++j<M*eBands[i+1]);
108       } while (++i<m->nbEBands);
109    }
110 }
111
112 #else /* FIXED_POINT */
113 /* Compute the amplitude (sqrt energy) in each of the bands */
114 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
115 {
116    int i, c, N;
117    const celt_int16 *eBands = m->eBands;
118    const int C = CHANNELS(_C);
119    N = M*m->eBands[m->nbEBands+1];
120    for (c=0;c<C;c++)
121    {
122       for (i=0;i<m->nbEBands;i++)
123       {
124          int j;
125          celt_word32 sum = 1e-10;
126          for (j=M*eBands[i];j<M*eBands[i+1];j++)
127             sum += X[j+c*N]*X[j+c*N];
128          bank[i+c*m->nbEBands] = sqrt(sum);
129          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
130       }
131    }
132    /*printf ("\n");*/
133 }
134
135 /* Normalise each band such that the energy is one. */
136 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
137 {
138    int i, c, N;
139    const celt_int16 *eBands = m->eBands;
140    const int C = CHANNELS(_C);
141    N = M*m->eBands[m->nbEBands+1];
142    for (c=0;c<C;c++)
143    {
144       for (i=0;i<m->nbEBands;i++)
145       {
146          int j;
147          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
148          for (j=M*eBands[i];j<M*eBands[i+1];j++)
149             X[j+c*N] = freq[j+c*N]*g;
150       }
151    }
152 }
153
154 #endif /* FIXED_POINT */
155
156 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int _C, int M)
157 {
158    int i, c;
159    const celt_int16 *eBands = m->eBands;
160    const int C = CHANNELS(_C);
161    for (c=0;c<C;c++)
162    {
163       i=0; do {
164          renormalise_vector(X+M*eBands[i]+c*M*eBands[m->nbEBands+1], Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
165       } while (++i<m->nbEBands);
166    }
167 }
168
169 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
170 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int _C, int M)
171 {
172    int i, c, N;
173    const celt_int16 *eBands = m->eBands;
174    const int C = CHANNELS(_C);
175    N = M*m->eBands[m->nbEBands+1];
176    if (C>2)
177       celt_fatal("denormalise_bands() not implemented for >2 channels");
178    for (c=0;c<C;c++)
179    {
180       celt_sig * restrict f;
181       const celt_norm * restrict x;
182       f = freq+c*N;
183       x = X+c*N;
184       for (i=0;i<m->nbEBands;i++)
185       {
186          int j, end;
187          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
188          j=M*eBands[i];
189          end = M*eBands[i+1];
190          do {
191             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
192             x++;
193          } while (++j<end);
194       }
195       for (i=M*eBands[m->nbEBands];i<M*eBands[m->nbEBands+1];i++)
196          *f++ = 0;
197    }
198 }
199
200 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
201 {
202    int j, c;
203    celt_word16 g;
204    celt_word16 delta;
205    const int C = CHANNELS(_C);
206    celt_word32 Sxy=0, Sxx=0, Syy=0;
207    int len = M*m->pitchEnd;
208    int N = M*m->eBands[m->nbEBands+1];
209 #ifdef FIXED_POINT
210    int shift = 0;
211    celt_word32 maxabs=0;
212
213    for (c=0;c<C;c++)
214    {
215       for (j=0;j<len;j++)
216       {
217          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
218          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
219       }
220    }
221    shift = celt_ilog2(maxabs)-12;
222    if (shift<0)
223       shift = 0;
224 #endif
225    delta = PDIV32_16(Q15ONE, len);
226    for (c=0;c<C;c++)
227    {
228       celt_word16 gg = Q15ONE;
229       for (j=0;j<len;j++)
230       {
231          celt_word16 Xj, Pj;
232          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
233          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
234          Sxy = MAC16_16(Sxy, Xj, Pj);
235          Sxx = MAC16_16(Sxx, Pj, Pj);
236          Syy = MAC16_16(Syy, Xj, Xj);
237          gg = SUB16(gg, delta);
238       }
239    }
240 #ifdef FIXED_POINT
241    {
242       celt_word32 num, den;
243       celt_word16 fact;
244       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
245       if (fact < QCONST16(1.f, 14))
246          fact = QCONST16(1.f, 14);
247       num = Sxy;
248       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
249       shift = celt_zlog2(Sxy)-16;
250       if (shift < 0)
251          shift = 0;
252       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
253          g = 0;
254       else
255          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
256
257       /* This MUST round down so that we don't over-estimate the gain */
258       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
259    }
260 #else
261    {
262       float fact = .04f*norm_rate;
263       if (fact < 1)
264          fact = 1;
265       g = Sxy/(.1f+Sxx+.03f*Syy);
266       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
267          g = 0;
268       /* This MUST round down so that we don't over-estimate the gain */
269       *gain_id = floor(20*(g-.5f));
270    }
271 #endif
272    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
273       maximum error amplification factor to 2.0 */
274    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
275    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
276    if (*gain_prod>QCONST32(2.f, 13))
277    {
278       *gain_id=9;
279       *gain_prod = QCONST32(2.f, 13);
280    }
281
282    if (*gain_id < 0)
283    {
284       *gain_id = 0;
285       return 0;
286    } else {
287       if (*gain_id > 15)
288          *gain_id = 15;
289       return 1;
290    }
291 }
292
293 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
294 {
295    int j, c, N;
296    celt_word16 gain;
297    celt_word16 delta;
298    const int C = CHANNELS(_C);
299    int len = M*m->pitchEnd;
300
301    N = M*m->eBands[m->nbEBands+1];
302    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
303    delta = PDIV32_16(gain, len);
304    if (pred)
305       gain = -gain;
306    else
307       delta = -delta;
308    for (c=0;c<C;c++)
309    {
310       celt_word16 gg = gain;
311       for (j=0;j<len;j++)
312       {
313          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
314          gg = ADD16(gg, delta);
315       }
316    }
317 }
318
319 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
320 {
321    int i = bandID;
322    int j;
323    celt_word16 a1, a2;
324    if (stereo_mode==0)
325    {
326       /* Do mid-side when not doing intensity stereo */
327       a1 = QCONST16(.70711f,14);
328       a2 = dir*QCONST16(.70711f,14);
329    } else {
330       celt_word16 left, right;
331       celt_word16 norm;
332 #ifdef FIXED_POINT
333       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
334 #endif
335       left = VSHR32(bank[i],shift);
336       right = VSHR32(bank[i+m->nbEBands],shift);
337       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
338       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
339       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
340    }
341    for (j=0;j<N;j++)
342    {
343       celt_norm r, l;
344       l = X[j];
345       r = Y[j];
346       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
347       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
348    }
349 }
350
351
352 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int _C, int M)
353 {
354    int i, c, N0;
355    int NR=0;
356    celt_word32 ratio = EPSILON;
357    const int C = CHANNELS(_C);
358    const celt_int16 * restrict eBands = m->eBands;
359    
360    N0 = M*m->eBands[m->nbEBands+1];
361
362    for (c=0;c<C;c++)
363    {
364    for (i=0;i<m->nbEBands;i++)
365    {
366       int j, N;
367       int max_i=0;
368       celt_word16 max_val=EPSILON;
369       celt_word32 floor_ener=EPSILON;
370       celt_norm * restrict x = X+M*eBands[i]+c*N0;
371       N = M*eBands[i+1]-M*eBands[i];
372       for (j=0;j<N;j++)
373       {
374          if (ABS16(x[j])>max_val)
375          {
376             max_val = ABS16(x[j]);
377             max_i = j;
378          }
379       }
380 #if 0
381       for (j=0;j<N;j++)
382       {
383          if (abs(j-max_i)>2)
384             floor_ener += x[j]*x[j];
385       }
386 #else
387       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
388       if (max_i < N-1)
389          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
390       if (max_i < N-2)
391          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
392       if (max_i > 0)
393          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
394       if (max_i > 1)
395          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
396       floor_ener = MAX32(floor_ener, EPSILON);
397 #endif
398       if (N>7)
399       {
400          celt_word16 r;
401          celt_word16 den = celt_sqrt(floor_ener);
402          den = MAX32(QCONST16(.02f, 15), den);
403          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
404          ratio = ADD32(ratio, EXTEND32(r));
405          NR++;
406       }
407    }
408    }
409    if (NR>0)
410       ratio = DIV32_16(ratio, NR);
411    ratio = ADD32(HALF32(ratio), HALF32(*average));
412    if (!*last_decision)
413    {
414       *last_decision = (ratio < QCONST16(1.8f,8));
415    } else {
416       *last_decision = (ratio < QCONST16(3.f,8));
417    }
418    *average = EXTRACT16(ratio);
419    return *last_decision;
420 }
421
422 static void interleave_vector(celt_norm *X, int N0, int stride)
423 {
424    int i,j;
425    VARDECL(celt_norm, tmp);
426    int N;
427    SAVE_STACK;
428    N = N0*stride;
429    ALLOC(tmp, N, celt_norm);
430    for (i=0;i<stride;i++)
431       for (j=0;j<N0;j++)
432          tmp[j*stride+i] = X[i*N0+j];
433    for (j=0;j<N;j++)
434       X[j] = tmp[j];
435    RESTORE_STACK;
436 }
437
438 static void deinterleave_vector(celt_norm *X, int N0, int stride)
439 {
440    int i,j;
441    VARDECL(celt_norm, tmp);
442    int N;
443    SAVE_STACK;
444    N = N0*stride;
445    ALLOC(tmp, N, celt_norm);
446    for (i=0;i<stride;i++)
447       for (j=0;j<N0;j++)
448          tmp[i*N0+j] = X[j*stride+i];
449    for (j=0;j<N;j++)
450       X[j] = tmp[j];
451    RESTORE_STACK;
452 }
453
454 static void haar1(celt_norm *X, int N0, int stride)
455 {
456    int i, j;
457    N0 >>= 1;
458    for (i=0;i<stride;i++)
459       for (j=0;j<N0;j++)
460       {
461          celt_norm tmp = X[stride*2*j+i];
462          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
463          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
464       }
465 }
466
467 /* This function is responsible for encoding and decoding a band for both
468    the mono and stereo case. Even in the mono case, it can split the band
469    in two and transmit the energy difference with the two half-bands. It
470    can be called recursively so bands can end up being split in 8 parts. */
471 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
472       int N, int b, int spread, celt_norm *lowband, int resynth, ec_enc *ec,
473       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level)
474 {
475    int q;
476    int curr_bits;
477    int stereo, split;
478    int imid=0, iside=0;
479    int N0=N;
480    int N_B=N;
481    int N_B0;
482    int spread0=spread;
483    int time_divide=0;
484    int recombine=0;
485    int tf_change=-1;
486
487    if (spread)
488       N_B /= spread;
489    N_B0 = N_B;
490
491    split = stereo = Y != NULL;
492
493    if (!stereo && spread > 1 && level == 0 && tf_change>0)
494    {
495       while (spread>1 && tf_change>0)
496       {
497          spread>>=1;
498          N_B<<=1;
499          if (encode)
500             haar1(X, N_B, spread);
501          if (lowband)
502             haar1(lowband, N_B, spread);
503          recombine++;
504          tf_change--;
505       }
506       spread0=spread;
507       N_B0 = N_B;
508    }
509
510    if (!stereo && spread>1 && level==0)
511    {
512       while ((N_B&1) == 0 && tf_change<0 && spread <= (1<<LM))
513       {
514          if (encode)
515             haar1(X, N_B, spread);
516          if (lowband)
517             haar1(lowband, N_B, spread);
518          spread <<= 1;
519          N_B >>= 1;
520          time_divide++;
521          tf_change++;
522       }
523       spread0 = spread;
524       N_B0 = N_B;
525       if (spread0>1)
526       {
527          if (encode)
528             deinterleave_vector(X, N_B, spread0);
529          if (lowband)
530             deinterleave_vector(lowband, N_B, spread0);
531       }
532    }
533
534    /* If we need more than 32 bits, try splitting the band in two. */
535    if (!stereo && LM != -1 && b > 32<<BITRES)
536    {
537       if (LM>0 || (N&1)==0)
538       {
539          N >>= 1;
540          Y = X+N;
541          split = 1;
542          LM -= 1;
543          spread = (spread+1)>>1;
544       }
545    }
546
547    if (split)
548    {
549       int qb;
550       int itheta;
551       int mbits, sbits, delta;
552       int qalloc;
553       celt_word16 mid, side;
554       if (N>1)
555       {
556          qb = (b-2*(N-1)*(QTHETA_OFFSET-m->logN[i]-(LM<<BITRES)))/(2*(N-1)<<BITRES);
557          if (qb > (b>>(BITRES+1))-1)
558             qb = (b>>(BITRES+1))-1;
559       } else {
560          /* For N==1, allocate two bits for the signs and the rest goes to qb */
561          qb = b-(2<<BITRES);
562       }
563       if (qb<0)
564          qb = 0;
565       if (qb>14)
566          qb = 14;
567
568       if (encode)
569       {
570          if (stereo)
571             stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, N);
572
573          mid = renormalise_vector(X, Q15ONE, N, 1);
574          side = renormalise_vector(Y, Q15ONE, N, 1);
575
576          /* 0.63662 = 2/pi */
577 #ifdef FIXED_POINT
578          itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
579 #else
580          itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
581 #endif
582       }
583
584       qalloc = 0;
585       if (qb==0)
586       {
587          itheta=0;
588       } else {
589          int shift;
590          shift = 14-qb;
591
592          /* Entropy coding of the angle. We use a uniform pdf for the
593             first stereo split but a triangular one for the rest. */
594          if (encode)
595             itheta = (itheta+(1<<shift>>1))>>shift;
596          if (stereo || qb>9 || spread>1)
597          {
598             if (encode)
599                ec_enc_uint(ec, itheta, (1<<qb)+1);
600             else
601                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
602             qalloc = log2_frac((1<<qb)+1,BITRES);
603          } else {
604             int fs=1, ft;
605             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
606             if (encode)
607             {
608                int j;
609                int fl=0;
610                j=0;
611                while(1)
612                {
613                   if (j==itheta)
614                      break;
615                   fl+=fs;
616                   if (j<(1<<qb>>1))
617                      fs++;
618                   else
619                      fs--;
620                   j++;
621                }
622                ec_encode(ec, fl, fl+fs, ft);
623             } else {
624                int fl=0;
625                int j, fm;
626                fm = ec_decode((ec_dec*)ec, ft);
627                j=0;
628                while (1)
629                {
630                   if (fm < fl+fs)
631                      break;
632                   fl+=fs;
633                   if (j<(1<<qb>>1))
634                      fs++;
635                   else
636                      fs--;
637                   j++;
638                }
639                itheta = j;
640                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
641             }
642             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
643          }
644          itheta <<= shift;
645       }
646
647       if (itheta == 0)
648       {
649          imid = 32767;
650          iside = 0;
651          delta = -10000;
652       } else if (itheta == 16384)
653       {
654          imid = 0;
655          iside = 32767;
656          delta = 10000;
657       } else {
658          imid = bitexact_cos(itheta);
659          iside = bitexact_cos(16384-itheta);
660          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
661       }
662
663       /* This is a special case for N=2 that only works for stereo and takes
664          advantage of the fact that mid and side are orthogonal to encode
665          the side with just one bit. */
666       if (N==2 && stereo)
667       {
668          int c, c2;
669          int sign=1;
670          celt_norm v[2], w[2];
671          celt_norm *x2, *y2;
672          mbits = b-qalloc;
673          sbits = 0;
674          if (itheta != 0 && itheta != 16384)
675             sbits = 1<<BITRES;
676          mbits -= sbits;
677          c = itheta > 8192 ? 1 : 0;
678          *remaining_bits -= qalloc+sbits;
679
680          x2 = X;
681          y2 = Y;
682          if (encode)
683          {
684             c2 = 1-c;
685
686             if (c==0)
687             {
688                v[0] = x2[0];
689                v[1] = x2[1];
690                w[0] = y2[0];
691                w[1] = y2[1];
692             } else {
693                v[0] = y2[0];
694                v[1] = y2[1];
695                w[0] = x2[0];
696                w[1] = x2[1];
697             }
698             /* Here we only need to encode a sign for the side */
699             if (v[0]*w[1] - v[1]*w[0] > 0)
700                sign = 1;
701             else
702                sign = -1;
703          }
704          quant_band(encode, m, i, v, NULL, N, mbits, spread, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1);
705          if (sbits)
706          {
707             if (encode)
708             {
709                ec_enc_bits(ec, sign==1, 1);
710             } else {
711                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
712             }
713          } else {
714             sign = 1;
715          }
716          w[0] = -sign*v[1];
717          w[1] = sign*v[0];
718          if (c==0)
719          {
720             x2[0] = v[0];
721             x2[1] = v[1];
722             y2[0] = w[0];
723             y2[1] = w[1];
724          } else {
725             x2[0] = w[0];
726             x2[1] = w[1];
727             y2[0] = v[0];
728             y2[1] = v[1];
729          }
730       } else
731       {
732          /* "Normal" split code */
733          celt_norm *next_lowband2=NULL;
734          celt_norm *next_lowband_out1=NULL;
735          int next_level=0;
736
737          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
738          if (spread>1 && !stereo)
739             delta >>= 1;
740
741          mbits = (b-qalloc/2-delta)/2;
742          if (mbits > b-qalloc)
743             mbits = b-qalloc;
744          if (mbits<0)
745             mbits=0;
746          sbits = b-qalloc-mbits;
747          *remaining_bits -= qalloc;
748
749          if (lowband && !stereo)
750             next_lowband2 = lowband+N;
751          if (stereo)
752             next_lowband_out1 = lowband_out;
753          else
754             next_level = level+1;
755
756          quant_band(encode, m, i, X, NULL, N, mbits, spread, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level);
757          quant_band(encode, m, i, Y, NULL, N, sbits, spread, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level);
758       }
759
760    } else {
761       /* This is the basic no-split case */
762       q = bits2pulses(m, m->bits[LM][i], N, b);
763       curr_bits = pulses2bits(m->bits[LM][i], N, q);
764       *remaining_bits -= curr_bits;
765
766       /* Ensures we can never bust the budget */
767       while (*remaining_bits < 0 && q > 0)
768       {
769          *remaining_bits += curr_bits;
770          q--;
771          curr_bits = pulses2bits(m->bits[LM][i], N, q);
772          *remaining_bits -= curr_bits;
773       }
774
775       if (encode)
776          alg_quant(X, N, q, spread, lowband, resynth, ec);
777       else
778          alg_unquant(X, N, q, spread, lowband, (ec_dec*)ec);
779    }
780
781    if (resynth)
782    {
783       if (split)
784       {
785          int j;
786          celt_word16 mid, side;
787 #ifdef FIXED_POINT
788          mid = imid;
789          side = iside;
790 #else
791          mid = (1.f/32768)*imid;
792          side = (1.f/32768)*iside;
793 #endif
794          for (j=0;j<N;j++)
795             X[j] = MULT16_16_Q15(X[j], mid);
796          for (j=0;j<N;j++)
797             Y[j] = MULT16_16_Q15(Y[j], side);
798       }
799
800
801       if (!stereo && spread0>1 && level==0)
802       {
803          int k;
804          interleave_vector(X, N_B, spread0);
805          if (lowband)
806             interleave_vector(lowband, N_B, spread0);
807          N_B = N_B0;
808          spread = spread0;
809          for (k=0;k<time_divide;k++)
810          {
811             spread >>= 1;
812             N_B <<= 1;
813             haar1(X, N_B, spread);
814             if (lowband)
815                haar1(lowband, N_B, spread);
816          }
817       }
818
819       if (!stereo && level == 0)
820       {
821          int k;
822          spread = spread0;
823          N_B = N_B0;
824          for (k=0;k<recombine;k++)
825          {
826             haar1(X, N_B, spread);
827             if (lowband)
828                haar1(lowband, N_B, spread);
829             N_B>>=1;
830             spread <<= 1;
831          }
832       }
833
834       if (lowband_out && !stereo)
835       {
836          int j;
837          celt_word16 n;
838          n = celt_sqrt(SHL32(EXTEND32(N0),22));
839          for (j=0;j<N0;j++)
840             lowband_out[j] = MULT16_16_Q15(n,X[j]);
841       }
842
843       if (stereo)
844       {
845          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
846          renormalise_vector(X, Q15ONE, N, 1);
847          renormalise_vector(Y, Q15ONE, N, 1);
848       }
849    }
850 }
851
852 void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, ec_enc *ec, int LM)
853 {
854    int i, remaining_bits, balance;
855    const celt_int16 * restrict eBands = m->eBands;
856    celt_norm * restrict norm;
857    VARDECL(celt_norm, _norm);
858    int B;
859    int M;
860    int spread;
861    SAVE_STACK;
862
863    M = 1<<LM;
864    B = shortBlocks ? M : 1;
865    spread = fold ? B : 0;
866    ALLOC(_norm, M*eBands[m->nbEBands+1], celt_norm);
867    norm = _norm;
868
869    balance = 0;
870    for (i=start;i<m->nbEBands;i++)
871    {
872       int tell;
873       int b;
874       int N;
875       int curr_balance;
876       celt_norm * restrict X, * restrict Y;
877       celt_norm *lowband;
878       
879       X = _X+M*eBands[i];
880       if (_Y!=NULL)
881          Y = _Y+M*eBands[i];
882       else
883          Y = NULL;
884       N = M*eBands[i+1]-M*eBands[i];
885       if (encode)
886          tell = ec_enc_tell(ec, BITRES);
887       else
888          tell = ec_dec_tell((ec_dec*)ec, BITRES);
889
890       if (i != start)
891          balance -= tell;
892       remaining_bits = (total_bits<<BITRES)-tell-1;
893       curr_balance = (m->nbEBands-i);
894       if (curr_balance > 3)
895          curr_balance = 3;
896       curr_balance = balance / curr_balance;
897       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
898       if (b<0)
899          b = 0;
900
901       if (M*eBands[i]-N >= M*eBands[start])
902          lowband = norm+M*eBands[i]-N;
903       else
904          lowband = NULL;
905       quant_band(encode, m, i, X, Y, N, b, spread, lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0);
906
907       balance += pulses[i] + tell;
908    }
909    RESTORE_STACK;
910 }
911