Allowing fractional bits for splitting angle (theta)
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
49    with this approximation is important because it has an impact on the bit allocation */
50 static celt_int16 bitexact_cos(celt_int16 x)
51 {
52    celt_int32 tmp;
53    celt_int16 x2;
54    tmp = (4096+((celt_int32)(x)*(x)))>>13;
55    if (tmp > 32767)
56       tmp = 32767;
57    x2 = tmp;
58    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
59    if (x2 > 32766)
60       x2 = 32766;
61    return 1+x2;
62 }
63
64
65 #ifdef FIXED_POINT
66 /* Compute the amplitude (sqrt energy) in each of the bands */
67 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
68 {
69    int i, c, N;
70    const celt_int16 *eBands = m->eBands;
71    const int C = CHANNELS(_C);
72    N = M*m->shortMdctSize;
73    for (c=0;c<C;c++)
74    {
75       for (i=0;i<end;i++)
76       {
77          int j;
78          celt_word32 maxval=0;
79          celt_word32 sum = 0;
80          
81          j=M*eBands[i]; do {
82             maxval = MAX32(maxval, X[j+c*N]);
83             maxval = MAX32(maxval, -X[j+c*N]);
84          } while (++j<M*eBands[i+1]);
85          
86          if (maxval > 0)
87          {
88             int shift = celt_ilog2(maxval)-10;
89             j=M*eBands[i]; do {
90                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
91                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
92             } while (++j<M*eBands[i+1]);
93             /* We're adding one here to make damn sure we never end up with a pitch vector that's
94                larger than unity norm */
95             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
96          } else {
97             bank[i+c*m->nbEBands] = EPSILON;
98          }
99          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
100       }
101    }
102    /*printf ("\n");*/
103 }
104
105 /* Normalise each band such that the energy is one. */
106 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
107 {
108    int i, c, N;
109    const celt_int16 *eBands = m->eBands;
110    const int C = CHANNELS(_C);
111    N = M*m->shortMdctSize;
112    for (c=0;c<C;c++)
113    {
114       i=0; do {
115          celt_word16 g;
116          int j,shift;
117          celt_word16 E;
118          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
119          E = VSHR32(bank[i+c*m->nbEBands], shift);
120          g = EXTRACT16(celt_rcp(SHL32(E,3)));
121          j=M*eBands[i]; do {
122             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
123          } while (++j<M*eBands[i+1]);
124       } while (++i<end);
125    }
126 }
127
128 #else /* FIXED_POINT */
129 /* Compute the amplitude (sqrt energy) in each of the bands */
130 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
131 {
132    int i, c, N;
133    const celt_int16 *eBands = m->eBands;
134    const int C = CHANNELS(_C);
135    N = M*m->shortMdctSize;
136    for (c=0;c<C;c++)
137    {
138       for (i=0;i<end;i++)
139       {
140          int j;
141          celt_word32 sum = 1e-10;
142          for (j=M*eBands[i];j<M*eBands[i+1];j++)
143             sum += X[j+c*N]*X[j+c*N];
144          bank[i+c*m->nbEBands] = sqrt(sum);
145          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
146       }
147    }
148    /*printf ("\n");*/
149 }
150
151 /* Normalise each band such that the energy is one. */
152 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
153 {
154    int i, c, N;
155    const celt_int16 *eBands = m->eBands;
156    const int C = CHANNELS(_C);
157    N = M*m->shortMdctSize;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<end;i++)
161       {
162          int j;
163          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
164          for (j=M*eBands[i];j<M*eBands[i+1];j++)
165             X[j+c*N] = freq[j+c*N]*g;
166       }
167    }
168 }
169
170 #endif /* FIXED_POINT */
171
172 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
173 {
174    int i, c;
175    const celt_int16 *eBands = m->eBands;
176    const int C = CHANNELS(_C);
177    for (c=0;c<C;c++)
178    {
179       i=0; do {
180          renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
181       } while (++i<end);
182    }
183 }
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    for (c=0;c<C;c++)
194    {
195       celt_sig * restrict f;
196       const celt_norm * restrict x;
197       f = freq+c*N;
198       x = X+c*N;
199       for (i=0;i<end;i++)
200       {
201          int j, band_end;
202          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
203          j=M*eBands[i];
204          band_end = M*eBands[i+1];
205          do {
206             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
207             x++;
208          } while (++j<band_end);
209       }
210       for (i=M*eBands[m->nbEBands];i<N;i++)
211          *f++ = 0;
212    }
213 }
214
215 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
216 {
217    int j, c;
218    celt_word16 g;
219    celt_word16 delta;
220    const int C = CHANNELS(_C);
221    celt_word32 Sxy=0, Sxx=0, Syy=0;
222    int len = M*m->pitchEnd;
223    int N = M*m->shortMdctSize;
224 #ifdef FIXED_POINT
225    int shift = 0;
226    celt_word32 maxabs=0;
227
228    for (c=0;c<C;c++)
229    {
230       for (j=0;j<len;j++)
231       {
232          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
233          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
234       }
235    }
236    shift = celt_ilog2(maxabs)-12;
237    if (shift<0)
238       shift = 0;
239 #endif
240    delta = PDIV32_16(Q15ONE, len);
241    for (c=0;c<C;c++)
242    {
243       celt_word16 gg = Q15ONE;
244       for (j=0;j<len;j++)
245       {
246          celt_word16 Xj, Pj;
247          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
248          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
249          Sxy = MAC16_16(Sxy, Xj, Pj);
250          Sxx = MAC16_16(Sxx, Pj, Pj);
251          Syy = MAC16_16(Syy, Xj, Xj);
252          gg = SUB16(gg, delta);
253       }
254    }
255 #ifdef FIXED_POINT
256    {
257       celt_word32 num, den;
258       celt_word16 fact;
259       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
260       if (fact < QCONST16(1.f, 14))
261          fact = QCONST16(1.f, 14);
262       num = Sxy;
263       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
264       shift = celt_zlog2(Sxy)-16;
265       if (shift < 0)
266          shift = 0;
267       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
268          g = 0;
269       else
270          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
271
272       /* This MUST round down so that we don't over-estimate the gain */
273       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
274    }
275 #else
276    {
277       float fact = .04f*norm_rate;
278       if (fact < 1)
279          fact = 1;
280       g = Sxy/(.1f+Sxx+.03f*Syy);
281       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
282          g = 0;
283       /* This MUST round down so that we don't over-estimate the gain */
284       *gain_id = floor(20*(g-.5f));
285    }
286 #endif
287    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
288       maximum error amplification factor to 2.0 */
289    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
290    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
291    if (*gain_prod>QCONST32(2.f, 13))
292    {
293       *gain_id=9;
294       *gain_prod = QCONST32(2.f, 13);
295    }
296
297    if (*gain_id < 0)
298    {
299       *gain_id = 0;
300       return 0;
301    } else {
302       if (*gain_id > 15)
303          *gain_id = 15;
304       return 1;
305    }
306 }
307
308 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
309 {
310    int j, c, N;
311    celt_word16 gain;
312    celt_word16 delta;
313    const int C = CHANNELS(_C);
314    int len = M*m->pitchEnd;
315
316    N = M*m->shortMdctSize;
317    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
318    delta = PDIV32_16(gain, len);
319    if (pred)
320       gain = -gain;
321    else
322       delta = -delta;
323    for (c=0;c<C;c++)
324    {
325       celt_word16 gg = gain;
326       for (j=0;j<len;j++)
327       {
328          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
329          gg = ADD16(gg, delta);
330       }
331    }
332 }
333
334 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int N)
335 {
336    int i = bandID;
337    int j;
338    celt_word16 a1, a2;
339    if (stereo_mode==0)
340    {
341       /* Do mid-side when not doing intensity stereo */
342       a1 = QCONST16(.70711f,14);
343       a2 = dir*QCONST16(.70711f,14);
344    } else {
345       celt_word16 left, right;
346       celt_word16 norm;
347 #ifdef FIXED_POINT
348       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
349 #endif
350       left = VSHR32(bank[i],shift);
351       right = VSHR32(bank[i+m->nbEBands],shift);
352       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
353       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
354       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
355    }
356    for (j=0;j<N;j++)
357    {
358       celt_norm r, l;
359       l = X[j];
360       r = Y[j];
361       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
362       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
363    }
364 }
365
366 /* Decide whether we should spread the pulses in the current frame */
367 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
368 {
369    int i, c, N0;
370    int NR=0;
371    celt_word32 ratio = EPSILON;
372    const int C = CHANNELS(_C);
373    const celt_int16 * restrict eBands = m->eBands;
374    
375    N0 = M*m->shortMdctSize;
376
377    for (c=0;c<C;c++)
378    {
379    for (i=0;i<end;i++)
380    {
381       int j, N;
382       int max_i=0;
383       celt_word16 max_val=EPSILON;
384       celt_word32 floor_ener=EPSILON;
385       celt_norm * restrict x = X+M*eBands[i]+c*N0;
386       N = M*eBands[i+1]-M*eBands[i];
387       for (j=0;j<N;j++)
388       {
389          if (ABS16(x[j])>max_val)
390          {
391             max_val = ABS16(x[j]);
392             max_i = j;
393          }
394       }
395 #if 0
396       for (j=0;j<N;j++)
397       {
398          if (abs(j-max_i)>2)
399             floor_ener += x[j]*x[j];
400       }
401 #else
402       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
403       if (max_i < N-1)
404          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
405       if (max_i < N-2)
406          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
407       if (max_i > 0)
408          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
409       if (max_i > 1)
410          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
411       floor_ener = MAX32(floor_ener, EPSILON);
412 #endif
413       if (N>7)
414       {
415          celt_word16 r;
416          celt_word16 den = celt_sqrt(floor_ener);
417          den = MAX32(QCONST16(.02f, 15), den);
418          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
419          ratio = ADD32(ratio, EXTEND32(r));
420          NR++;
421       }
422    }
423    }
424    if (NR>0)
425       ratio = DIV32_16(ratio, NR);
426    ratio = ADD32(HALF32(ratio), HALF32(*average));
427    if (!*last_decision)
428    {
429       *last_decision = (ratio < QCONST16(1.8f,8));
430    } else {
431       *last_decision = (ratio < QCONST16(3.f,8));
432    }
433    *average = EXTRACT16(ratio);
434    return *last_decision;
435 }
436
437 #ifdef MEASURE_NORM_MSE
438
439 float MSE[30] = {0};
440 int nbMSEBands = 0;
441 int MSECount[30] = {0};
442
443 void dump_norm_mse(void)
444 {
445    int i;
446    for (i=0;i<nbMSEBands;i++)
447    {
448       printf ("%f ", MSE[i]/MSECount[i]);
449    }
450    printf ("\n");
451 }
452
453 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
454 {
455    static int init = 0;
456    int i;
457    if (!init)
458    {
459       atexit(dump_norm_mse);
460       init = 1;
461    }
462    for (i=0;i<m->nbEBands;i++)
463    {
464       int j;
465       int c;
466       float g;
467       if (bandE0[i]<1 || (C==2 && bandE0[i+m->nbEBands]<1))
468          continue;
469       for (c=0;c<C;c++)
470       {
471          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
472          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
473             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
474       }
475       MSECount[i]+=C;
476    }
477    nbMSEBands = m->nbEBands;
478 }
479
480 #endif
481
482 static void interleave_vector(celt_norm *X, int N0, int stride)
483 {
484    int i,j;
485    VARDECL(celt_norm, tmp);
486    int N;
487    SAVE_STACK;
488    N = N0*stride;
489    ALLOC(tmp, N, celt_norm);
490    for (i=0;i<stride;i++)
491       for (j=0;j<N0;j++)
492          tmp[j*stride+i] = X[i*N0+j];
493    for (j=0;j<N;j++)
494       X[j] = tmp[j];
495    RESTORE_STACK;
496 }
497
498 static void deinterleave_vector(celt_norm *X, int N0, int stride)
499 {
500    int i,j;
501    VARDECL(celt_norm, tmp);
502    int N;
503    SAVE_STACK;
504    N = N0*stride;
505    ALLOC(tmp, N, celt_norm);
506    for (i=0;i<stride;i++)
507       for (j=0;j<N0;j++)
508          tmp[i*N0+j] = X[j*stride+i];
509    for (j=0;j<N;j++)
510       X[j] = tmp[j];
511    RESTORE_STACK;
512 }
513
514 static void haar1(celt_norm *X, int N0, int stride)
515 {
516    int i, j;
517    N0 >>= 1;
518    for (i=0;i<stride;i++)
519       for (j=0;j<N0;j++)
520       {
521          celt_norm tmp = X[stride*2*j+i];
522          X[stride*2*j+i] = MULT16_16_Q15(QCONST16(.7070678f,15), X[stride*2*j+i] + X[stride*(2*j+1)+i]);
523          X[stride*(2*j+1)+i] = MULT16_16_Q15(QCONST16(.7070678f,15), tmp - X[stride*(2*j+1)+i]);
524       }
525 }
526
527 static int compute_qn(int N, int b, int offset, int stereo)
528 {
529    static const celt_int16 exp2_table8[8] =
530       {16384, 17867, 19484, 21247, 23170, 25268, 27554, 30048};
531    int qn, qb;
532    int N2 = 2*N-1;
533    if (stereo && N==2)
534       N2--;
535    qb = (b+N2*offset)/(N2);
536    if (qb > (b>>1)-(1<<BITRES))
537       qb = (b>>1)-(1<<BITRES);
538
539    if (qb<0)
540        qb = 0;
541    if (qb>14<<BITRES)
542      qb = 14<<BITRES;
543
544    if (qb<(1<<BITRES>>1)) {
545       qn = 1;
546    } else {
547       qn = ((1<<(qb>>BITRES))*exp2_table8[qb&0x7] + (1<<14))>>14;
548       qn = qn>>1<<1;
549       if (qn>1024)
550          qn = 1024;
551    }
552    return qn;
553 }
554
555
556 /* This function is responsible for encoding and decoding a band for both
557    the mono and stereo case. Even in the mono case, it can split the band
558    in two and transmit the energy difference with the two half-bands. It
559    can be called recursively so bands can end up being split in 8 parts. */
560 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
561       int N, int b, int spread, int B, int tf_change, celt_norm *lowband, int resynth, void *ec,
562       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level, celt_int32 *seed)
563 {
564    int q;
565    int curr_bits;
566    int stereo, split;
567    int imid=0, iside=0;
568    int N0=N;
569    int N_B=N;
570    int N_B0;
571    int B0=B;
572    int time_divide=0;
573    int recombine=0;
574
575    N_B /= B;
576    N_B0 = N_B;
577
578    split = stereo = Y != NULL;
579
580    /* Special case for one sample */
581    if (N==1)
582    {
583       int c;
584       celt_norm *x = X;
585       for (c=0;c<1+stereo;c++)
586       {
587          int sign=0;
588          if (b>=1<<BITRES && *remaining_bits>=1<<BITRES)
589          {
590             if (encode)
591             {
592                sign = x[0]<0;
593                ec_enc_bits((ec_enc*)ec, sign, 1);
594             } else {
595                sign = ec_dec_bits((ec_dec*)ec, 1);
596             }
597             *remaining_bits -= 1<<BITRES;
598             b-=1<<BITRES;
599          }
600          if (resynth)
601             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
602          x = Y;
603       }
604       if (lowband_out)
605          lowband_out[0] = X[0];
606       return;
607    }
608
609    /* Band recombining to increase frequency resolution */
610    if (!stereo && B > 1 && level == 0 && tf_change>0)
611    {
612       while (B>1 && tf_change>0)
613       {
614          B>>=1;
615          N_B<<=1;
616          if (encode)
617             haar1(X, N_B, B);
618          if (lowband)
619             haar1(lowband, N_B, B);
620          recombine++;
621          tf_change--;
622       }
623       B0=B;
624       N_B0 = N_B;
625    }
626
627    /* Increasing the time resolution */
628    if (!stereo && level==0)
629    {
630       while ((N_B&1) == 0 && tf_change<0 && B <= (1<<LM))
631       {
632          if (encode)
633             haar1(X, N_B, B);
634          if (lowband)
635             haar1(lowband, N_B, B);
636          B <<= 1;
637          N_B >>= 1;
638          time_divide++;
639          tf_change++;
640       }
641       B0 = B;
642       N_B0 = N_B;
643    }
644
645    /* Reorganize the samples in time order instead of frequency order */
646    if (!stereo && B0>1 && level==0)
647    {
648       if (encode)
649          deinterleave_vector(X, N_B, B0);
650       if (lowband)
651          deinterleave_vector(lowband, N_B, B0);
652    }
653
654    /* If we need more than 32 bits, try splitting the band in two. */
655    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
656    {
657       if (LM>0 || (N&1)==0)
658       {
659          N >>= 1;
660          Y = X+N;
661          split = 1;
662          LM -= 1;
663          B = (B+1)>>1;
664       }
665    }
666
667    if (split)
668    {
669       int qn;
670       int itheta=0;
671       int mbits, sbits, delta;
672       int qalloc;
673       celt_word16 mid, side;
674       int offset;
675
676       /* Decide on the resolution to give to the split parameter theta */
677       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
678       qn = compute_qn(N, b, offset, stereo);
679
680       qalloc = 0;
681       if (qn!=1)
682       {
683          if (encode)
684          {
685             if (stereo)
686                stereo_band_mix(m, X, Y, bandE, 0, i, 1, N);
687
688             mid = renormalise_vector(X, Q15ONE, N, 1);
689             side = renormalise_vector(Y, Q15ONE, N, 1);
690
691             /* theta is the atan() of the ratio between the (normalized)
692                side and mid. With just that parameter, we can re-scale both
693                mid and side because we know that 1) they have unit norm and
694                2) they are orthogonal. */
695    #ifdef FIXED_POINT
696             /* 0.63662 = 2/pi */
697             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
698    #else
699             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
700    #endif
701
702             itheta = (itheta*qn+8192)>>14;
703          }
704
705          /* Entropy coding of the angle. We use a uniform pdf for the
706             first stereo split but a triangular one for the rest. */
707          if (stereo || qn>256 || B>1)
708          {
709             if (encode)
710                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
711             else
712                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
713             qalloc = log2_frac(qn+1,BITRES);
714          } else {
715             int fs=1, ft;
716             ft = ((qn>>1)+1)*((qn>>1)+1);
717             if (encode)
718             {
719                int fl;
720
721                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
722                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
723                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
724
725                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
726             } else {
727                int fl=0;
728                int fm;
729                fm = ec_decode((ec_dec*)ec, ft);
730
731                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
732                {
733                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
734                   fs = itheta + 1;
735                   fl = itheta*(itheta + 1)>>1;
736                }
737                else
738                {
739                   itheta = (2*(qn + 1)
740                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
741                   fs = qn + 1 - itheta;
742                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
743                }
744
745                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
746             }
747             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
748          }
749          itheta = itheta*16384/qn;
750       } else {
751          if (stereo && encode)
752             stereo_band_mix(m, X, Y, bandE, 1, i, 1, N);
753       }
754
755       if (itheta == 0)
756       {
757          imid = 32767;
758          iside = 0;
759          delta = -10000;
760       } else if (itheta == 16384)
761       {
762          imid = 0;
763          iside = 32767;
764          delta = 10000;
765       } else {
766          imid = bitexact_cos(itheta);
767          iside = bitexact_cos(16384-itheta);
768          /* This is the mid vs side allocation that minimizes squared error
769             in that band. */
770          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
771       }
772
773       /* This is a special case for N=2 that only works for stereo and takes
774          advantage of the fact that mid and side are orthogonal to encode
775          the side with just one bit. */
776       if (N==2 && stereo)
777       {
778          int c, c2;
779          int sign=1;
780          celt_norm v[2], w[2];
781          celt_norm *x2, *y2;
782          mbits = b-qalloc;
783          sbits = 0;
784          /* Only need one bit for the side */
785          if (itheta != 0 && itheta != 16384)
786             sbits = 1<<BITRES;
787          mbits -= sbits;
788          c = itheta > 8192 ? 1 : 0;
789          *remaining_bits -= qalloc+sbits;
790
791          x2 = X;
792          y2 = Y;
793          if (encode)
794          {
795             c2 = 1-c;
796
797             /* v is the largest vector between mid and side. w is the other */
798             if (c==0)
799             {
800                v[0] = x2[0];
801                v[1] = x2[1];
802                w[0] = y2[0];
803                w[1] = y2[1];
804             } else {
805                v[0] = y2[0];
806                v[1] = y2[1];
807                w[0] = x2[0];
808                w[1] = x2[1];
809             }
810             /* Here we only need to encode a sign for the side */
811             if (v[0]*w[1] - v[1]*w[0] > 0)
812                sign = 1;
813             else
814                sign = -1;
815          }
816          quant_band(encode, m, i, v, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level+1, seed);
817          if (sbits)
818          {
819             if (encode)
820             {
821                ec_enc_bits((ec_enc*)ec, sign==1, 1);
822             } else {
823                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
824             }
825          } else {
826             sign = 1;
827          }
828          w[0] = -sign*v[1];
829          w[1] = sign*v[0];
830          if (c==0)
831          {
832             x2[0] = v[0];
833             x2[1] = v[1];
834             y2[0] = w[0];
835             y2[1] = w[1];
836          } else {
837             x2[0] = w[0];
838             x2[1] = w[1];
839             y2[0] = v[0];
840             y2[1] = v[1];
841          }
842       } else
843       {
844          /* "Normal" split code */
845          celt_norm *next_lowband2=NULL;
846          celt_norm *next_lowband_out1=NULL;
847          int next_level=0;
848
849          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
850          if (B>1 && !stereo)
851             delta >>= 1;
852
853          mbits = (b-qalloc-delta)/2;
854          if (mbits > b-qalloc)
855             mbits = b-qalloc;
856          if (mbits<0)
857             mbits=0;
858          sbits = b-qalloc-mbits;
859          *remaining_bits -= qalloc;
860
861          if (lowband && !stereo)
862             next_lowband2 = lowband+N; /* >32-bit split case */
863
864          /* Only stereo needs to pass on lowband_out. Otherwise, it's handled at the end */
865          if (stereo)
866             next_lowband_out1 = lowband_out;
867          else
868             next_level = level+1;
869
870          quant_band(encode, m, i, X, NULL, N, mbits, spread, B, tf_change, lowband, resynth, ec, remaining_bits, LM, next_lowband_out1, NULL, next_level, seed);
871          quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, tf_change, next_lowband2, resynth, ec, remaining_bits, LM, NULL, NULL, level, seed);
872       }
873
874    } else {
875       /* This is the basic no-split case */
876       q = bits2pulses(m, m->bits[LM][i], N, b);
877       curr_bits = pulses2bits(m->bits[LM][i], N, q);
878       *remaining_bits -= curr_bits;
879
880       /* Ensures we can never bust the budget */
881       while (*remaining_bits < 0 && q > 0)
882       {
883          *remaining_bits += curr_bits;
884          q--;
885          curr_bits = pulses2bits(m->bits[LM][i], N, q);
886          *remaining_bits -= curr_bits;
887       }
888
889       /* Finally do the actual quantization */
890       if (encode)
891          alg_quant(X, N, q, spread, B, lowband, resynth, (ec_enc*)ec, seed);
892       else
893          alg_unquant(X, N, q, spread, B, lowband, (ec_dec*)ec, seed);
894    }
895
896    /* This code is used by the decoder and by the resynthesis-enabled encoder */
897    if (resynth)
898    {
899       int k;
900
901       if (split)
902       {
903          int j;
904          celt_word16 mid, side;
905 #ifdef FIXED_POINT
906          mid = imid;
907          side = iside;
908 #else
909          mid = (1.f/32768)*imid;
910          side = (1.f/32768)*iside;
911 #endif
912          for (j=0;j<N;j++)
913             X[j] = MULT16_16_Q15(X[j], mid);
914          for (j=0;j<N;j++)
915             Y[j] = MULT16_16_Q15(Y[j], side);
916       }
917
918       /* Undo the sample reorganization going from time order to frequency order */
919       if (!stereo && B0>1 && level==0)
920       {
921          interleave_vector(X, N_B, B0);
922          if (lowband)
923             interleave_vector(lowband, N_B, B0);
924       }
925
926       /* Undo time-freq changes that we did earlier */
927       N_B = N_B0;
928       B = B0;
929       for (k=0;k<time_divide;k++)
930       {
931          B >>= 1;
932          N_B <<= 1;
933          haar1(X, N_B, B);
934          if (lowband)
935             haar1(lowband, N_B, B);
936       }
937
938       for (k=0;k<recombine;k++)
939       {
940          haar1(X, N_B, B);
941          if (lowband)
942             haar1(lowband, N_B, B);
943          N_B>>=1;
944          B <<= 1;
945       }
946
947       /* Scale output for later folding */
948       if (lowband_out && !stereo)
949       {
950          int j;
951          celt_word16 n;
952          n = celt_sqrt(SHL32(EXTEND32(N0),22));
953          for (j=0;j<N0;j++)
954             lowband_out[j] = MULT16_16_Q15(n,X[j]);
955       }
956
957       if (stereo)
958       {
959          stereo_band_mix(m, X, Y, bandE, 0, i, -1, N);
960          /* We only need to renormalize because quantization may not
961             have preserved orthogonality of mid and side */
962          renormalise_vector(X, Q15ONE, N, 1);
963          renormalise_vector(Y, Q15ONE, N, 1);
964       }
965    }
966 }
967
968 void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
969 {
970    int i, balance;
971    celt_int32 remaining_bits;
972    const celt_int16 * restrict eBands = m->eBands;
973    celt_norm * restrict norm;
974    VARDECL(celt_norm, _norm);
975    int B;
976    int M;
977    int spread;
978    celt_int32 seed;
979    celt_norm *lowband;
980    int update_lowband = 1;
981    int C = _Y != NULL ? 2 : 1;
982    SAVE_STACK;
983
984    M = 1<<LM;
985    B = shortBlocks ? M : 1;
986    spread = fold ? B : 0;
987    ALLOC(_norm, M*eBands[m->nbEBands], celt_norm);
988    norm = _norm;
989
990    if (encode)
991       seed = ((ec_enc*)ec)->rng;
992    else
993       seed = ((ec_dec*)ec)->rng;
994    balance = 0;
995    lowband = NULL;
996    for (i=start;i<end;i++)
997    {
998       int tell;
999       int b;
1000       int N;
1001       int curr_balance;
1002       celt_norm * restrict X, * restrict Y;
1003       int tf_change=0;
1004       celt_norm *effective_lowband;
1005       
1006       X = _X+M*eBands[i];
1007       if (_Y!=NULL)
1008          Y = _Y+M*eBands[i];
1009       else
1010          Y = NULL;
1011       N = M*eBands[i+1]-M*eBands[i];
1012       if (encode)
1013          tell = ec_enc_tell((ec_enc*)ec, BITRES);
1014       else
1015          tell = ec_dec_tell((ec_dec*)ec, BITRES);
1016
1017       /* Compute how many bits we want to allocate to this band */
1018       if (i != start)
1019          balance -= tell;
1020       remaining_bits = (total_bits<<BITRES)-tell-1;
1021       curr_balance = (end-i);
1022       if (curr_balance > 3)
1023          curr_balance = 3;
1024       curr_balance = balance / curr_balance;
1025       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
1026       if (b<0)
1027          b = 0;
1028       /* Prevents ridiculous bit depths */
1029       if (b > C*16*N<<BITRES)
1030          b = C*16*N<<BITRES;
1031
1032       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband==NULL))
1033             lowband = norm+M*eBands[i]-N;
1034
1035       tf_change = tf_res[i];
1036       if (i>=m->effEBands)
1037       {
1038          X=norm;
1039          if (_Y!=NULL)
1040             Y = norm;
1041       }
1042
1043       if (tf_change==0 && !shortBlocks && fold)
1044          effective_lowband = NULL;
1045       else
1046          effective_lowband = lowband;
1047       quant_band(encode, m, i, X, Y, N, b, fold, B, tf_change, effective_lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0, &seed);
1048
1049       balance += pulses[i] + tell;
1050
1051       /* Update the folding position only as long as we have 2 bit/sample depth */
1052       update_lowband = (b>>BITRES)>2*N;
1053    }
1054    RESTORE_STACK;
1055 }
1056