Fix collapse_masks overflow for mono.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 static celt_uint32 lcg_rand(celt_uint32 seed)
49 {
50    return 1664525 * seed + 1013904223;
51 }
52
53 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
54    with this approximation is important because it has an impact on the bit allocation */
55 static celt_int16 bitexact_cos(celt_int16 x)
56 {
57    celt_int32 tmp;
58    celt_int16 x2;
59    tmp = (4096+((celt_int32)(x)*(x)))>>13;
60    if (tmp > 32767)
61       tmp = 32767;
62    x2 = tmp;
63    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
64    if (x2 > 32766)
65       x2 = 32766;
66    return 1+x2;
67 }
68
69 static int bitexact_log2tan(int isin,int icos)
70 {
71    int lc;
72    int ls;
73    lc=EC_ILOG(icos);
74    ls=EC_ILOG(isin);
75    icos<<=15-lc;
76    isin<<=15-ls;
77    return (ls-lc<<11)
78          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
79          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
80 }
81
82 #ifdef FIXED_POINT
83 /* Compute the amplitude (sqrt energy) in each of the bands */
84 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
85 {
86    int i, c, N;
87    const celt_int16 *eBands = m->eBands;
88    const int C = CHANNELS(_C);
89    N = M*m->shortMdctSize;
90    c=0; do {
91       for (i=0;i<end;i++)
92       {
93          int j;
94          celt_word32 maxval=0;
95          celt_word32 sum = 0;
96          
97          j=M*eBands[i]; do {
98             maxval = MAX32(maxval, X[j+c*N]);
99             maxval = MAX32(maxval, -X[j+c*N]);
100          } while (++j<M*eBands[i+1]);
101          
102          if (maxval > 0)
103          {
104             int shift = celt_ilog2(maxval)-10;
105             j=M*eBands[i]; do {
106                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
107                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
108             } while (++j<M*eBands[i+1]);
109             /* We're adding one here to make damn sure we never end up with a pitch vector that's
110                larger than unity norm */
111             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
112          } else {
113             bank[i+c*m->nbEBands] = EPSILON;
114          }
115          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
116       }
117    } while (++c<C);
118    /*printf ("\n");*/
119 }
120
121 /* Normalise each band such that the energy is one. */
122 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
123 {
124    int i, c, N;
125    const celt_int16 *eBands = m->eBands;
126    const int C = CHANNELS(_C);
127    N = M*m->shortMdctSize;
128    c=0; do {
129       i=0; do {
130          celt_word16 g;
131          int j,shift;
132          celt_word16 E;
133          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
134          E = VSHR32(bank[i+c*m->nbEBands], shift);
135          g = EXTRACT16(celt_rcp(SHL32(E,3)));
136          j=M*eBands[i]; do {
137             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
138          } while (++j<M*eBands[i+1]);
139       } while (++i<end);
140    } while (++c<C);
141 }
142
143 #else /* FIXED_POINT */
144 /* Compute the amplitude (sqrt energy) in each of the bands */
145 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
146 {
147    int i, c, N;
148    const celt_int16 *eBands = m->eBands;
149    const int C = CHANNELS(_C);
150    N = M*m->shortMdctSize;
151    c=0; do {
152       for (i=0;i<end;i++)
153       {
154          int j;
155          celt_word32 sum = 1e-10f;
156          for (j=M*eBands[i];j<M*eBands[i+1];j++)
157             sum += X[j+c*N]*X[j+c*N];
158          bank[i+c*m->nbEBands] = celt_sqrt(sum);
159          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
160       }
161    } while (++c<C);
162    /*printf ("\n");*/
163 }
164
165 /* Normalise each band such that the energy is one. */
166 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
167 {
168    int i, c, N;
169    const celt_int16 *eBands = m->eBands;
170    const int C = CHANNELS(_C);
171    N = M*m->shortMdctSize;
172    c=0; do {
173       for (i=0;i<end;i++)
174       {
175          int j;
176          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
177          for (j=M*eBands[i];j<M*eBands[i+1];j++)
178             X[j+c*N] = freq[j+c*N]*g;
179       }
180    } while (++c<C);
181 }
182
183 #endif /* FIXED_POINT */
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    c=0; do {
194       celt_sig * restrict f;
195       const celt_norm * restrict x;
196       f = freq+c*N;
197       x = X+c*N;
198       for (i=0;i<end;i++)
199       {
200          int j, band_end;
201          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
202          j=M*eBands[i];
203          band_end = M*eBands[i+1];
204          do {
205             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
206             x++;
207          } while (++j<band_end);
208       }
209       for (i=M*eBands[m->nbEBands];i<N;i++)
210          *f++ = 0;
211    } while (++c<C);
212 }
213
214 /* This prevents energy collapse for transients with multiple short MDCTs */
215 void anti_collapse(const CELTMode *m, celt_norm *_X, unsigned char *collapse_masks, int LM, int C, int size,
216       int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
217       celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
218 {
219    int c, i, j, k;
220    for (i=start;i<end;i++)
221    {
222       int N0;
223       celt_word16 thresh, sqrt_1;
224       int depth;
225 #ifdef FIXED_POINT
226       int shift;
227 #endif
228
229       N0 = m->eBands[i+1]-m->eBands[i];
230       depth = (1+(pulses[i]>>BITRES))/(m->eBands[i+1]-m->eBands[i]<<LM);
231
232 #ifdef FIXED_POINT
233       thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11)),1) ));
234       {
235          celt_word32 t;
236          t = N0<<LM;
237          shift = celt_ilog2(t)>>1;
238          t = SHL32(t, (7-shift)<<1);
239          sqrt_1 = celt_rsqrt_norm(t);
240       }
241 #else
242       thresh = .3f*celt_exp2(-depth);
243       sqrt_1 = celt_rsqrt(N0<<LM);
244 #endif
245
246       c=0; do
247       {
248          celt_norm *X;
249          celt_word16 Ediff;
250          celt_word16 r;
251          Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
252          Ediff = MAX16(0, Ediff);
253
254 #ifdef FIXED_POINT
255          if (Ediff < 16384)
256             r = 2*MIN16(16383,SHR32(celt_exp2(-SHL16(Ediff, 11-DB_SHIFT)),1));
257          else
258             r = 0;
259          r = SHR16(MIN16(thresh, r),1);
260          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
261 #else
262          r = 2.f*celt_exp2(-Ediff);
263          r = MIN16(thresh, r);
264          r = r*sqrt_1;
265 #endif
266          X = _X+c*size+(m->eBands[i]<<LM);
267          for (k=0;k<1<<LM;k++)
268          {
269             /* Detect collapse */
270             if (!(collapse_masks[i*C+c]&1<<k))
271             {
272                /* Fill with noise */
273                for (j=0;j<N0;j++)
274                {
275                   seed = lcg_rand(seed);
276                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
277                }
278             }
279          }
280          /* We just added some energy, so we need to renormalise */
281          renormalise_vector(X, N0<<LM, Q15ONE);
282       } while (++c<C);
283    }
284
285 }
286
287
288 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
289 {
290    int i = bandID;
291    int j;
292    celt_word16 a1, a2;
293    celt_word16 left, right;
294    celt_word16 norm;
295 #ifdef FIXED_POINT
296    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
297 #endif
298    left = VSHR32(bank[i],shift);
299    right = VSHR32(bank[i+m->nbEBands],shift);
300    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
301    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
302    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
303    for (j=0;j<N;j++)
304    {
305       celt_norm r, l;
306       l = X[j];
307       r = Y[j];
308       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
309       /* Side is not encoded, no need to calculate */
310    }
311 }
312
313 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
314 {
315    int j;
316    for (j=0;j<N;j++)
317    {
318       celt_norm r, l;
319       l = MULT16_16_Q15(QCONST16(.70710678f,15), X[j]);
320       r = MULT16_16_Q15(QCONST16(.70710678f,15), Y[j]);
321       X[j] = l+r;
322       Y[j] = r-l;
323    }
324 }
325
326 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, int N)
327 {
328    int j;
329    celt_word32 xp=0, side=0;
330    celt_word32 El, Er;
331    celt_word16 mid2;
332 #ifdef FIXED_POINT
333    int kl, kr;
334 #endif
335    celt_word32 t, lgain, rgain;
336
337    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
338    for (j=0;j<N;j++)
339    {
340       xp = MAC16_16(xp, X[j], Y[j]);
341       side = MAC16_16(side, Y[j], Y[j]);
342    }
343    /* Compensating for the mid normalization */
344    xp = MULT16_32_Q15(mid, xp);
345    /* mid and side are in Q15, not Q14 like X and Y */
346    mid2 = SHR32(mid, 1);
347    El = MULT16_16(mid2, mid2) + side - 2*xp;
348    Er = MULT16_16(mid2, mid2) + side + 2*xp;
349    if (Er < EPSILON)
350       Er = EPSILON;
351    if (El < EPSILON)
352       El = EPSILON;
353
354 #ifdef FIXED_POINT
355    kl = celt_ilog2(El)>>1;
356    kr = celt_ilog2(Er)>>1;
357 #endif
358    t = VSHR32(El, (kl-7)<<1);
359    lgain = celt_rsqrt_norm(t);
360    t = VSHR32(Er, (kr-7)<<1);
361    rgain = celt_rsqrt_norm(t);
362
363 #ifdef FIXED_POINT
364    if (kl < 7)
365       kl = 7;
366    if (kr < 7)
367       kr = 7;
368 #endif
369
370    for (j=0;j<N;j++)
371    {
372       celt_norm r, l;
373       /* Apply mid scaling (side is already scaled) */
374       l = MULT16_16_Q15(mid, X[j]);
375       r = Y[j];
376       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
377       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
378    }
379 }
380
381 /* Decide whether we should spread the pulses in the current frame */
382 int spreading_decision(const CELTMode *m, celt_norm *X, int *average,
383       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
384       int end, int _C, int M)
385 {
386    int i, c, N0;
387    int sum = 0, nbBands=0;
388    const int C = CHANNELS(_C);
389    const celt_int16 * restrict eBands = m->eBands;
390    int decision;
391    int hf_sum=0;
392    
393    N0 = M*m->shortMdctSize;
394
395    if (M*(eBands[end]-eBands[end-1]) <= 8)
396       return SPREAD_NONE;
397    c=0; do {
398       for (i=0;i<end;i++)
399       {
400          int j, N, tmp=0;
401          int tcount[3] = {0};
402          celt_norm * restrict x = X+M*eBands[i]+c*N0;
403          N = M*(eBands[i+1]-eBands[i]);
404          if (N<=8)
405             continue;
406          /* Compute rough CDF of |x[j]| */
407          for (j=0;j<N;j++)
408          {
409             celt_word32 x2N; /* Q13 */
410
411             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
412             if (x2N < QCONST16(0.25f,13))
413                tcount[0]++;
414             if (x2N < QCONST16(0.0625f,13))
415                tcount[1]++;
416             if (x2N < QCONST16(0.015625f,13))
417                tcount[2]++;
418          }
419
420          /* Only include four last bands (8 kHz and up) */
421          if (i>m->nbEBands-4)
422             hf_sum += 32*(tcount[1]+tcount[0])/N;
423          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
424          sum += tmp*256;
425          nbBands++;
426       }
427    } while (++c<C);
428
429    if (update_hf)
430    {
431       if (hf_sum)
432          hf_sum /= C*(4-m->nbEBands+end);
433       *hf_average = (*hf_average+hf_sum)>>1;
434       hf_sum = *hf_average;
435       if (*tapset_decision==2)
436          hf_sum += 4;
437       else if (*tapset_decision==0)
438          hf_sum -= 4;
439       if (hf_sum > 22)
440          *tapset_decision=2;
441       else if (hf_sum > 18)
442          *tapset_decision=1;
443       else
444          *tapset_decision=0;
445    }
446    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
447    sum /= nbBands;
448    /* Recursive averaging */
449    sum = (sum+*average)>>1;
450    *average = sum;
451    /* Hysteresis */
452    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
453    if (sum < 80)
454    {
455       decision = SPREAD_AGGRESSIVE;
456    } else if (sum < 256)
457    {
458       decision = SPREAD_NORMAL;
459    } else if (sum < 384)
460    {
461       decision = SPREAD_LIGHT;
462    } else {
463       decision = SPREAD_NONE;
464    }
465    return decision;
466 }
467
468 #ifdef MEASURE_NORM_MSE
469
470 float MSE[30] = {0};
471 int nbMSEBands = 0;
472 int MSECount[30] = {0};
473
474 void dump_norm_mse(void)
475 {
476    int i;
477    for (i=0;i<nbMSEBands;i++)
478    {
479       printf ("%g ", MSE[i]/MSECount[i]);
480    }
481    printf ("\n");
482 }
483
484 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
485 {
486    static int init = 0;
487    int i;
488    if (!init)
489    {
490       atexit(dump_norm_mse);
491       init = 1;
492    }
493    for (i=0;i<m->nbEBands;i++)
494    {
495       int j;
496       int c;
497       float g;
498       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
499          continue;
500       c=0; do {
501          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
502          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
503             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
504       } while (++c<C);
505       MSECount[i]+=C;
506    }
507    nbMSEBands = m->nbEBands;
508 }
509
510 #endif
511
512 /* Indexing table for converting from natural Hadamard to ordery Hadamard
513    This is essentially a bit-reversed Gray, on top of which we've added
514    an inversion of the order because we want the DC at the end rather than
515    the beginning. The lines are for N=2, 4, 8, 16 */
516 static const int ordery_table[] = {
517        1,  0,
518        3,  0,  2,  1,
519        7,  0,  4,  3,  6,  1,  5,  2,
520       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
521 };
522
523 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
524 {
525    int i,j;
526    VARDECL(celt_norm, tmp);
527    int N;
528    SAVE_STACK;
529    N = N0*stride;
530    ALLOC(tmp, N, celt_norm);
531    if (hadamard)
532    {
533       const int *ordery = ordery_table+stride-2;
534       for (i=0;i<stride;i++)
535       {
536          for (j=0;j<N0;j++)
537             tmp[ordery[i]*N0+j] = X[j*stride+i];
538       }
539    } else {
540       for (i=0;i<stride;i++)
541          for (j=0;j<N0;j++)
542             tmp[i*N0+j] = X[j*stride+i];
543    }
544    for (j=0;j<N;j++)
545       X[j] = tmp[j];
546    RESTORE_STACK;
547 }
548
549 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
550 {
551    int i,j;
552    VARDECL(celt_norm, tmp);
553    int N;
554    SAVE_STACK;
555    N = N0*stride;
556    ALLOC(tmp, N, celt_norm);
557    if (hadamard)
558    {
559       const int *ordery = ordery_table+stride-2;
560       for (i=0;i<stride;i++)
561          for (j=0;j<N0;j++)
562             tmp[j*stride+i] = X[ordery[i]*N0+j];
563    } else {
564       for (i=0;i<stride;i++)
565          for (j=0;j<N0;j++)
566             tmp[j*stride+i] = X[i*N0+j];
567    }
568    for (j=0;j<N;j++)
569       X[j] = tmp[j];
570    RESTORE_STACK;
571 }
572
573 void haar1(celt_norm *X, int N0, int stride)
574 {
575    int i, j;
576    N0 >>= 1;
577    for (i=0;i<stride;i++)
578       for (j=0;j<N0;j++)
579       {
580          celt_norm tmp1, tmp2;
581          tmp1 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*2*j+i]);
582          tmp2 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
583          X[stride*2*j+i] = tmp1 + tmp2;
584          X[stride*(2*j+1)+i] = tmp1 - tmp2;
585       }
586 }
587
588 static int compute_qn(int N, int b, int offset, int stereo)
589 {
590    static const celt_int16 exp2_table8[8] =
591       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
592    int qn, qb;
593    int N2 = 2*N-1;
594    if (stereo && N==2)
595       N2--;
596    qb = IMIN((b>>1)-(1<<BITRES), (b+N2*offset)/N2);
597
598    qb = IMAX(0, IMIN(8<<BITRES, qb));
599
600    if (qb<(1<<BITRES>>1)) {
601       qn = 1;
602    } else {
603       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
604       qn = (qn+1)>>1<<1;
605    }
606    celt_assert(qn <= 256);
607    return qn;
608 }
609
610 /* This function is responsible for encoding and decoding a band for both
611    the mono and stereo case. Even in the mono case, it can split the band
612    in two and transmit the energy difference with the two half-bands. It
613    can be called recursively so bands can end up being split in 8 parts. */
614 static unsigned quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
615       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
616       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
617       celt_uint32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
618 {
619    int q;
620    int curr_bits;
621    int stereo, split;
622    int imid=0, iside=0;
623    int N0=N;
624    int N_B=N;
625    int N_B0;
626    int B0=B;
627    int time_divide=0;
628    int recombine=0;
629    int inv = 0;
630    celt_word16 mid=0, side=0;
631    int longBlocks;
632    unsigned cm=0;
633
634    longBlocks = B0==1;
635
636    N_B /= B;
637    N_B0 = N_B;
638
639    split = stereo = Y != NULL;
640
641    /* Special case for one sample */
642    if (N==1)
643    {
644       int c;
645       celt_norm *x = X;
646       c=0; do {
647          int sign=0;
648          if (*remaining_bits>=1<<BITRES)
649          {
650             if (encode)
651             {
652                sign = x[0]<0;
653                ec_enc_bits((ec_enc*)ec, sign, 1);
654             } else {
655                sign = ec_dec_bits((ec_dec*)ec, 1);
656             }
657             *remaining_bits -= 1<<BITRES;
658             b-=1<<BITRES;
659          }
660          if (resynth)
661             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
662          x = Y;
663       } while (++c<1+stereo);
664       if (lowband_out)
665          lowband_out[0] = SHR16(X[0],4);
666       return 1;
667    }
668
669    if (!stereo && level == 0)
670    {
671       int k;
672       if (tf_change>0)
673          recombine = tf_change;
674       /* Band recombining to increase frequency resolution */
675
676       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
677       {
678          int j;
679          for (j=0;j<N;j++)
680             lowband_scratch[j] = lowband[j];
681          lowband = lowband_scratch;
682       }
683
684       for (k=0;k<recombine;k++)
685       {
686          if (encode)
687             haar1(X, N>>k, 1<<k);
688          if (lowband)
689             haar1(lowband, N>>k, 1<<k);
690          fill |= fill<<(1<<k);
691       }
692       B>>=recombine;
693       N_B<<=recombine;
694
695       /* Increasing the time resolution */
696       while ((N_B&1) == 0 && tf_change<0)
697       {
698          if (encode)
699             haar1(X, N_B, B);
700          if (lowband)
701             haar1(lowband, N_B, B);
702          fill |= fill<<B;
703          B <<= 1;
704          N_B >>= 1;
705          time_divide++;
706          tf_change++;
707       }
708       B0=B;
709       N_B0 = N_B;
710
711       /* Reorganize the samples in time order instead of frequency order */
712       if (B0>1)
713       {
714          if (encode)
715             deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
716          if (lowband)
717             deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
718       }
719    }
720
721    /* If we need more than 32 bits, try splitting the band in two. */
722    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
723    {
724       if (LM>0 || (N&1)==0)
725       {
726          N >>= 1;
727          Y = X+N;
728          split = 1;
729          LM -= 1;
730          if (B==1)
731             fill |= fill<<1;
732          B = (B+1)>>1;
733       }
734    }
735
736    if (split)
737    {
738       int qn;
739       int itheta=0;
740       int mbits, sbits, delta;
741       int qalloc;
742       int offset;
743       celt_int32 tell;
744
745       /* Decide on the resolution to give to the split parameter theta */
746       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
747       qn = compute_qn(N, b, offset, stereo);
748       if (stereo && i>=intensity)
749          qn = 1;
750       if (encode)
751       {
752          /* theta is the atan() of the ratio between the (normalized)
753             side and mid. With just that parameter, we can re-scale both
754             mid and side because we know that 1) they have unit norm and
755             2) they are orthogonal. */
756          itheta = stereo_itheta(X, Y, stereo, N);
757       }
758       tell = encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES);
759       if (qn!=1)
760       {
761          if (encode)
762             itheta = (itheta*qn+8192)>>14;
763
764          /* Entropy coding of the angle. We use a uniform pdf for the
765             time split, a step for stereo, and a triangular one for the rest. */
766          if (stereo && N>2)
767          {
768             int p0 = 3;
769             int x = itheta;
770             int x0 = qn/2;
771             int ft = p0*(x0+1) + x0;
772             /* Use a probability of p0 up to itheta=8192 and then use 1 after */
773             if (encode)
774             {
775                ec_encode((ec_enc*)ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
776             } else {
777                int fs;
778                fs=ec_decode(ec,ft);
779                if (fs<(x0+1)*p0)
780                   x=fs/p0;
781                else
782                   x=x0+1+(fs-(x0+1)*p0);
783                ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
784                itheta = x;
785             }
786          } else if (B0>1 || stereo) {
787             /* Uniform pdf */
788             if (encode)
789                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
790             else
791                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
792          } else {
793             int fs=1, ft;
794             ft = ((qn>>1)+1)*((qn>>1)+1);
795             if (encode)
796             {
797                int fl;
798
799                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
800                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
801                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
802
803                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
804             } else {
805                /* Triangular pdf */
806                int fl=0;
807                int fm;
808                fm = ec_decode((ec_dec*)ec, ft);
809
810                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
811                {
812                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
813                   fs = itheta + 1;
814                   fl = itheta*(itheta + 1)>>1;
815                }
816                else
817                {
818                   itheta = (2*(qn + 1)
819                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
820                   fs = qn + 1 - itheta;
821                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
822                }
823
824                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
825             }
826          }
827          itheta = (celt_int32)itheta*16384/qn;
828          if (encode && stereo)
829          {
830             if (itheta==0)
831                intensity_stereo(m, X, Y, bandE, i, N);
832             else
833                stereo_split(X, Y, N);
834          }
835          /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
836                   Let's do that at higher complexity */
837       } else if (stereo) {
838          if (encode)
839          {
840             inv = itheta > 8192;
841             if (inv)
842             {
843                int j;
844                for (j=0;j<N;j++)
845                   Y[j] = -Y[j];
846             }
847             intensity_stereo(m, X, Y, bandE, i, N);
848          }
849          if (b>2<<BITRES && *remaining_bits > 2<<BITRES)
850          {
851             if (encode)
852                ec_enc_bit_logp(ec, inv, 2);
853             else
854                inv = ec_dec_bit_logp(ec, 2);
855          } else
856             inv = 0;
857          itheta = 0;
858       }
859       qalloc = (encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES))
860                - tell;
861       b -= qalloc;
862
863       if (itheta == 0)
864       {
865          imid = 32767;
866          iside = 0;
867          fill &= (1<<B)-1;
868          delta = -16384;
869       } else if (itheta == 16384)
870       {
871          imid = 0;
872          iside = 32767;
873          fill &= (1<<B)-1<<B;
874          delta = 16384;
875       } else {
876          imid = bitexact_cos(itheta);
877          iside = bitexact_cos(16384-itheta);
878          /* This is the mid vs side allocation that minimizes squared error
879             in that band. */
880          delta = FRAC_MUL16(N-1<<7,bitexact_log2tan(iside,imid));
881       }
882
883 #ifdef FIXED_POINT
884       mid = imid;
885       side = iside;
886 #else
887       mid = (1.f/32768)*imid;
888       side = (1.f/32768)*iside;
889 #endif
890
891       /* This is a special case for N=2 that only works for stereo and takes
892          advantage of the fact that mid and side are orthogonal to encode
893          the side with just one bit. */
894       if (N==2 && stereo)
895       {
896          int c;
897          int sign=0;
898          celt_norm *x2, *y2;
899          mbits = b;
900          sbits = 0;
901          /* Only need one bit for the side */
902          if (itheta != 0 && itheta != 16384)
903             sbits = 1<<BITRES;
904          mbits -= sbits;
905          c = itheta > 8192;
906          *remaining_bits -= qalloc+sbits;
907
908          x2 = c ? Y : X;
909          y2 = c ? X : Y;
910          if (sbits)
911          {
912             if (encode)
913             {
914                /* Here we only need to encode a sign for the side */
915                sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
916                ec_enc_bits((ec_enc*)ec, sign, 1);
917             } else {
918                sign = ec_dec_bits((ec_dec*)ec, 1);
919             }
920          }
921          sign = 1-2*sign;
922          cm = quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, fill);
923          /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
924              and there's no need to worry about mixing with the other channel. */
925          y2[0] = -sign*x2[1];
926          y2[1] = sign*x2[0];
927          if (resynth)
928          {
929             celt_norm tmp;
930             X[0] = MULT16_16_Q15(mid, X[0]);
931             X[1] = MULT16_16_Q15(mid, X[1]);
932             Y[0] = MULT16_16_Q15(side, Y[0]);
933             Y[1] = MULT16_16_Q15(side, Y[1]);
934             tmp = X[0];
935             X[0] = SUB16(tmp,Y[0]);
936             Y[0] = ADD16(tmp,Y[0]);
937             tmp = X[1];
938             X[1] = SUB16(tmp,Y[1]);
939             Y[1] = ADD16(tmp,Y[1]);
940          }
941       } else {
942          /* "Normal" split code */
943          celt_norm *next_lowband2=NULL;
944          celt_norm *next_lowband_out1=NULL;
945          int next_level=0;
946
947          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
948          if (B0>1 && !stereo)
949          {
950             if (itheta > 8192)
951                /* Rough approximation for pre-echo masking */
952                delta -= delta>>(4-LM);
953             else
954                /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
955                delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
956          }
957          mbits = IMAX(0, IMIN(b, (b-delta)/2));
958          sbits = b-mbits;
959          *remaining_bits -= qalloc;
960
961          if (lowband && !stereo)
962             next_lowband2 = lowband+N; /* >32-bit split case */
963
964          /* Only stereo needs to pass on lowband_out. Otherwise, it's
965             handled at the end */
966          if (stereo)
967             next_lowband_out1 = lowband_out;
968          else
969             next_level = level+1;
970
971          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
972             mid for folding later */
973          cm = quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
974                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
975                NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
976          /* For a stereo split, the high bits of fill are always zero, so no
977              folding will be done to the side. */
978          cm |= quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
979                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
980                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<B;
981       }
982
983    } else {
984       /* This is the basic no-split case */
985       q = bits2pulses(m, i, LM, b);
986       curr_bits = pulses2bits(m, i, LM, q);
987       *remaining_bits -= curr_bits;
988
989       /* Ensures we can never bust the budget */
990       while (*remaining_bits < 0 && q > 0)
991       {
992          *remaining_bits += curr_bits;
993          q--;
994          curr_bits = pulses2bits(m, i, LM, q);
995          *remaining_bits -= curr_bits;
996       }
997
998       if (q!=0)
999       {
1000          int K = get_pulses(q);
1001
1002          /* Finally do the actual quantization */
1003          if (encode)
1004             cm = alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, gain);
1005          else
1006             cm = alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, gain);
1007       } else {
1008          /* If there's no pulse, fill the band anyway */
1009          int j;
1010          if (resynth)
1011          {
1012             if (!fill)
1013             {
1014                for (j=0;j<N;j++)
1015                   X[j] = 0;
1016             } else {
1017                if (lowband == NULL || (spread==SPREAD_AGGRESSIVE && B<=1))
1018                {
1019                   /* Noise */
1020                   for (j=0;j<N;j++)
1021                   {
1022                      *seed = lcg_rand(*seed);
1023                      X[j] = (celt_int32)(*seed)>>20;
1024                   }
1025                   cm = (1<<B)-1;
1026                } else {
1027                   /* Folded spectrum */
1028                   for (j=0;j<N;j++)
1029                      X[j] = lowband[j];
1030                   cm = fill;
1031                }
1032                renormalise_vector(X, N, gain);
1033             }
1034          }
1035       }
1036    }
1037
1038    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1039    if (resynth)
1040    {
1041       if (stereo)
1042       {
1043          if (N!=2)
1044          {
1045             cm |= cm>>B;
1046             stereo_merge(X, Y, mid, N);
1047          }
1048          if (inv)
1049          {
1050             int j;
1051             for (j=0;j<N;j++)
1052                Y[j] = -Y[j];
1053          }
1054       } else if (level == 0)
1055       {
1056          int k;
1057
1058          /* Undo the sample reorganization going from time order to frequency order */
1059          if (B0>1)
1060             interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1061
1062          /* Undo time-freq changes that we did earlier */
1063          N_B = N_B0;
1064          B = B0;
1065          for (k=0;k<time_divide;k++)
1066          {
1067             B >>= 1;
1068             N_B <<= 1;
1069             cm |= cm>>B;
1070             haar1(X, N_B, B);
1071          }
1072
1073          for (k=0;k<recombine;k++)
1074          {
1075             cm |= cm<<(1<<k);
1076             haar1(X, N0>>k, 1<<k);
1077          }
1078          B<<=recombine;
1079          N_B>>=recombine;
1080
1081          /* Scale output for later folding */
1082          if (lowband_out)
1083          {
1084             int j;
1085             celt_word16 n;
1086             n = celt_sqrt(SHL32(EXTEND32(N0),22));
1087             for (j=0;j<N0;j++)
1088                lowband_out[j] = MULT16_16_Q15(n,X[j]);
1089          }
1090       }
1091    }
1092    return cm;
1093 }
1094
1095 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1096       celt_norm *_X, celt_norm *_Y, unsigned char *collapse_masks, const celt_ener *bandE, int *pulses,
1097       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
1098       int total_bits, void *ec, int LM, int codedBands, ec_uint32 *seed)
1099 {
1100    int i;
1101    celt_int32 balance;
1102    celt_int32 remaining_bits;
1103    const celt_int16 * restrict eBands = m->eBands;
1104    celt_norm * restrict norm, * restrict norm2;
1105    VARDECL(celt_norm, _norm);
1106    VARDECL(celt_norm, lowband_scratch);
1107    int B;
1108    int M;
1109    int lowband_offset;
1110    int update_lowband = 1;
1111    int C = _Y != NULL ? 2 : 1;
1112    SAVE_STACK;
1113
1114    M = 1<<LM;
1115    B = shortBlocks ? M : 1;
1116    ALLOC(_norm, C*M*eBands[m->nbEBands], celt_norm);
1117    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
1118    norm = _norm;
1119    norm2 = norm + M*eBands[m->nbEBands];
1120
1121    balance = 0;
1122    lowband_offset = 0;
1123    for (i=start;i<end;i++)
1124    {
1125       celt_int32 tell;
1126       int b;
1127       int N;
1128       celt_int32 curr_balance;
1129       int effective_lowband=-1;
1130       celt_norm * restrict X, * restrict Y;
1131       int tf_change=0;
1132       unsigned x_cm;
1133       unsigned y_cm;
1134
1135       X = _X+M*eBands[i];
1136       if (_Y!=NULL)
1137          Y = _Y+M*eBands[i];
1138       else
1139          Y = NULL;
1140       N = M*eBands[i+1]-M*eBands[i];
1141       if (encode)
1142          tell = ec_enc_tell((ec_enc*)ec, BITRES);
1143       else
1144          tell = ec_dec_tell((ec_dec*)ec, BITRES);
1145
1146       /* Compute how many bits we want to allocate to this band */
1147       if (i != start)
1148          balance -= tell;
1149       remaining_bits = ((celt_int32)total_bits<<BITRES)-tell-1- (shortBlocks&&LM>=2 ? (1<<BITRES) : 0);
1150       if (i <= codedBands-1)
1151       {
1152          curr_balance = balance / IMIN(3, codedBands-i);
1153          b = IMAX(0, IMIN(16384, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1154       } else {
1155          b = 0;
1156       }
1157
1158       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1159             lowband_offset = i;
1160
1161       tf_change = tf_res[i];
1162       if (i>=m->effEBands)
1163       {
1164          X=norm;
1165          if (_Y!=NULL)
1166             Y = norm;
1167       }
1168
1169       /* This ensures we never repeat spectral content within one band */
1170       if (lowband_offset != 0)
1171          effective_lowband = IMAX(M*eBands[start], M*eBands[lowband_offset]-N);
1172
1173       /* Get a conservative estimate of the collapse_mask's for the bands we're
1174           going to be folding from. */
1175       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1))
1176       {
1177          int fold_start;
1178          int fold_end;
1179          int fold_i;
1180          fold_start = lowband_offset;
1181          while(M*eBands[--fold_start] > effective_lowband);
1182          fold_end = lowband_offset-1;
1183          while(M*eBands[++fold_end] < effective_lowband+N);
1184          x_cm = y_cm = 0;
1185          fold_i = fold_start; do {
1186            x_cm |= collapse_masks[fold_i*C+0];
1187            y_cm |= collapse_masks[fold_i*C+1];
1188          } while (++fold_i<fold_end);
1189       }
1190       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1191           always) be non-zero.*/
1192       else
1193          x_cm = y_cm = (1<<B)-1;
1194
1195       if (dual_stereo && i==intensity)
1196       {
1197          int j;
1198
1199          /* Switch off dual stereo to do intensity */
1200          dual_stereo = 0;
1201          for (j=M*eBands[start];j<M*eBands[i];j++)
1202             norm[j] = HALF32(norm[j]+norm2[j]);
1203       }
1204       if (dual_stereo)
1205       {
1206          x_cm = quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
1207                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1208                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm);
1209          y_cm = quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
1210                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1211                norm2+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, y_cm);
1212       } else {
1213          x_cm = quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
1214                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1215                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm|y_cm);
1216          y_cm = x_cm;
1217       }
1218       collapse_masks[i*C+0] = (unsigned char)(x_cm&(1<<B)-1);
1219       collapse_masks[i*C+C-1] = (unsigned char)(y_cm&(1<<B)-1);
1220       balance += pulses[i] + tell;
1221
1222       /* Update the folding position only as long as we have 1 bit/sample depth */
1223       update_lowband = (b>>BITRES)>N;
1224    }
1225    RESTORE_STACK;
1226 }
1227