Simplifies vector renormalisation (and using it less)
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10f;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = celt_sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, M*eBands[i+1]-M*eBands[i]);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
216 {
217    int i = bandID;
218    int j;
219    celt_word16 a1, a2;
220    if (stereo_mode==0)
221    {
222       /* Do mid-side when not doing intensity stereo */
223       a1 = QCONST16(.70711f,14);
224       a2 = dir*QCONST16(.70711f,14);
225    } else {
226       celt_word16 left, right;
227       celt_word16 norm;
228 #ifdef FIXED_POINT
229       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
230 #endif
231       left = VSHR32(bank[i],shift);
232       right = VSHR32(bank[i+m->nbEBands],shift);
233       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
234       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
235       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
236    }
237    for (j=0;j<N;j++)
238    {
239       celt_norm r, l;
240       l = X[j];
241       r = Y[j];
242       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
243       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
244    }
245 }
246
247 /* Decide whether we should spread the pulses in the current frame */
248 int folding_decision(const CELTMode *m, celt_norm *X, int *average, int *last_decision, int end, int _C, int M)
249 {
250    int i, c, N0;
251    int sum = 0, nbBands=0;
252    const int C = CHANNELS(_C);
253    const celt_int16 * restrict eBands = m->eBands;
254    int decision;
255    
256    N0 = M*m->shortMdctSize;
257
258    if (M*(eBands[end]-eBands[end-1]) <= 8)
259       return 0;
260    for (c=0;c<C;c++)
261    {
262       for (i=0;i<end;i++)
263       {
264          int j, N, tmp=0;
265          int tcount[3] = {0};
266          celt_norm * restrict x = X+M*eBands[i]+c*N0;
267          N = M*(eBands[i+1]-eBands[i]);
268          if (N<=8)
269             continue;
270          /* Compute rough CDF of |x[j]| */
271          for (j=0;j<N;j++)
272          {
273             celt_word32 x2N; /* Q13 */
274
275             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
276             if (x2N < QCONST16(0.25f,13))
277                tcount[0]++;
278             if (x2N < QCONST16(0.0625f,13))
279                tcount[1]++;
280             if (x2N < QCONST16(0.015625f,13))
281                tcount[2]++;
282          }
283
284          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
285          sum += tmp*256;
286          nbBands++;
287       }
288    }
289    sum /= nbBands;
290    /* Recursive averaging */
291    sum = (sum+*average)>>1;
292    *average = sum;
293    /* Hysteresis */
294    sum = (3*sum + ((*last_decision<<7) + 64) + 2)>>2;
295    /* decision and last_decision do not use the same encoding */
296    if (sum < 128)
297    {
298       decision = 2;
299       *last_decision = 0;
300    } else if (sum < 256)
301    {
302       decision = 1;
303       *last_decision = 1;
304    } else if (sum < 384)
305    {
306       decision = 3;
307       *last_decision = 2;
308    } else {
309       decision = 0;
310       *last_decision = 3;
311    }
312    return decision;
313 }
314
315 #ifdef MEASURE_NORM_MSE
316
317 float MSE[30] = {0};
318 int nbMSEBands = 0;
319 int MSECount[30] = {0};
320
321 void dump_norm_mse(void)
322 {
323    int i;
324    for (i=0;i<nbMSEBands;i++)
325    {
326       printf ("%g ", MSE[i]/MSECount[i]);
327    }
328    printf ("\n");
329 }
330
331 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
332 {
333    static int init = 0;
334    int i;
335    if (!init)
336    {
337       atexit(dump_norm_mse);
338       init = 1;
339    }
340    for (i=0;i<m->nbEBands;i++)
341    {
342       int j;
343       int c;
344       float g;
345       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
346          continue;
347       for (c=0;c<C;c++)
348       {
349          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
350          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
351             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
352       }
353       MSECount[i]+=C;
354    }
355    nbMSEBands = m->nbEBands;
356 }
357
358 #endif
359
360 static void interleave_vector(celt_norm *X, int N0, int stride)
361 {
362    int i,j;
363    VARDECL(celt_norm, tmp);
364    int N;
365    SAVE_STACK;
366    N = N0*stride;
367    ALLOC(tmp, N, celt_norm);
368    for (i=0;i<stride;i++)
369       for (j=0;j<N0;j++)
370          tmp[j*stride+i] = X[i*N0+j];
371    for (j=0;j<N;j++)
372       X[j] = tmp[j];
373    RESTORE_STACK;
374 }
375
376 static void deinterleave_vector(celt_norm *X, int N0, int stride)
377 {
378    int i,j;
379    VARDECL(celt_norm, tmp);
380    int N;
381    SAVE_STACK;
382    N = N0*stride;
383    ALLOC(tmp, N, celt_norm);
384    for (i=0;i<stride;i++)
385       for (j=0;j<N0;j++)
386          tmp[i*N0+j] = X[j*stride+i];
387    for (j=0;j<N;j++)
388       X[j] = tmp[j];
389    RESTORE_STACK;
390 }
391
392 static void haar1(celt_norm *X, int N0, int stride)
393 {
394    int i, j;
395    N0 >>= 1;
396    for (i=0;i<stride;i++)
397       for (j=0;j<N0;j++)
398       {
399          celt_norm tmp = X[stride*2*j+i];
400          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
401          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
402       }
403 }
404
405 static int compute_qn(int N, int b, int offset, int stereo)
406 {
407    static const celt_int16 exp2_table8[8] =
408       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
409    int qn, qb;
410    int N2 = 2*N-1;
411    if (stereo && N==2)
412       N2--;
413    qb = (b+N2*offset)/N2;
414    if (qb > (b>>1)-(1<<BITRES))
415       qb = (b>>1)-(1<<BITRES);
416
417    if (qb<0)
418        qb = 0;
419    if (qb>14<<BITRES)
420      qb = 14<<BITRES;
421
422    if (qb<(1<<BITRES>>1)) {
423       qn = 1;
424    } else {
425       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
426       qn = (qn+1)>>1<<1;
427       if (qn>1024)
428          qn = 1024;
429    }
430    return qn;
431 }
432
433
434 /* This function is responsible for encoding and decoding a band for both
435    the mono and stereo case. Even in the mono case, it can split the band
436    in two and transmit the energy difference with the two half-bands. It
437    can be called recursively so bands can end up being split in 8 parts. */
438 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
439       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
440       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level, celt_int32 *seed)
441 {
442    int q;
443    int curr_bits;
444    int stereo, split;
445    int imid=0, iside=0;
446    int N0=N;
447    int N_B=N;
448    int N_B0;
449    int B0=B;
450    int time_divide=0;
451    int recombine=0;
452
453    N_B /= B;
454    N_B0 = N_B;
455
456    split = stereo = Y != NULL;
457
458    /* Special case for one sample */
459    if (N==1)
460    {
461       int c;
462       celt_norm *x = X;
463       for (c=0;c<1+stereo;c++)
464       {
465          int sign=0;
466          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
467          {
468             if (encode)
469             {
470                sign = x[0]<0;
471                ec_enc_bits((ec_enc*)ec, sign, 1);
472             } else {
473                sign = ec_dec_bits((ec_dec*)ec, 1);
474             }
475             *remaining_bits -= 1<<BITRES;
476             b-=1<<BITRES;
477          }
478          if (resynth)
479             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
480          x = Y;
481       }
482       if (lowband_out)
483          lowband_out[0] = X[0];
484       return;
485    }
486
487    /* Band recombining to increase frequency resolution */
488    if (!stereo && B > 1 && level == 0 && tf_change>0)
489    {
490       while (B>1 && tf_change>0)
491       {
492          B>>=1;
493          N_B<<=1;
494          if (encode)
495             haar1(X, N_B, B);
496          if (lowband)
497             haar1(lowband, N_B, B);
498          recombine++;
499          tf_change--;
500       }
501       B0=B;
502       N_B0 = N_B;
503    }
504
505    /* Increasing the time resolution */
506    if (!stereo && level==0)
507    {
508       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
509       {
510          if (encode)
511             haar1(X, N_B, B);
512          if (lowband)
513             haar1(lowband, N_B, B);
514          B <<= 1;
515          N_B >>= 1;
516          time_divide++;
517          tf_change++;
518       }
519       B0 = B;
520       N_B0 = N_B;
521    }
522
523    /* Reorganize the samples in time order instead of frequency order */
524    if (!stereo && B0>1 && level==0)
525    {
526       if (encode)
527          deinterleave_vector(X, N_B, B0);
528       if (lowband)
529          deinterleave_vector(lowband, N_B, B0);
530    }
531
532    /* If we need more than 32 bits, try splitting the band in two. */
533    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
534    {
535       if (LM>0 || (N&1)==0)
536       {
537          N >>= 1;
538          Y = X+N;
539          split = 1;
540          LM -= 1;
541          B = (B+1)>>1;
542       }
543    }
544
545    if (split)
546    {
547       int qn;
548       int itheta=0;
549       int mbits, sbits, delta;
550       int qalloc;
551       celt_word16 mid, side;
552       int offset;
553
554       /* Decide on the resolution to give to the split parameter theta */
555       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
556       qn = compute_qn(N, b, offset, stereo);
557
558       qalloc = 0;
559       if (qn!=1)
560       {
561          if (encode)
562          {
563             if (stereo)
564                stereo_band_mix(m, X, Y, bandE, 0, i, 1, N);
565
566             mid = vector_norm(X, N);
567             side = vector_norm(Y, N);
568             /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
569                      Let's do that at higher complexity */
570             /*mid = renormalise_vector(X, Q15ONE, N, 1);
571             side = renormalise_vector(Y, Q15ONE, N, 1);*/
572
573             /* theta is the atan() of the ratio between the (normalized)
574                side and mid. With just that parameter, we can re-scale both
575                mid and side because we know that 1) they have unit norm and
576                2) they are orthogonal. */
577    #ifdef FIXED_POINT
578             /* 0.63662 = 2/pi */
579             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
580    #else
581             itheta = (int)floor(.5f+16384*0.63662f*atan2(side,mid));
582    #endif
583
584             itheta = (itheta*qn+8192)>>14;
585          }
586
587          /* Entropy coding of the angle. We use a uniform pdf for the
588             first stereo split but a triangular one for the rest. */
589          if (stereo || qn>256 || B>1)
590          {
591             if (encode)
592                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
593             else
594                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
595             qalloc = log2_frac(qn+1,BITRES);
596          } else {
597             int fs=1, ft;
598             ft = ((qn>>1)+1)*((qn>>1)+1);
599             if (encode)
600             {
601                int fl;
602
603                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
604                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
605                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
606
607                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
608             } else {
609                int fl=0;
610                int fm;
611                fm = ec_decode((ec_dec*)ec, ft);
612
613                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
614                {
615                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
616                   fs = itheta + 1;
617                   fl = itheta*(itheta + 1)>>1;
618                }
619                else
620                {
621                   itheta = (2*(qn + 1)
622                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
623                   fs = qn + 1 - itheta;
624                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
625                }
626
627                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
628             }
629             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
630          }
631          itheta = itheta*16384/qn;
632       } else {
633          if (stereo && encode)
634             stereo_band_mix(m, X, Y, bandE, 1, i, 1, N);
635       }
636
637       if (itheta == 0)
638       {
639          imid = 32767;
640          iside = 0;
641          delta = -10000;
642       } else if (itheta == 16384)
643       {
644          imid = 0;
645          iside = 32767;
646          delta = 10000;
647       } else {
648          imid = bitexact_cos(itheta);
649          iside = bitexact_cos(16384-itheta);
650          /* This is the mid vs side allocation that minimizes squared error
651             in that band. */
652          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
653       }
654
655       /* This is a special case for N=2 that only works for stereo and takes
656          advantage of the fact that mid and side are orthogonal to encode
657          the side with just one bit. */
658       if (N==2 && stereo)
659       {
660          int c;
661          int sign=1;
662          celt_norm *x2, *y2;
663          mbits = b-qalloc;
664          sbits = 0;
665          /* Only need one bit for the side */
666          if (itheta != 0 && itheta != 16384)
667             sbits = 1<<BITRES;
668          mbits -= sbits;
669          c = itheta > 8192;
670          *remaining_bits -= qalloc+sbits;
671
672          x2 = c ? Y : X;
673          y2 = c ? X : Y;
674          if (sbits)
675          {
676             if (encode)
677             {
678                /* Here we only need to encode a sign for the side */
679                sign = x2[0]*y2[1] - x2[1]*y2[0] > 0;
680                ec_enc_bits((ec_enc*)ec, sign, 1);
681             } else {
682                sign = ec_dec_bits((ec_dec*)ec, 1);
683             }
684          }
685          sign = 2*sign - 1;
686          quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed);
687          y2[0] = -sign*x2[1];
688          y2[1] = sign*x2[0];
689       } else {
690          /* "Normal" split code */
691          celt_norm *next_lowband2=NULL;
692          celt_norm *next_lowband_out1=NULL;
693          int next_level=0;
694
695          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
696          if (B>1 && !stereo)
697             delta >>= 1;
698
699          mbits = (b-qalloc-delta)/2;
700          if (mbits > b-qalloc)
701             mbits = b-qalloc;
702          if (mbits<0)
703             mbits=0;
704          sbits = b-qalloc-mbits;
705          *remaining_bits -= qalloc;
706
707          if (lowband && !stereo)
708             next_lowband2 = lowband+N; /* >32-bit split case */
709
710          /* Only stereo needs to pass on lowband_out. Otherwise, it's handled at the end */
711          if (stereo)
712             next_lowband_out1 = lowband_out;
713          else
714             next_level = level+1;
715
716          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level, seed);
717          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, next_level, seed);
718       }
719
720    } else {
721       /* This is the basic no-split case */
722       q = bits2pulses(m, i, LM, b);
723       curr_bits = pulses2bits(m, i, LM, q);
724       *remaining_bits -= curr_bits;
725
726       /* Ensures we can never bust the budget */
727       while (*remaining_bits < 0 && q > 0)
728       {
729          *remaining_bits += curr_bits;
730          q--;
731          curr_bits = pulses2bits(m, i, LM, q);
732          *remaining_bits -= curr_bits;
733       }
734
735       /* Finally do the actual quantization */
736       if (encode)
737          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed);
738       else
739          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed);
740    }
741
742    /* This code is used by the decoder and by the resynthesis-enabled encoder */
743    if (resynth)
744    {
745       int k;
746
747       if (split)
748       {
749          int j;
750          celt_word16 mid, side;
751 #ifdef FIXED_POINT
752          mid = imid;
753          side = iside;
754 #else
755          mid = (1.f/32768)*imid;
756          side = (1.f/32768)*iside;
757 #endif
758          for (j=0;j<N;j++)
759             X[j] = MULT16_16_Q15(X[j], mid);
760          for (j=0;j<N;j++)
761             Y[j] = MULT16_16_Q15(Y[j], side);
762       }
763
764       /* Undo the sample reorganization going from time order to frequency order */
765       if (!stereo && B0>1 && level==0)
766       {
767          interleave_vector(X, N_B, B0);
768          if (lowband)
769             interleave_vector(lowband, N_B, B0);
770       }
771
772       /* Undo time-freq changes that we did earlier */
773       N_B = N_B0;
774       B = B0;
775       for (k=0;k<time_divide;k++)
776       {
777          B >>= 1;
778          N_B <<= 1;
779          haar1(X, N_B, B);
780          if (lowband)
781             haar1(lowband, N_B, B);
782       }
783
784       for (k=0;k<recombine;k++)
785       {
786          haar1(X, N_B, B);
787          if (lowband)
788             haar1(lowband, N_B, B);
789          N_B>>=1;
790          B <<= 1;
791       }
792
793       /* Scale output for later folding */
794       if (lowband_out && !stereo)
795       {
796          int j;
797          celt_word16 n;
798          n = celt_sqrt(SHL32(EXTEND32(N0),22));
799          for (j=0;j<N0;j++)
800             lowband_out[j] = MULT16_16_Q15(n,X[j]);
801       }
802
803       if (stereo)
804       {
805          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
806          /* We only need to renormalize because quantization may not
807             have preserved orthogonality of mid and side */
808          renormalise_vector(X, N);
809          renormalise_vector(Y, N);
810       }
811    }
812 }
813
814 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
815 {
816    int i, balance;
817    celt_int32 remaining_bits;
818    const celt_int16 * restrict eBands = m->eBands;
819    celt_norm * restrict norm;
820    VARDECL(celt_norm, _norm);
821    int B;
822    int M;
823    celt_int32 seed;
824    celt_norm *lowband;
825    int update_lowband = 1;
826    int C = _Y != NULL ? 2 : 1;
827    SAVE_STACK;
828
829    M = 1<<LM;
830    B = shortBlocks ? M : 1;
831    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
832    norm = _norm;
833
834    if (encode)
835       seed = ((ec_enc*)ec)->rng;
836    else
837       seed = ((ec_dec*)ec)->rng;
838    balance = 0;
839    lowband = NULL;
840    for (i=start;i<end;i++)
841    {
842       int tell;
843       int b;
844       int N;
845       int curr_balance;
846       celt_norm * restrict X, * restrict Y;
847       int tf_change=0;
848       celt_norm *effective_lowband;
849       
850       X = _X+M*eBands[i];
851       if (_Y!=NULL)
852          Y = _Y+M*eBands[i];
853       else
854          Y = NULL;
855       N = M*eBands[i+1]-M*eBands[i];
856       if (encode)
857          tell = ec_enc_tell((ec_enc*)ec, BITRES);
858       else
859          tell = ec_dec_tell((ec_dec*)ec, BITRES);
860
861       /* Compute how many bits we want to allocate to this band */
862       if (i != start)
863          balance -= tell;
864       remaining_bits = (total_bits<<BITRES)-tell-1;
865       curr_balance = (end-i);
866       if (curr_balance > 3)
867          curr_balance = 3;
868       curr_balance = balance / curr_balance;
869       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
870       if (b<0)
871          b = 0;
872       /* Prevents ridiculous bit depths */
873       if (b > C*16*N<<BITRES)
874          b = C*16*N<<BITRES;
875
876       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
877             lowband = norm+M*eBands[i]-N;
878
879       tf_change = tf_res[i];
880       if (i>=m->effEBands)
881       {
882          X=norm;
883          if (_Y!=NULL)
884             Y = norm;
885       }
886
887       if (tf_change==0 && !shortBlocks && fold)
888          effective_lowband = NULL;
889       else
890          effective_lowband = lowband;
891       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change, effective_lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0, &seed);
892
893       balance += pulses[i] + tell;
894
895       /* Update the folding position only as long as we have 2 bit/sample depth */
896       update_lowband = (b>>BITRES)>2*N;
897    }
898    RESTORE_STACK;
899 }
900