Entropy-coding the new split parameter.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48
49 #ifdef FIXED_POINT
50 /* Compute the amplitude (sqrt energy) in each of the bands */
51 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
52 {
53    int i, c, N;
54    const celt_int16 *eBands = m->eBands;
55    const int C = CHANNELS(_C);
56    N = M*m->eBands[m->nbEBands+1];
57    for (c=0;c<C;c++)
58    {
59       for (i=0;i<m->nbEBands;i++)
60       {
61          int j;
62          celt_word32 maxval=0;
63          celt_word32 sum = 0;
64          
65          j=M*eBands[i]; do {
66             maxval = MAX32(maxval, X[j+c*N]);
67             maxval = MAX32(maxval, -X[j+c*N]);
68          } while (++j<M*eBands[i+1]);
69          
70          if (maxval > 0)
71          {
72             int shift = celt_ilog2(maxval)-10;
73             j=M*eBands[i]; do {
74                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
75                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
76             } while (++j<M*eBands[i+1]);
77             /* We're adding one here to make damn sure we never end up with a pitch vector that's
78                larger than unity norm */
79             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
80          } else {
81             bank[i+c*m->nbEBands] = EPSILON;
82          }
83          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
84       }
85    }
86    /*printf ("\n");*/
87 }
88
89 /* Normalise each band such that the energy is one. */
90 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
91 {
92    int i, c, N;
93    const celt_int16 *eBands = m->eBands;
94    const int C = CHANNELS(_C);
95    N = M*m->eBands[m->nbEBands+1];
96    for (c=0;c<C;c++)
97    {
98       i=0; do {
99          celt_word16 g;
100          int j,shift;
101          celt_word16 E;
102          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
103          E = VSHR32(bank[i+c*m->nbEBands], shift);
104          g = EXTRACT16(celt_rcp(SHL32(E,3)));
105          j=M*eBands[i]; do {
106             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
107          } while (++j<M*eBands[i+1]);
108       } while (++i<m->nbEBands);
109    }
110 }
111
112 #else /* FIXED_POINT */
113 /* Compute the amplitude (sqrt energy) in each of the bands */
114 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
115 {
116    int i, c, N;
117    const celt_int16 *eBands = m->eBands;
118    const int C = CHANNELS(_C);
119    N = M*m->eBands[m->nbEBands+1];
120    for (c=0;c<C;c++)
121    {
122       for (i=0;i<m->nbEBands;i++)
123       {
124          int j;
125          celt_word32 sum = 1e-10;
126          for (j=M*eBands[i];j<M*eBands[i+1];j++)
127             sum += X[j+c*N]*X[j+c*N];
128          bank[i+c*m->nbEBands] = sqrt(sum);
129          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
130       }
131    }
132    /*printf ("\n");*/
133 }
134
135 #ifdef EXP_PSY
136 void compute_noise_energies(const CELTMode *m, const celt_sig *X, const celt_word16 *tonality, celt_ener *bank, int _C, int M)
137 {
138    int i, c, N;
139    const celt_int16 *eBands = m->eBands;
140    const int C = CHANNELS(_C);
141    N = M*m->eBands[m->nbEBands+1];
142    for (c=0;c<C;c++)
143    {
144       for (i=0;i<m->nbEBands;i++)
145       {
146          int j;
147          celt_word32 sum = 1e-10;
148          for (j=M*eBands[i];j<M*eBands[i+1];j++)
149             sum += X[j*C+c]*X[j+c*N]*tonality[j];
150          bank[i+c*m->nbEBands] = sqrt(sum);
151          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
152       }
153    }
154    /*printf ("\n");*/
155 }
156 #endif
157
158 /* Normalise each band such that the energy is one. */
159 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
160 {
161    int i, c, N;
162    const celt_int16 *eBands = m->eBands;
163    const int C = CHANNELS(_C);
164    N = M*m->eBands[m->nbEBands+1];
165    for (c=0;c<C;c++)
166    {
167       for (i=0;i<m->nbEBands;i++)
168       {
169          int j;
170          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
171          for (j=M*eBands[i];j<M*eBands[i+1];j++)
172             X[j+c*N] = freq[j+c*N]*g;
173       }
174    }
175 }
176
177 #endif /* FIXED_POINT */
178
179 void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int _C, int M)
180 {
181    int i, c;
182    const celt_int16 *eBands = m->eBands;
183    const int C = CHANNELS(_C);
184    for (c=0;c<C;c++)
185    {
186       i=0; do {
187          renormalise_vector(X+M*eBands[i]+c*M*eBands[m->nbEBands+1], Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
188       } while (++i<m->nbEBands);
189    }
190 }
191
192 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
193 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int _C, int M)
194 {
195    int i, c, N;
196    const celt_int16 *eBands = m->eBands;
197    const int C = CHANNELS(_C);
198    N = M*m->eBands[m->nbEBands+1];
199    if (C>2)
200       celt_fatal("denormalise_bands() not implemented for >2 channels");
201    for (c=0;c<C;c++)
202    {
203       celt_sig * restrict f;
204       const celt_norm * restrict x;
205       f = freq+c*N;
206       x = X+c*N;
207       for (i=0;i<m->nbEBands;i++)
208       {
209          int j, end;
210          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
211          j=M*eBands[i];
212          end = M*eBands[i+1];
213          do {
214             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
215             x++;
216          } while (++j<end);
217       }
218       for (i=M*eBands[m->nbEBands];i<M*eBands[m->nbEBands+1];i++)
219          *f++ = 0;
220    }
221 }
222
223 int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P, int norm_rate, int *gain_id, int _C, celt_word16 *gain_prod, int M)
224 {
225    int j, c;
226    celt_word16 g;
227    celt_word16 delta;
228    const int C = CHANNELS(_C);
229    celt_word32 Sxy=0, Sxx=0, Syy=0;
230    int len = M*m->pitchEnd;
231    int N = M*m->eBands[m->nbEBands+1];
232 #ifdef FIXED_POINT
233    int shift = 0;
234    celt_word32 maxabs=0;
235
236    for (c=0;c<C;c++)
237    {
238       for (j=0;j<len;j++)
239       {
240          maxabs = MAX32(maxabs, ABS32(X[j+c*N]));
241          maxabs = MAX32(maxabs, ABS32(P[j+c*N]));
242       }
243    }
244    shift = celt_ilog2(maxabs)-12;
245    if (shift<0)
246       shift = 0;
247 #endif
248    delta = PDIV32_16(Q15ONE, len);
249    for (c=0;c<C;c++)
250    {
251       celt_word16 gg = Q15ONE;
252       for (j=0;j<len;j++)
253       {
254          celt_word16 Xj, Pj;
255          Xj = EXTRACT16(SHR32(X[j+c*N], shift));
256          Pj = MULT16_16_P15(gg,EXTRACT16(SHR32(P[j+c*N], shift)));
257          Sxy = MAC16_16(Sxy, Xj, Pj);
258          Sxx = MAC16_16(Sxx, Pj, Pj);
259          Syy = MAC16_16(Syy, Xj, Xj);
260          gg = SUB16(gg, delta);
261       }
262    }
263 #ifdef FIXED_POINT
264    {
265       celt_word32 num, den;
266       celt_word16 fact;
267       fact = MULT16_16(QCONST16(.04f, 14), norm_rate);
268       if (fact < QCONST16(1.f, 14))
269          fact = QCONST16(1.f, 14);
270       num = Sxy;
271       den = EPSILON+Sxx+MULT16_32_Q15(QCONST16(.03f,15),Syy);
272       shift = celt_zlog2(Sxy)-16;
273       if (shift < 0)
274          shift = 0;
275       if (Sxy < MULT16_32_Q15(fact, MULT16_16(celt_sqrt(EPSILON+Sxx),celt_sqrt(EPSILON+Syy))))
276          g = 0;
277       else
278          g = DIV32(SHL32(SHR32(num,shift),14),ADD32(EPSILON,SHR32(den,shift)));
279
280       /* This MUST round down so that we don't over-estimate the gain */
281       *gain_id = EXTRACT16(SHR32(MULT16_16(20,(g-QCONST16(.5f,14))),14));
282    }
283 #else
284    {
285       float fact = .04f*norm_rate;
286       if (fact < 1)
287          fact = 1;
288       g = Sxy/(.1f+Sxx+.03f*Syy);
289       if (Sxy < .5f*fact*celt_sqrt(1+Sxx*Syy))
290          g = 0;
291       /* This MUST round down so that we don't over-estimate the gain */
292       *gain_id = floor(20*(g-.5f));
293    }
294 #endif
295    /* This prevents the pitch gain from being above 1.0 for too long by bounding the 
296       maximum error amplification factor to 2.0 */
297    g = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),*gain_id));
298    *gain_prod = MAX16(QCONST32(1.f, 13), MULT16_16_Q14(*gain_prod,g));
299    if (*gain_prod>QCONST32(2.f, 13))
300    {
301       *gain_id=9;
302       *gain_prod = QCONST32(2.f, 13);
303    }
304
305    if (*gain_id < 0)
306    {
307       *gain_id = 0;
308       return 0;
309    } else {
310       if (*gain_id > 15)
311          *gain_id = 15;
312       return 1;
313    }
314 }
315
316 void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id, int pred, int _C, int M)
317 {
318    int j, c, N;
319    celt_word16 gain;
320    celt_word16 delta;
321    const int C = CHANNELS(_C);
322    int len = M*m->pitchEnd;
323
324    N = M*m->eBands[m->nbEBands+1];
325    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
326    delta = PDIV32_16(gain, len);
327    if (pred)
328       gain = -gain;
329    else
330       delta = -delta;
331    for (c=0;c<C;c++)
332    {
333       celt_word16 gg = gain;
334       for (j=0;j<len;j++)
335       {
336          X[j+c*N] += SHL32(MULT16_32_Q15(gg,P[j+c*N]),1);
337          gg = ADD16(gg, delta);
338       }
339    }
340 }
341
342 #ifndef DISABLE_STEREO
343
344 static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int stereo_mode, int bandID, int dir, int M)
345 {
346    int i = bandID;
347    const celt_int16 *eBands = m->eBands;
348    int j;
349    celt_word16 a1, a2;
350    if (stereo_mode==0)
351    {
352       /* Do mid-side when not doing intensity stereo */
353       a1 = QCONST16(.70711f,14);
354       a2 = dir*QCONST16(.70711f,14);
355    } else {
356       celt_word16 left, right;
357       celt_word16 norm;
358 #ifdef FIXED_POINT
359       int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
360 #endif
361       left = VSHR32(bank[i],shift);
362       right = VSHR32(bank[i+m->nbEBands],shift);
363       norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
364       a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
365       a2 = dir*DIV32_16(SHL32(EXTEND32(right),14),norm);
366    }
367    for (j=0;j<M*eBands[i+1]-M*eBands[i];j++)
368    {
369       celt_norm r, l;
370       l = X[j];
371       r = Y[j];
372       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
373       Y[j] = MULT16_16_Q14(a1,r) - MULT16_16_Q14(a2,l);
374    }
375 }
376
377
378 #endif /* DISABLE_STEREO */
379
380 int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int _C, int M)
381 {
382    int i, c, N0;
383    int NR=0;
384    celt_word32 ratio = EPSILON;
385    const int C = CHANNELS(_C);
386    const celt_int16 * restrict eBands = m->eBands;
387    
388    N0 = M*m->eBands[m->nbEBands+1];
389
390    for (c=0;c<C;c++)
391    {
392    for (i=0;i<m->nbEBands;i++)
393    {
394       int j, N;
395       int max_i=0;
396       celt_word16 max_val=EPSILON;
397       celt_word32 floor_ener=EPSILON;
398       celt_norm * restrict x = X+M*eBands[i]+c*N0;
399       N = M*eBands[i+1]-M*eBands[i];
400       for (j=0;j<N;j++)
401       {
402          if (ABS16(x[j])>max_val)
403          {
404             max_val = ABS16(x[j]);
405             max_i = j;
406          }
407       }
408 #if 0
409       for (j=0;j<N;j++)
410       {
411          if (abs(j-max_i)>2)
412             floor_ener += x[j]*x[j];
413       }
414 #else
415       floor_ener = QCONST32(1.,28)-MULT16_16(max_val,max_val);
416       if (max_i < N-1)
417          floor_ener -= MULT16_16(x[(max_i+1)], x[(max_i+1)]);
418       if (max_i < N-2)
419          floor_ener -= MULT16_16(x[(max_i+2)], x[(max_i+2)]);
420       if (max_i > 0)
421          floor_ener -= MULT16_16(x[(max_i-1)], x[(max_i-1)]);
422       if (max_i > 1)
423          floor_ener -= MULT16_16(x[(max_i-2)], x[(max_i-2)]);
424       floor_ener = MAX32(floor_ener, EPSILON);
425 #endif
426       if (N>7)
427       {
428          celt_word16 r;
429          celt_word16 den = celt_sqrt(floor_ener);
430          den = MAX32(QCONST16(.02f, 15), den);
431          r = DIV32_16(SHL32(EXTEND32(max_val),8),den);
432          ratio = ADD32(ratio, EXTEND32(r));
433          NR++;
434       }
435    }
436    }
437    if (NR>0)
438       ratio = DIV32_16(ratio, NR);
439    ratio = ADD32(HALF32(ratio), HALF32(*average));
440    if (!*last_decision)
441    {
442       *last_decision = (ratio < QCONST16(1.8f,8));
443    } else {
444       *last_decision = (ratio < QCONST16(3.f,8));
445    }
446    *average = EXTRACT16(ratio);
447    return *last_decision;
448 }
449
450 void quant_band(const CELTMode *m, int i, celt_norm *X, celt_norm *Y, int N, int b, int spread, celt_norm *lowband, int resynth, ec_enc *enc, celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE)
451 {
452    int q;
453    int curr_bits;
454    int stereo, split;
455    int imid=0, iside=0;
456    int N0=N;
457
458    split = stereo = Y != NULL;
459
460    if (b>(32<<BITRES) && !stereo && LM>0)
461    {
462       N /= 2;
463       Y = X+N;
464       split = 1;
465       LM -= 1;
466    }
467
468    if (split)
469    {
470       int qb;
471       int itheta;
472       int mbits, sbits, delta;
473       int qalloc;
474       celt_word16 mid, side;
475       if (N>1)
476          qb = (b-2*(N-1)*(QTHETA_OFFSET-m->logN[i]-(LM<<BITRES)))/(32*(N-1));
477       else
478          qb = b-2;
479       if (qb > (b>>BITRES)-1)
480          qb = (b>>BITRES)-1;
481       if (qb<0)
482          qb = 0;
483       if (qb>14)
484          qb = 14;
485
486       if (stereo)
487          stereo_band_mix(m, X, Y, bandE, qb==0, i, 1, 1<<LM);
488
489       mid = renormalise_vector(X, Q15ONE, N, 1);
490       side = renormalise_vector(Y, Q15ONE, N, 1);
491       /* 0.63662 = 2/pi */
492 #ifdef FIXED_POINT
493       itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
494 #else
495       itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
496 #endif
497       qalloc = log2_frac((1<<qb)+1,BITRES);
498       if (qb==0)
499       {
500          itheta=0;
501       } else {
502          int shift;
503          shift = 14-qb;
504          itheta = (itheta+(1<<shift>>1))>>shift;
505          if (stereo || qb>9)
506             ec_enc_uint(enc, itheta, (1<<qb)+1);
507          else {
508             int j;
509             int fl=0, fs=1, ft;
510             j=0;
511             while(1)
512             {
513                if (j==itheta)
514                   break;
515                fl+=fs;
516                if (j<(1<<qb>>1))
517                   fs++;
518                else
519                   fs--;
520                j++;
521             }
522             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
523             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
524             ec_encode(enc, fl, fl+fs, ft);
525          }
526          itheta <<= shift;
527       }
528       if (itheta == 0)
529       {
530          imid = 32767;
531          iside = 0;
532          delta = -10000;
533       } else if (itheta == 16384)
534       {
535          imid = 0;
536          iside = 32767;
537          delta = 10000;
538       } else {
539          imid = bitexact_cos(itheta);
540          iside = bitexact_cos(16384-itheta);
541          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
542       }
543 #if 1
544       if (N==2 && stereo)
545       {
546          int c, c2;
547          int sign=1;
548          celt_norm v[2], w[2];
549          celt_norm *x2, *y2;
550          mbits = b-qalloc;
551          sbits = 0;
552          if (itheta != 0 && itheta != 16384)
553             sbits = 1<<BITRES;
554          mbits -= sbits;
555          c = itheta > 8192 ? 1 : 0;
556          c2 = 1-c;
557
558          x2 = X;
559          y2 = Y;
560          if (c==0)
561          {
562             v[0] = x2[0];
563             v[1] = x2[1];
564             w[0] = y2[0];
565             w[1] = y2[1];
566          } else {
567             v[0] = y2[0];
568             v[1] = y2[1];
569             w[0] = x2[0];
570             w[1] = x2[1];
571          }
572          *remaining_bits -= qalloc+sbits;
573          quant_band(m, i, v, NULL, N, mbits, spread, lowband, resynth, enc, remaining_bits, LM, NULL, NULL);
574          if (sbits)
575          {
576             if (v[0]*w[1] - v[1]*w[0] > 0)
577                sign = 1;
578             else
579                sign = -1;
580             ec_enc_bits(enc, sign==1, 1);
581          } else {
582             sign = 1;
583          }
584          w[0] = -sign*v[1];
585          w[1] = sign*v[0];
586          if (c==0)
587          {
588             x2[0] = v[0];
589             x2[1] = v[1];
590             y2[0] = w[0];
591             y2[1] = w[1];
592          } else {
593             x2[0] = w[0];
594             x2[1] = w[1];
595             y2[0] = v[0];
596             y2[1] = v[1];
597          }
598       } else
599 #endif
600       {
601
602          mbits = (b-qalloc/2-delta)/2;
603          if (mbits > b-qalloc)
604             mbits = b-qalloc;
605          if (mbits<0)
606             mbits=0;
607          sbits = b-qalloc-mbits;
608          *remaining_bits -= qalloc;
609          quant_band(m, i, X, NULL, N, mbits, spread, lowband, resynth, enc, remaining_bits, LM, NULL, NULL);
610          if (stereo)
611             quant_band(m, i, Y, NULL, N, sbits, spread, NULL, resynth, enc, remaining_bits, LM, NULL, NULL);
612          else
613             quant_band(m, i, Y, NULL, N, sbits, spread, lowband ? lowband+N : NULL, resynth, enc, remaining_bits, LM, NULL, NULL);
614       }
615
616    } else {
617       q = bits2pulses(m, m->bits[LM][i], N, b);
618       curr_bits = pulses2bits(m->bits[LM][i], N, q);
619       *remaining_bits -= curr_bits;
620       while (*remaining_bits < 0 && q > 0)
621       {
622          *remaining_bits += curr_bits;
623          q--;
624          curr_bits = pulses2bits(m->bits[LM][i], N, q);
625          *remaining_bits -= curr_bits;
626       }
627       alg_quant(X, N, q, spread, lowband, resynth, enc);
628    }
629
630    if (resynth && lowband_out)
631    {
632       int j;
633       celt_word16 n;
634       n = celt_sqrt(SHL32(EXTEND32(N0),22));
635       for (j=0;j<N0;j++)
636          lowband_out[j] = MULT16_16_Q15(n,X[j]);
637    }
638
639    if (split && resynth)
640    {
641       int j;
642       celt_word16 mid, side;
643 #ifdef FIXED_POINT
644       mid = imid;
645       side = iside;
646 #else
647       mid = (1.f/32768)*imid;
648       side = (1.f/32768)*iside;
649 #endif
650       for (j=0;j<N;j++)
651          X[j] = MULT16_16_Q15(X[j], mid);
652       for (j=0;j<N;j++)
653          Y[j] = MULT16_16_Q15(Y[j], side);
654
655    }
656 }
657
658 void unquant_band(const CELTMode *m, int i, celt_norm *X, celt_norm *Y, int N, int b,
659                  int spread, celt_norm *lowband, ec_dec *dec,
660                  celt_int32 *remaining_bits, int LM, celt_norm *lowband_out)
661 {
662    int q;
663    int curr_bits;
664    int stereo, split;
665    int imid=0, iside=0;
666    int N0=N;
667
668    split = stereo = Y != NULL;
669
670    if (b>(32<<BITRES) && !stereo && LM>0)
671    {
672       N /= 2;
673       Y = X+N;
674       split = 1;
675       LM -= 1;
676    }
677
678    if (split)
679    {
680       int itheta;
681       int mbits, sbits, delta;
682       int qalloc, qb;
683       if (N>1)
684          qb = (b-2*(N-1)*(QTHETA_OFFSET-m->logN[i]-(LM<<BITRES)))/(32*(N-1));
685       else
686          qb = b-2;
687       if (qb > (b>>BITRES)-1)
688          qb = (b>>BITRES)-1;
689       if (qb>14)
690          qb = 14;
691       if (qb<0)
692          qb = 0;
693       qalloc = log2_frac((1<<qb)+1,BITRES);
694       if (qb==0)
695       {
696          itheta=0;
697       } else {
698          int shift;
699          shift = 14-qb;
700          if (stereo || qb>9)
701             itheta = ec_dec_uint(dec, (1<<qb)+1);
702          else {
703             int fs=1, fl=0;
704             int j, fm, ft;
705             ft = ((1<<qb>>1)+1)*((1<<qb>>1)+1);
706             fm = ec_decode(dec, ft);
707             j=0;
708             while (1)
709             {
710                if (fm < fl+fs)
711                   break;
712                fl+=fs;
713                if (j<(1<<qb>>1))
714                   fs++;
715                else
716                   fs--;
717                j++;
718             }
719             itheta = j;
720             qalloc = log2_frac(ft,BITRES) - log2_frac(fs,BITRES) + 1;
721             ec_dec_update(dec, fl, fl+fs, ft);
722          }
723          itheta <<= shift;
724       }
725       if (itheta == 0)
726       {
727          imid = 32767;
728          iside = 0;
729          delta = -10000;
730       } else if (itheta == 16384)
731       {
732          imid = 0;
733          iside = 32767;
734          delta = 10000;
735       } else {
736          imid = bitexact_cos(itheta);
737          iside = bitexact_cos(16384-itheta);
738          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
739       }
740
741 #if 1
742       if (N==2 && stereo)
743       {
744          int c;
745          int sign=1;
746          celt_norm v[2], w[2];
747          celt_norm *x2, *y2;
748          mbits = b-qalloc;
749          sbits = 0;
750          if (itheta != 0 && itheta != 16384)
751             sbits = 1<<BITRES;
752          mbits -= sbits;
753          c = itheta > 8192 ? 1 : 0;
754
755          x2 = X;
756          y2 = Y;
757          *remaining_bits -= qalloc+sbits;
758          unquant_band(m, i, v, NULL, N, mbits, spread, lowband, dec, remaining_bits, LM, NULL);
759          if (sbits)
760             sign = 2*ec_dec_bits(dec, 1)-1;
761          else
762             sign = 1;
763          w[0] = -sign*v[1];
764          w[1] = sign*v[0];
765          if (c==0)
766          {
767             x2[0] = v[0];
768             x2[1] = v[1];
769             y2[0] = w[0];
770             y2[1] = w[1];
771          } else {
772             x2[0] = w[0];
773             x2[1] = w[1];
774             y2[0] = v[0];
775             y2[1] = v[1];
776          }
777       } else
778 #endif
779       {
780          mbits = (b-qalloc/2-delta)/2;
781          if (mbits > b-qalloc)
782             mbits = b-qalloc;
783          if (mbits<0)
784             mbits=0;
785          sbits = b-qalloc-mbits;
786          *remaining_bits -= qalloc;
787          unquant_band(m, i, X, NULL, N, mbits, spread, lowband, dec, remaining_bits, LM, NULL);
788          if (stereo)
789             unquant_band(m, i, Y, NULL, N, sbits, spread, NULL, dec, remaining_bits, LM, NULL);
790          else
791             unquant_band(m, i, Y, NULL, N, sbits, spread, lowband ? lowband+N : NULL, dec, remaining_bits, LM, NULL);
792       }
793    } else {
794
795       q = bits2pulses(m, m->bits[LM][i], N, b);
796       curr_bits = pulses2bits(m->bits[LM][i], N, q);
797       *remaining_bits -= curr_bits;
798       while (*remaining_bits < 0 && q > 0)
799       {
800          *remaining_bits += curr_bits;
801          q--;
802          curr_bits = pulses2bits(m->bits[LM][i], N, q);
803          *remaining_bits -= curr_bits;
804       }
805       alg_unquant(X, N, q, spread, lowband, dec);
806    }
807
808    if (lowband_out)
809    {
810       celt_word16 n;
811       int j;
812       n = celt_sqrt(SHL32(EXTEND32(N0),22));
813       for (j=0;j<N0;j++)
814          lowband_out[j] = MULT16_16_Q15(n,X[j]);
815    }
816    if (split)
817    {
818       int j;
819       celt_word16 mid, side;
820 #ifdef FIXED_POINT
821       mid = imid;
822       side = iside;
823 #else
824       mid = (1.f/32768)*imid;
825       side = (1.f/32768)*iside;
826 #endif
827       for (j=0;j<N;j++)
828          X[j] = MULT16_16_Q15(X[j], mid);
829       for (j=0;j<N;j++)
830          Y[j] = MULT16_16_Q15(Y[j], side);
831
832    }
833 }
834
835 /* Quantisation of the residual */
836 void quant_all_bands(const CELTMode *m, int start, celt_norm * restrict X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, int encode, void *enc, int LM)
837 {
838    int i, remaining_bits, balance;
839    const celt_int16 * restrict eBands = m->eBands;
840    celt_norm * restrict norm;
841    VARDECL(celt_norm, _norm);
842    int B;
843    int M;
844    int spread;
845    SAVE_STACK;
846
847    M = 1<<LM;
848    B = shortBlocks ? M : 1;
849    spread = fold ? B : 0;
850    ALLOC(_norm, M*eBands[m->nbEBands+1], celt_norm);
851    norm = _norm;
852    /* Just in case the first bands attempts to fold -- shouldn't really happen */
853    for (i=0;i<M;i++)
854       norm[i] = 0;
855
856    balance = 0;
857    for (i=start;i<m->nbEBands;i++)
858    {
859       int tell;
860       int N;
861       int curr_balance;
862       
863       N = M*eBands[i+1]-M*eBands[i];
864
865       tell = ec_enc_tell(enc, BITRES);
866       if (i != start)
867          balance -= tell;
868       remaining_bits = (total_bits<<BITRES)-tell-1;
869       curr_balance = (m->nbEBands-i);
870       if (curr_balance > 3)
871          curr_balance = 3;
872       curr_balance = balance / curr_balance;
873
874       quant_band(m, i, X+M*eBands[i], NULL, N, pulses[i]+curr_balance, spread, norm+M*eBands[start], resynth, enc, &remaining_bits, LM, norm+M*eBands[i], NULL);
875
876       balance += pulses[i] + tell;
877    }
878    RESTORE_STACK;
879 }
880
881 /* Decoding of the residual */
882 void unquant_all_bands(const CELTMode *m, int start, celt_norm * restrict X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int total_bits, int encode, ec_dec *dec, int LM)
883 {
884    int i, remaining_bits, balance;
885    const celt_int16 * restrict eBands = m->eBands;
886    celt_norm * restrict norm;
887    VARDECL(celt_norm, _norm);
888    int B;
889    int M;
890    int spread;
891    SAVE_STACK;
892
893    M = 1<<LM;
894    B = shortBlocks ? M : 1;
895    spread = fold ? B : 0;
896    ALLOC(_norm, M*eBands[m->nbEBands+1], celt_norm);
897    norm = _norm;
898    /* Just in case the first bands attempts to fold -- shouldn't really happen */
899    for (i=0;i<M;i++)
900       norm[i] = 0;
901
902    balance = 0;
903    for (i=start;i<m->nbEBands;i++)
904    {
905       int tell;
906       int N;
907       int curr_balance;
908
909       N = M*eBands[i+1]-M*eBands[i];
910
911       tell = ec_dec_tell(dec, BITRES);
912       if (i != start)
913          balance -= tell;
914       remaining_bits = (total_bits<<BITRES)-tell-1;
915       curr_balance = (m->nbEBands-i);
916       if (curr_balance > 3)
917          curr_balance = 3;
918       curr_balance = balance / curr_balance;
919
920       unquant_band(m, i, X+M*eBands[i], NULL, N, pulses[i]+curr_balance, spread, norm+M*eBands[start], dec, &remaining_bits, LM, norm+M*eBands[i]);
921
922       balance += pulses[i] + tell;
923    }
924    RESTORE_STACK;
925 }
926
927 #ifndef DISABLE_STEREO
928
929 void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, ec_enc *enc, int LM)
930 {
931    int i, remaining_bits, balance;
932    const celt_int16 * restrict eBands = m->eBands;
933    celt_norm * restrict norm;
934    VARDECL(celt_norm, _norm);
935    int B;
936    int M;
937    int spread;
938    SAVE_STACK;
939
940    M = 1<<LM;
941    B = shortBlocks ? M : 1;
942    spread = fold ? B : 0;
943    ALLOC(_norm, M*eBands[m->nbEBands+1], celt_norm);
944    norm = _norm;
945    /* Just in case the first bands attempts to fold -- not that rare for stereo */
946    for (i=0;i<M;i++)
947       norm[i] = 0;
948
949    balance = 0;
950    for (i=start;i<m->nbEBands;i++)
951    {
952       int tell;
953       int b;
954       int N;
955       int curr_balance;
956       celt_norm * restrict X, * restrict Y;
957       
958       X = _X+M*eBands[i];
959       Y = X+M*eBands[m->nbEBands+1];
960
961       N = M*eBands[i+1]-M*eBands[i];
962       tell = ec_enc_tell(enc, BITRES);
963       if (i != start)
964          balance -= tell;
965       remaining_bits = (total_bits<<BITRES)-tell-1;
966       curr_balance = (m->nbEBands-i);
967       if (curr_balance > 3)
968          curr_balance = 3;
969       curr_balance = balance / curr_balance;
970       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
971       if (b<0)
972          b = 0;
973
974       quant_band(m, i, X, Y, N, b, spread, norm+M*eBands[start], resynth, enc, &remaining_bits, LM, norm+M*eBands[i], bandE);
975
976       balance += pulses[i] + tell;
977
978       if (resynth)
979       {
980          stereo_band_mix(m, X, Y, bandE, 0, i, -1, M);
981          renormalise_vector(X, Q15ONE, N, 1);
982          renormalise_vector(Y, Q15ONE, N, 1);
983       }
984    }
985    RESTORE_STACK;
986 }
987 #endif /* DISABLE_STEREO */
988
989
990 #ifndef DISABLE_STEREO
991
992 void unquant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int total_bits, ec_dec *dec, int LM)
993 {
994    int i, remaining_bits, balance;
995    const celt_int16 * restrict eBands = m->eBands;
996    celt_norm * restrict norm;
997    VARDECL(celt_norm, _norm);
998    int B;
999    int M;
1000    int spread;
1001    SAVE_STACK;
1002
1003    M = 1<<LM;
1004    B = shortBlocks ? M : 1;
1005    spread = fold ? B : 0;
1006    ALLOC(_norm, M*eBands[m->nbEBands+1], celt_norm);
1007    norm = _norm;
1008    /* Just in case the first bands attempts to fold -- not that rare for stereo */
1009    for (i=0;i<M;i++)
1010       norm[i] = 0;
1011
1012    balance = 0;
1013    for (i=start;i<m->nbEBands;i++)
1014    {
1015       int tell;
1016       int b;
1017       int N;
1018       int curr_balance;
1019       celt_norm * restrict X, * restrict Y;
1020       
1021       X = _X+M*eBands[i];
1022       Y = X+M*eBands[m->nbEBands+1];
1023
1024       N = M*eBands[i+1]-M*eBands[i];
1025       tell = ec_dec_tell(dec, BITRES);
1026       if (i != start)
1027          balance -= tell;
1028       remaining_bits = (total_bits<<BITRES)-tell-1;
1029       curr_balance = (m->nbEBands-i);
1030       if (curr_balance > 3)
1031          curr_balance = 3;
1032       curr_balance = balance / curr_balance;
1033       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
1034       if (b<0)
1035          b = 0;
1036
1037       unquant_band(m, i, X, Y, N, b, spread, norm+M*eBands[start], dec, &remaining_bits, LM, norm+M*eBands[i]);
1038
1039       balance += pulses[i] + tell;
1040       
1041       stereo_band_mix(m, X, Y, bandE, 0, i, -1, M);
1042       renormalise_vector(X, Q15ONE, N, 1);
1043       renormalise_vector(Y, Q15ONE, N, 1);
1044    }
1045    RESTORE_STACK;
1046 }
1047
1048 #endif /* DISABLE_STEREO */