Avoids pre-echo in hybrid mode caused by noise being injected in the first band
[opus.git] / celt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16
17    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 #ifdef HAVE_CONFIG_H
31 #include "config.h"
32 #endif
33
34 #include <math.h>
35 #include "bands.h"
36 #include "modes.h"
37 #include "vq.h"
38 #include "cwrs.h"
39 #include "stack_alloc.h"
40 #include "os_support.h"
41 #include "mathops.h"
42 #include "rate.h"
43 #include "quant_bands.h"
44 #include "pitch.h"
45
46 int hysteresis_decision(opus_val16 val, const opus_val16 *thresholds, const opus_val16 *hysteresis, int N, int prev)
47 {
48    int i;
49    for (i=0;i<N;i++)
50    {
51       if (val < thresholds[i])
52          break;
53    }
54    if (i>prev && val < thresholds[prev]+hysteresis[prev])
55       i=prev;
56    if (i<prev && val > thresholds[prev-1]-hysteresis[prev-1])
57       i=prev;
58    return i;
59 }
60
61 opus_uint32 celt_lcg_rand(opus_uint32 seed)
62 {
63    return 1664525 * seed + 1013904223;
64 }
65
66 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
67    with this approximation is important because it has an impact on the bit allocation */
68 static opus_int16 bitexact_cos(opus_int16 x)
69 {
70    opus_int32 tmp;
71    opus_int16 x2;
72    tmp = (4096+((opus_int32)(x)*(x)))>>13;
73    celt_assert(tmp<=32767);
74    x2 = tmp;
75    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
76    celt_assert(x2<=32766);
77    return 1+x2;
78 }
79
80 static int bitexact_log2tan(int isin,int icos)
81 {
82    int lc;
83    int ls;
84    lc=EC_ILOG(icos);
85    ls=EC_ILOG(isin);
86    icos<<=15-lc;
87    isin<<=15-ls;
88    return (ls-lc)*(1<<11)
89          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
90          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
91 }
92
93 #ifdef FIXED_POINT
94 /* Compute the amplitude (sqrt energy) in each of the bands */
95 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
96 {
97    int i, c, N;
98    const opus_int16 *eBands = m->eBands;
99    N = m->shortMdctSize<<LM;
100    c=0; do {
101       for (i=0;i<end;i++)
102       {
103          int j;
104          opus_val32 maxval=0;
105          opus_val32 sum = 0;
106
107          maxval = celt_maxabs32(&X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
108          if (maxval > 0)
109          {
110             int shift = celt_ilog2(maxval) - 14 + (((m->logN[i]>>BITRES)+LM+1)>>1);
111             j=eBands[i]<<LM;
112             if (shift>0)
113             {
114                do {
115                   sum = MAC16_16(sum, EXTRACT16(SHR32(X[j+c*N],shift)),
116                         EXTRACT16(SHR32(X[j+c*N],shift)));
117                } while (++j<eBands[i+1]<<LM);
118             } else {
119                do {
120                   sum = MAC16_16(sum, EXTRACT16(SHL32(X[j+c*N],-shift)),
121                         EXTRACT16(SHL32(X[j+c*N],-shift)));
122                } while (++j<eBands[i+1]<<LM);
123             }
124             /* We're adding one here to ensure the normalized band isn't larger than unity norm */
125             bandE[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
126          } else {
127             bandE[i+c*m->nbEBands] = EPSILON;
128          }
129          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
130       }
131    } while (++c<C);
132    /*printf ("\n");*/
133 }
134
135 /* Normalise each band such that the energy is one. */
136 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
137 {
138    int i, c, N;
139    const opus_int16 *eBands = m->eBands;
140    N = M*m->shortMdctSize;
141    c=0; do {
142       i=0; do {
143          opus_val16 g;
144          int j,shift;
145          opus_val16 E;
146          shift = celt_zlog2(bandE[i+c*m->nbEBands])-13;
147          E = VSHR32(bandE[i+c*m->nbEBands], shift);
148          g = EXTRACT16(celt_rcp(SHL32(E,3)));
149          j=M*eBands[i]; do {
150             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
151          } while (++j<M*eBands[i+1]);
152       } while (++i<end);
153    } while (++c<C);
154 }
155
156 #else /* FIXED_POINT */
157 /* Compute the amplitude (sqrt energy) in each of the bands */
158 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
159 {
160    int i, c, N;
161    const opus_int16 *eBands = m->eBands;
162    N = m->shortMdctSize<<LM;
163    c=0; do {
164       for (i=0;i<end;i++)
165       {
166          opus_val32 sum;
167          sum = 1e-27f + celt_inner_prod_c(&X[c*N+(eBands[i]<<LM)], &X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
168          bandE[i+c*m->nbEBands] = celt_sqrt(sum);
169          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
170       }
171    } while (++c<C);
172    /*printf ("\n");*/
173 }
174
175 /* Normalise each band such that the energy is one. */
176 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
177 {
178    int i, c, N;
179    const opus_int16 *eBands = m->eBands;
180    N = M*m->shortMdctSize;
181    c=0; do {
182       for (i=0;i<end;i++)
183       {
184          int j;
185          opus_val16 g = 1.f/(1e-27f+bandE[i+c*m->nbEBands]);
186          for (j=M*eBands[i];j<M*eBands[i+1];j++)
187             X[j+c*N] = freq[j+c*N]*g;
188       }
189    } while (++c<C);
190 }
191
192 #endif /* FIXED_POINT */
193
194 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
195 void denormalise_bands(const CELTMode *m, const celt_norm * OPUS_RESTRICT X,
196       celt_sig * OPUS_RESTRICT freq, const opus_val16 *bandLogE, int start,
197       int end, int M, int downsample, int silence)
198 {
199    int i, N;
200    int bound;
201    celt_sig * OPUS_RESTRICT f;
202    const celt_norm * OPUS_RESTRICT x;
203    const opus_int16 *eBands = m->eBands;
204    N = M*m->shortMdctSize;
205    bound = M*eBands[end];
206    if (downsample!=1)
207       bound = IMIN(bound, N/downsample);
208    if (silence)
209    {
210       bound = 0;
211       start = end = 0;
212    }
213    f = freq;
214    x = X+M*eBands[start];
215    for (i=0;i<M*eBands[start];i++)
216       *f++ = 0;
217    for (i=start;i<end;i++)
218    {
219       int j, band_end;
220       opus_val16 g;
221       opus_val16 lg;
222 #ifdef FIXED_POINT
223       int shift;
224 #endif
225       j=M*eBands[i];
226       band_end = M*eBands[i+1];
227       lg = SATURATE16(ADD32(bandLogE[i], SHL32((opus_val32)eMeans[i],6)));
228 #ifndef FIXED_POINT
229       g = celt_exp2(MIN32(32.f, lg));
230 #else
231       /* Handle the integer part of the log energy */
232       shift = 16-(lg>>DB_SHIFT);
233       if (shift>31)
234       {
235          shift=0;
236          g=0;
237       } else {
238          /* Handle the fractional part. */
239          g = celt_exp2_frac(lg&((1<<DB_SHIFT)-1));
240       }
241       /* Handle extreme gains with negative shift. */
242       if (shift<0)
243       {
244          /* For shift <= -2 and g > 16384 we'd be likely to overflow, so we're
245             capping the gain here, which is equivalent to a cap of 18 on lg.
246             This shouldn't trigger unless the bitstream is already corrupted. */
247          if (shift <= -2)
248          {
249             g = 16384;
250             shift = -2;
251          }
252          do {
253             *f++ = SHL32(MULT16_16(*x++, g), -shift);
254          } while (++j<band_end);
255       } else
256 #endif
257          /* Be careful of the fixed-point "else" just above when changing this code */
258          do {
259             *f++ = SHR32(MULT16_16(*x++, g), shift);
260          } while (++j<band_end);
261    }
262    celt_assert(start <= end);
263    OPUS_CLEAR(&freq[bound], N-bound);
264 }
265
266 /* This prevents energy collapse for transients with multiple short MDCTs */
267 void anti_collapse(const CELTMode *m, celt_norm *X_, unsigned char *collapse_masks, int LM, int C, int size,
268       int start, int end, const opus_val16 *logE, const opus_val16 *prev1logE,
269       const opus_val16 *prev2logE, const int *pulses, opus_uint32 seed, int arch)
270 {
271    int c, i, j, k;
272    for (i=start;i<end;i++)
273    {
274       int N0;
275       opus_val16 thresh, sqrt_1;
276       int depth;
277 #ifdef FIXED_POINT
278       int shift;
279       opus_val32 thresh32;
280 #endif
281
282       N0 = m->eBands[i+1]-m->eBands[i];
283       /* depth in 1/8 bits */
284       celt_assert(pulses[i]>=0);
285       depth = celt_udiv(1+pulses[i], (m->eBands[i+1]-m->eBands[i]))>>LM;
286
287 #ifdef FIXED_POINT
288       thresh32 = SHR32(celt_exp2(-SHL16(depth, 10-BITRES)),1);
289       thresh = MULT16_32_Q15(QCONST16(0.5f, 15), MIN32(32767,thresh32));
290       {
291          opus_val32 t;
292          t = N0<<LM;
293          shift = celt_ilog2(t)>>1;
294          t = SHL32(t, (7-shift)<<1);
295          sqrt_1 = celt_rsqrt_norm(t);
296       }
297 #else
298       thresh = .5f*celt_exp2(-.125f*depth);
299       sqrt_1 = celt_rsqrt(N0<<LM);
300 #endif
301
302       c=0; do
303       {
304          celt_norm *X;
305          opus_val16 prev1;
306          opus_val16 prev2;
307          opus_val32 Ediff;
308          opus_val16 r;
309          int renormalize=0;
310          prev1 = prev1logE[c*m->nbEBands+i];
311          prev2 = prev2logE[c*m->nbEBands+i];
312          if (C==1)
313          {
314             prev1 = MAX16(prev1,prev1logE[m->nbEBands+i]);
315             prev2 = MAX16(prev2,prev2logE[m->nbEBands+i]);
316          }
317          Ediff = EXTEND32(logE[c*m->nbEBands+i])-EXTEND32(MIN16(prev1,prev2));
318          Ediff = MAX32(0, Ediff);
319
320 #ifdef FIXED_POINT
321          if (Ediff < 16384)
322          {
323             opus_val32 r32 = SHR32(celt_exp2(-EXTRACT16(Ediff)),1);
324             r = 2*MIN16(16383,r32);
325          } else {
326             r = 0;
327          }
328          if (LM==3)
329             r = MULT16_16_Q14(23170, MIN32(23169, r));
330          r = SHR16(MIN16(thresh, r),1);
331          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
332 #else
333          /* r needs to be multiplied by 2 or 2*sqrt(2) depending on LM because
334             short blocks don't have the same energy as long */
335          r = 2.f*celt_exp2(-Ediff);
336          if (LM==3)
337             r *= 1.41421356f;
338          r = MIN16(thresh, r);
339          r = r*sqrt_1;
340 #endif
341          X = X_+c*size+(m->eBands[i]<<LM);
342          for (k=0;k<1<<LM;k++)
343          {
344             /* Detect collapse */
345             if (!(collapse_masks[i*C+c]&1<<k))
346             {
347                /* Fill with noise */
348                for (j=0;j<N0;j++)
349                {
350                   seed = celt_lcg_rand(seed);
351                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
352                }
353                renormalize = 1;
354             }
355          }
356          /* We just added some energy, so we need to renormalise */
357          if (renormalize)
358             renormalise_vector(X, N0<<LM, Q15ONE, arch);
359       } while (++c<C);
360    }
361 }
362
363 /* Compute the weights to use for optimizing normalized distortion across
364    channels. We use the amplitude to weight square distortion, which means
365    that we use the square root of the value we would have been using if we
366    wanted to minimize the MSE in the non-normalized domain. This roughly
367    corresponds to some quick-and-dirty perceptual experiments I ran to
368    measure inter-aural masking (there doesn't seem to be any published data
369    on the topic). */
370 static void compute_channel_weights(celt_ener Ex, celt_ener Ey, opus_val16 w[2])
371 {
372    celt_ener minE;
373 #if FIXED_POINT
374    int shift;
375 #endif
376    minE = MIN32(Ex, Ey);
377    /* Adjustment to make the weights a bit more conservative. */
378    Ex = ADD32(Ex, minE/3);
379    Ey = ADD32(Ey, minE/3);
380 #if FIXED_POINT
381    shift = celt_ilog2(EPSILON+MAX32(Ex, Ey))-14;
382 #endif
383    w[0] = VSHR32(Ex, shift);
384    w[1] = VSHR32(Ey, shift);
385 }
386
387 static void intensity_stereo(const CELTMode *m, celt_norm * OPUS_RESTRICT X, const celt_norm * OPUS_RESTRICT Y, const celt_ener *bandE, int bandID, int N)
388 {
389    int i = bandID;
390    int j;
391    opus_val16 a1, a2;
392    opus_val16 left, right;
393    opus_val16 norm;
394 #ifdef FIXED_POINT
395    int shift = celt_zlog2(MAX32(bandE[i], bandE[i+m->nbEBands]))-13;
396 #endif
397    left = VSHR32(bandE[i],shift);
398    right = VSHR32(bandE[i+m->nbEBands],shift);
399    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
400    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
401    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
402    for (j=0;j<N;j++)
403    {
404       celt_norm r, l;
405       l = X[j];
406       r = Y[j];
407       X[j] = EXTRACT16(SHR32(MAC16_16(MULT16_16(a1, l), a2, r), 14));
408       /* Side is not encoded, no need to calculate */
409    }
410 }
411
412 static void stereo_split(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, int N)
413 {
414    int j;
415    for (j=0;j<N;j++)
416    {
417       opus_val32 r, l;
418       l = MULT16_16(QCONST16(.70710678f, 15), X[j]);
419       r = MULT16_16(QCONST16(.70710678f, 15), Y[j]);
420       X[j] = EXTRACT16(SHR32(ADD32(l, r), 15));
421       Y[j] = EXTRACT16(SHR32(SUB32(r, l), 15));
422    }
423 }
424
425 static void stereo_merge(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, opus_val16 mid, int N, int arch)
426 {
427    int j;
428    opus_val32 xp=0, side=0;
429    opus_val32 El, Er;
430    opus_val16 mid2;
431 #ifdef FIXED_POINT
432    int kl, kr;
433 #endif
434    opus_val32 t, lgain, rgain;
435
436    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
437    dual_inner_prod(Y, X, Y, N, &xp, &side, arch);
438    /* Compensating for the mid normalization */
439    xp = MULT16_32_Q15(mid, xp);
440    /* mid and side are in Q15, not Q14 like X and Y */
441    mid2 = SHR16(mid, 1);
442    El = MULT16_16(mid2, mid2) + side - 2*xp;
443    Er = MULT16_16(mid2, mid2) + side + 2*xp;
444    if (Er < QCONST32(6e-4f, 28) || El < QCONST32(6e-4f, 28))
445    {
446       OPUS_COPY(Y, X, N);
447       return;
448    }
449
450 #ifdef FIXED_POINT
451    kl = celt_ilog2(El)>>1;
452    kr = celt_ilog2(Er)>>1;
453 #endif
454    t = VSHR32(El, (kl-7)<<1);
455    lgain = celt_rsqrt_norm(t);
456    t = VSHR32(Er, (kr-7)<<1);
457    rgain = celt_rsqrt_norm(t);
458
459 #ifdef FIXED_POINT
460    if (kl < 7)
461       kl = 7;
462    if (kr < 7)
463       kr = 7;
464 #endif
465
466    for (j=0;j<N;j++)
467    {
468       celt_norm r, l;
469       /* Apply mid scaling (side is already scaled) */
470       l = MULT16_16_P15(mid, X[j]);
471       r = Y[j];
472       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
473       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
474    }
475 }
476
477 /* Decide whether we should spread the pulses in the current frame */
478 int spreading_decision(const CELTMode *m, const celt_norm *X, int *average,
479       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
480       int end, int C, int M)
481 {
482    int i, c, N0;
483    int sum = 0, nbBands=0;
484    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
485    int decision;
486    int hf_sum=0;
487
488    celt_assert(end>0);
489
490    N0 = M*m->shortMdctSize;
491
492    if (M*(eBands[end]-eBands[end-1]) <= 8)
493       return SPREAD_NONE;
494    c=0; do {
495       for (i=0;i<end;i++)
496       {
497          int j, N, tmp=0;
498          int tcount[3] = {0,0,0};
499          const celt_norm * OPUS_RESTRICT x = X+M*eBands[i]+c*N0;
500          N = M*(eBands[i+1]-eBands[i]);
501          if (N<=8)
502             continue;
503          /* Compute rough CDF of |x[j]| */
504          for (j=0;j<N;j++)
505          {
506             opus_val32 x2N; /* Q13 */
507
508             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
509             if (x2N < QCONST16(0.25f,13))
510                tcount[0]++;
511             if (x2N < QCONST16(0.0625f,13))
512                tcount[1]++;
513             if (x2N < QCONST16(0.015625f,13))
514                tcount[2]++;
515          }
516
517          /* Only include four last bands (8 kHz and up) */
518          if (i>m->nbEBands-4)
519             hf_sum += celt_udiv(32*(tcount[1]+tcount[0]), N);
520          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
521          sum += tmp*256;
522          nbBands++;
523       }
524    } while (++c<C);
525
526    if (update_hf)
527    {
528       if (hf_sum)
529          hf_sum = celt_udiv(hf_sum, C*(4-m->nbEBands+end));
530       *hf_average = (*hf_average+hf_sum)>>1;
531       hf_sum = *hf_average;
532       if (*tapset_decision==2)
533          hf_sum += 4;
534       else if (*tapset_decision==0)
535          hf_sum -= 4;
536       if (hf_sum > 22)
537          *tapset_decision=2;
538       else if (hf_sum > 18)
539          *tapset_decision=1;
540       else
541          *tapset_decision=0;
542    }
543    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
544    celt_assert(nbBands>0); /* end has to be non-zero */
545    celt_assert(sum>=0);
546    sum = celt_udiv(sum, nbBands);
547    /* Recursive averaging */
548    sum = (sum+*average)>>1;
549    *average = sum;
550    /* Hysteresis */
551    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
552    if (sum < 80)
553    {
554       decision = SPREAD_AGGRESSIVE;
555    } else if (sum < 256)
556    {
557       decision = SPREAD_NORMAL;
558    } else if (sum < 384)
559    {
560       decision = SPREAD_LIGHT;
561    } else {
562       decision = SPREAD_NONE;
563    }
564 #ifdef FUZZING
565    decision = rand()&0x3;
566    *tapset_decision=rand()%3;
567 #endif
568    return decision;
569 }
570
571 /* Indexing table for converting from natural Hadamard to ordery Hadamard
572    This is essentially a bit-reversed Gray, on top of which we've added
573    an inversion of the order because we want the DC at the end rather than
574    the beginning. The lines are for N=2, 4, 8, 16 */
575 static const int ordery_table[] = {
576        1,  0,
577        3,  0,  2,  1,
578        7,  0,  4,  3,  6,  1,  5,  2,
579       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
580 };
581
582 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
583 {
584    int i,j;
585    VARDECL(celt_norm, tmp);
586    int N;
587    SAVE_STACK;
588    N = N0*stride;
589    ALLOC(tmp, N, celt_norm);
590    celt_assert(stride>0);
591    if (hadamard)
592    {
593       const int *ordery = ordery_table+stride-2;
594       for (i=0;i<stride;i++)
595       {
596          for (j=0;j<N0;j++)
597             tmp[ordery[i]*N0+j] = X[j*stride+i];
598       }
599    } else {
600       for (i=0;i<stride;i++)
601          for (j=0;j<N0;j++)
602             tmp[i*N0+j] = X[j*stride+i];
603    }
604    OPUS_COPY(X, tmp, N);
605    RESTORE_STACK;
606 }
607
608 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
609 {
610    int i,j;
611    VARDECL(celt_norm, tmp);
612    int N;
613    SAVE_STACK;
614    N = N0*stride;
615    ALLOC(tmp, N, celt_norm);
616    if (hadamard)
617    {
618       const int *ordery = ordery_table+stride-2;
619       for (i=0;i<stride;i++)
620          for (j=0;j<N0;j++)
621             tmp[j*stride+i] = X[ordery[i]*N0+j];
622    } else {
623       for (i=0;i<stride;i++)
624          for (j=0;j<N0;j++)
625             tmp[j*stride+i] = X[i*N0+j];
626    }
627    OPUS_COPY(X, tmp, N);
628    RESTORE_STACK;
629 }
630
631 void haar1(celt_norm *X, int N0, int stride)
632 {
633    int i, j;
634    N0 >>= 1;
635    for (i=0;i<stride;i++)
636       for (j=0;j<N0;j++)
637       {
638          opus_val32 tmp1, tmp2;
639          tmp1 = MULT16_16(QCONST16(.70710678f,15), X[stride*2*j+i]);
640          tmp2 = MULT16_16(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
641          X[stride*2*j+i] = EXTRACT16(PSHR32(ADD32(tmp1, tmp2), 15));
642          X[stride*(2*j+1)+i] = EXTRACT16(PSHR32(SUB32(tmp1, tmp2), 15));
643       }
644 }
645
646 static int compute_qn(int N, int b, int offset, int pulse_cap, int stereo)
647 {
648    static const opus_int16 exp2_table8[8] =
649       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
650    int qn, qb;
651    int N2 = 2*N-1;
652    if (stereo && N==2)
653       N2--;
654    /* The upper limit ensures that in a stereo split with itheta==16384, we'll
655        always have enough bits left over to code at least one pulse in the
656        side; otherwise it would collapse, since it doesn't get folded. */
657    qb = celt_sudiv(b+N2*offset, N2);
658    qb = IMIN(b-pulse_cap-(4<<BITRES), qb);
659
660    qb = IMIN(8<<BITRES, qb);
661
662    if (qb<(1<<BITRES>>1)) {
663       qn = 1;
664    } else {
665       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
666       qn = (qn+1)>>1<<1;
667    }
668    celt_assert(qn <= 256);
669    return qn;
670 }
671
672 struct band_ctx {
673    int encode;
674    int resynth;
675    const CELTMode *m;
676    int i;
677    int intensity;
678    int spread;
679    int tf_change;
680    ec_ctx *ec;
681    opus_int32 remaining_bits;
682    const celt_ener *bandE;
683    opus_uint32 seed;
684    int arch;
685    int theta_round;
686    int disable_inv;
687    int avoid_split_noise;
688 };
689
690 struct split_ctx {
691    int inv;
692    int imid;
693    int iside;
694    int delta;
695    int itheta;
696    int qalloc;
697 };
698
699 static void compute_theta(struct band_ctx *ctx, struct split_ctx *sctx,
700       celt_norm *X, celt_norm *Y, int N, int *b, int B, int B0,
701       int LM,
702       int stereo, int *fill)
703 {
704    int qn;
705    int itheta=0;
706    int delta;
707    int imid, iside;
708    int qalloc;
709    int pulse_cap;
710    int offset;
711    opus_int32 tell;
712    int inv=0;
713    int encode;
714    const CELTMode *m;
715    int i;
716    int intensity;
717    ec_ctx *ec;
718    const celt_ener *bandE;
719
720    encode = ctx->encode;
721    m = ctx->m;
722    i = ctx->i;
723    intensity = ctx->intensity;
724    ec = ctx->ec;
725    bandE = ctx->bandE;
726
727    /* Decide on the resolution to give to the split parameter theta */
728    pulse_cap = m->logN[i]+LM*(1<<BITRES);
729    offset = (pulse_cap>>1) - (stereo&&N==2 ? QTHETA_OFFSET_TWOPHASE : QTHETA_OFFSET);
730    qn = compute_qn(N, *b, offset, pulse_cap, stereo);
731    if (stereo && i>=intensity)
732       qn = 1;
733    if (encode)
734    {
735       /* theta is the atan() of the ratio between the (normalized)
736          side and mid. With just that parameter, we can re-scale both
737          mid and side because we know that 1) they have unit norm and
738          2) they are orthogonal. */
739       itheta = stereo_itheta(X, Y, stereo, N, ctx->arch);
740    }
741    tell = ec_tell_frac(ec);
742    if (qn!=1)
743    {
744       if (encode)
745       {
746          if (!stereo || ctx->theta_round == 0)
747          {
748             itheta = (itheta*(opus_int32)qn+8192)>>14;
749             if (!stereo && ctx->avoid_split_noise && itheta > 0 && itheta < qn)
750             {
751                /* Check if the selected value of theta will cause the bit allocation
752                   to inject noise on one side. If so, make sure the energy of that side
753                   is zero. */
754                int unquantized = celt_udiv((opus_int32)itheta*16384, qn);
755                imid = bitexact_cos((opus_int16)unquantized);
756                iside = bitexact_cos((opus_int16)(16384-unquantized));
757                delta = FRAC_MUL16((N-1)<<7,bitexact_log2tan(iside,imid));
758                if (delta > *b)
759                   itheta = qn;
760                else if (delta < -*b)
761                   itheta = 0;
762             }
763          } else {
764             int down;
765             /* Bias quantization towards itheta=0 and itheta=16384. */
766             int bias = itheta > 8192 ? 32767/qn : -32767/qn;
767             down = IMIN(qn-1, IMAX(0, (itheta*(opus_int32)qn + bias)>>14));
768             if (ctx->theta_round < 0)
769                itheta = down;
770             else
771                itheta = down+1;
772          }
773       }
774       /* Entropy coding of the angle. We use a uniform pdf for the
775          time split, a step for stereo, and a triangular one for the rest. */
776       if (stereo && N>2)
777       {
778          int p0 = 3;
779          int x = itheta;
780          int x0 = qn/2;
781          int ft = p0*(x0+1) + x0;
782          /* Use a probability of p0 up to itheta=8192 and then use 1 after */
783          if (encode)
784          {
785             ec_encode(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
786          } else {
787             int fs;
788             fs=ec_decode(ec,ft);
789             if (fs<(x0+1)*p0)
790                x=fs/p0;
791             else
792                x=x0+1+(fs-(x0+1)*p0);
793             ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
794             itheta = x;
795          }
796       } else if (B0>1 || stereo) {
797          /* Uniform pdf */
798          if (encode)
799             ec_enc_uint(ec, itheta, qn+1);
800          else
801             itheta = ec_dec_uint(ec, qn+1);
802       } else {
803          int fs=1, ft;
804          ft = ((qn>>1)+1)*((qn>>1)+1);
805          if (encode)
806          {
807             int fl;
808
809             fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
810             fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
811              ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
812
813             ec_encode(ec, fl, fl+fs, ft);
814          } else {
815             /* Triangular pdf */
816             int fl=0;
817             int fm;
818             fm = ec_decode(ec, ft);
819
820             if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
821             {
822                itheta = (isqrt32(8*(opus_uint32)fm + 1) - 1)>>1;
823                fs = itheta + 1;
824                fl = itheta*(itheta + 1)>>1;
825             }
826             else
827             {
828                itheta = (2*(qn + 1)
829                 - isqrt32(8*(opus_uint32)(ft - fm - 1) + 1))>>1;
830                fs = qn + 1 - itheta;
831                fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
832             }
833
834             ec_dec_update(ec, fl, fl+fs, ft);
835          }
836       }
837       celt_assert(itheta>=0);
838       itheta = celt_udiv((opus_int32)itheta*16384, qn);
839       if (encode && stereo)
840       {
841          if (itheta==0)
842             intensity_stereo(m, X, Y, bandE, i, N);
843          else
844             stereo_split(X, Y, N);
845       }
846       /* NOTE: Renormalising X and Y *may* help fixed-point a bit at very high rate.
847                Let's do that at higher complexity */
848    } else if (stereo) {
849       if (encode)
850       {
851          inv = itheta > 8192 && !ctx->disable_inv;
852          if (inv)
853          {
854             int j;
855             for (j=0;j<N;j++)
856                Y[j] = -Y[j];
857          }
858          intensity_stereo(m, X, Y, bandE, i, N);
859       }
860       if (*b>2<<BITRES && ctx->remaining_bits > 2<<BITRES)
861       {
862          if (encode)
863             ec_enc_bit_logp(ec, inv, 2);
864          else
865             inv = ec_dec_bit_logp(ec, 2);
866       } else
867          inv = 0;
868       /* inv flag override to avoid problems with downmixing. */
869       if (ctx->disable_inv)
870          inv = 0;
871       itheta = 0;
872    }
873    qalloc = ec_tell_frac(ec) - tell;
874    *b -= qalloc;
875
876    if (itheta == 0)
877    {
878       imid = 32767;
879       iside = 0;
880       *fill &= (1<<B)-1;
881       delta = -16384;
882    } else if (itheta == 16384)
883    {
884       imid = 0;
885       iside = 32767;
886       *fill &= ((1<<B)-1)<<B;
887       delta = 16384;
888    } else {
889       imid = bitexact_cos((opus_int16)itheta);
890       iside = bitexact_cos((opus_int16)(16384-itheta));
891       /* This is the mid vs side allocation that minimizes squared error
892          in that band. */
893       delta = FRAC_MUL16((N-1)<<7,bitexact_log2tan(iside,imid));
894    }
895
896    sctx->inv = inv;
897    sctx->imid = imid;
898    sctx->iside = iside;
899    sctx->delta = delta;
900    sctx->itheta = itheta;
901    sctx->qalloc = qalloc;
902 }
903 static unsigned quant_band_n1(struct band_ctx *ctx, celt_norm *X, celt_norm *Y, int b,
904       celt_norm *lowband_out)
905 {
906    int c;
907    int stereo;
908    celt_norm *x = X;
909    int encode;
910    ec_ctx *ec;
911
912    encode = ctx->encode;
913    ec = ctx->ec;
914
915    stereo = Y != NULL;
916    c=0; do {
917       int sign=0;
918       if (ctx->remaining_bits>=1<<BITRES)
919       {
920          if (encode)
921          {
922             sign = x[0]<0;
923             ec_enc_bits(ec, sign, 1);
924          } else {
925             sign = ec_dec_bits(ec, 1);
926          }
927          ctx->remaining_bits -= 1<<BITRES;
928          b-=1<<BITRES;
929       }
930       if (ctx->resynth)
931          x[0] = sign ? -NORM_SCALING : NORM_SCALING;
932       x = Y;
933    } while (++c<1+stereo);
934    if (lowband_out)
935       lowband_out[0] = SHR16(X[0],4);
936    return 1;
937 }
938
939 /* This function is responsible for encoding and decoding a mono partition.
940    It can split the band in two and transmit the energy difference with
941    the two half-bands. It can be called recursively so bands can end up being
942    split in 8 parts. */
943 static unsigned quant_partition(struct band_ctx *ctx, celt_norm *X,
944       int N, int b, int B, celt_norm *lowband,
945       int LM,
946       opus_val16 gain, int fill)
947 {
948    const unsigned char *cache;
949    int q;
950    int curr_bits;
951    int imid=0, iside=0;
952    int B0=B;
953    opus_val16 mid=0, side=0;
954    unsigned cm=0;
955    celt_norm *Y=NULL;
956    int encode;
957    const CELTMode *m;
958    int i;
959    int spread;
960    ec_ctx *ec;
961
962    encode = ctx->encode;
963    m = ctx->m;
964    i = ctx->i;
965    spread = ctx->spread;
966    ec = ctx->ec;
967
968    /* If we need 1.5 more bit than we can produce, split the band in two. */
969    cache = m->cache.bits + m->cache.index[(LM+1)*m->nbEBands+i];
970    if (LM != -1 && b > cache[cache[0]]+12 && N>2)
971    {
972       int mbits, sbits, delta;
973       int itheta;
974       int qalloc;
975       struct split_ctx sctx;
976       celt_norm *next_lowband2=NULL;
977       opus_int32 rebalance;
978
979       N >>= 1;
980       Y = X+N;
981       LM -= 1;
982       if (B==1)
983          fill = (fill&1)|(fill<<1);
984       B = (B+1)>>1;
985
986       compute_theta(ctx, &sctx, X, Y, N, &b, B, B0, LM, 0, &fill);
987       imid = sctx.imid;
988       iside = sctx.iside;
989       delta = sctx.delta;
990       itheta = sctx.itheta;
991       qalloc = sctx.qalloc;
992 #ifdef FIXED_POINT
993       mid = imid;
994       side = iside;
995 #else
996       mid = (1.f/32768)*imid;
997       side = (1.f/32768)*iside;
998 #endif
999
1000       /* Give more bits to low-energy MDCTs than they would otherwise deserve */
1001       if (B0>1 && (itheta&0x3fff))
1002       {
1003          if (itheta > 8192)
1004             /* Rough approximation for pre-echo masking */
1005             delta -= delta>>(4-LM);
1006          else
1007             /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
1008             delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
1009       }
1010       mbits = IMAX(0, IMIN(b, (b-delta)/2));
1011       sbits = b-mbits;
1012       ctx->remaining_bits -= qalloc;
1013
1014       if (lowband)
1015          next_lowband2 = lowband+N; /* >32-bit split case */
1016
1017       rebalance = ctx->remaining_bits;
1018       if (mbits >= sbits)
1019       {
1020          cm = quant_partition(ctx, X, N, mbits, B, lowband, LM,
1021                MULT16_16_P15(gain,mid), fill);
1022          rebalance = mbits - (rebalance-ctx->remaining_bits);
1023          if (rebalance > 3<<BITRES && itheta!=0)
1024             sbits += rebalance - (3<<BITRES);
1025          cm |= quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1026                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1027       } else {
1028          cm = quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1029                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1030          rebalance = sbits - (rebalance-ctx->remaining_bits);
1031          if (rebalance > 3<<BITRES && itheta!=16384)
1032             mbits += rebalance - (3<<BITRES);
1033          cm |= quant_partition(ctx, X, N, mbits, B, lowband, LM,
1034                MULT16_16_P15(gain,mid), fill);
1035       }
1036    } else {
1037       /* This is the basic no-split case */
1038       q = bits2pulses(m, i, LM, b);
1039       curr_bits = pulses2bits(m, i, LM, q);
1040       ctx->remaining_bits -= curr_bits;
1041
1042       /* Ensures we can never bust the budget */
1043       while (ctx->remaining_bits < 0 && q > 0)
1044       {
1045          ctx->remaining_bits += curr_bits;
1046          q--;
1047          curr_bits = pulses2bits(m, i, LM, q);
1048          ctx->remaining_bits -= curr_bits;
1049       }
1050
1051       if (q!=0)
1052       {
1053          int K = get_pulses(q);
1054
1055          /* Finally do the actual quantization */
1056          if (encode)
1057          {
1058             cm = alg_quant(X, N, K, spread, B, ec, gain, ctx->resynth, ctx->arch);
1059          } else {
1060             cm = alg_unquant(X, N, K, spread, B, ec, gain);
1061          }
1062       } else {
1063          /* If there's no pulse, fill the band anyway */
1064          int j;
1065          if (ctx->resynth)
1066          {
1067             unsigned cm_mask;
1068             /* B can be as large as 16, so this shift might overflow an int on a
1069                16-bit platform; use a long to get defined behavior.*/
1070             cm_mask = (unsigned)(1UL<<B)-1;
1071             fill &= cm_mask;
1072             if (!fill)
1073             {
1074                OPUS_CLEAR(X, N);
1075             } else {
1076                if (lowband == NULL)
1077                {
1078                   /* Noise */
1079                   for (j=0;j<N;j++)
1080                   {
1081                      ctx->seed = celt_lcg_rand(ctx->seed);
1082                      X[j] = (celt_norm)((opus_int32)ctx->seed>>20);
1083                   }
1084                   cm = cm_mask;
1085                } else {
1086                   /* Folded spectrum */
1087                   for (j=0;j<N;j++)
1088                   {
1089                      opus_val16 tmp;
1090                      ctx->seed = celt_lcg_rand(ctx->seed);
1091                      /* About 48 dB below the "normal" folding level */
1092                      tmp = QCONST16(1.0f/256, 10);
1093                      tmp = (ctx->seed)&0x8000 ? tmp : -tmp;
1094                      X[j] = lowband[j]+tmp;
1095                   }
1096                   cm = fill;
1097                }
1098                renormalise_vector(X, N, gain, ctx->arch);
1099             }
1100          }
1101       }
1102    }
1103
1104    return cm;
1105 }
1106
1107
1108 /* This function is responsible for encoding and decoding a band for the mono case. */
1109 static unsigned quant_band(struct band_ctx *ctx, celt_norm *X,
1110       int N, int b, int B, celt_norm *lowband,
1111       int LM, celt_norm *lowband_out,
1112       opus_val16 gain, celt_norm *lowband_scratch, int fill)
1113 {
1114    int N0=N;
1115    int N_B=N;
1116    int N_B0;
1117    int B0=B;
1118    int time_divide=0;
1119    int recombine=0;
1120    int longBlocks;
1121    unsigned cm=0;
1122    int k;
1123    int encode;
1124    int tf_change;
1125
1126    encode = ctx->encode;
1127    tf_change = ctx->tf_change;
1128
1129    longBlocks = B0==1;
1130
1131    N_B = celt_udiv(N_B, B);
1132
1133    /* Special case for one sample */
1134    if (N==1)
1135    {
1136       return quant_band_n1(ctx, X, NULL, b, lowband_out);
1137    }
1138
1139    if (tf_change>0)
1140       recombine = tf_change;
1141    /* Band recombining to increase frequency resolution */
1142
1143    if (lowband_scratch && lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
1144    {
1145       OPUS_COPY(lowband_scratch, lowband, N);
1146       lowband = lowband_scratch;
1147    }
1148
1149    for (k=0;k<recombine;k++)
1150    {
1151       static const unsigned char bit_interleave_table[16]={
1152             0,1,1,1,2,3,3,3,2,3,3,3,2,3,3,3
1153       };
1154       if (encode)
1155          haar1(X, N>>k, 1<<k);
1156       if (lowband)
1157          haar1(lowband, N>>k, 1<<k);
1158       fill = bit_interleave_table[fill&0xF]|bit_interleave_table[fill>>4]<<2;
1159    }
1160    B>>=recombine;
1161    N_B<<=recombine;
1162
1163    /* Increasing the time resolution */
1164    while ((N_B&1) == 0 && tf_change<0)
1165    {
1166       if (encode)
1167          haar1(X, N_B, B);
1168       if (lowband)
1169          haar1(lowband, N_B, B);
1170       fill |= fill<<B;
1171       B <<= 1;
1172       N_B >>= 1;
1173       time_divide++;
1174       tf_change++;
1175    }
1176    B0=B;
1177    N_B0 = N_B;
1178
1179    /* Reorganize the samples in time order instead of frequency order */
1180    if (B0>1)
1181    {
1182       if (encode)
1183          deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1184       if (lowband)
1185          deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
1186    }
1187
1188    cm = quant_partition(ctx, X, N, b, B, lowband, LM, gain, fill);
1189
1190    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1191    if (ctx->resynth)
1192    {
1193       /* Undo the sample reorganization going from time order to frequency order */
1194       if (B0>1)
1195          interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1196
1197       /* Undo time-freq changes that we did earlier */
1198       N_B = N_B0;
1199       B = B0;
1200       for (k=0;k<time_divide;k++)
1201       {
1202          B >>= 1;
1203          N_B <<= 1;
1204          cm |= cm>>B;
1205          haar1(X, N_B, B);
1206       }
1207
1208       for (k=0;k<recombine;k++)
1209       {
1210          static const unsigned char bit_deinterleave_table[16]={
1211                0x00,0x03,0x0C,0x0F,0x30,0x33,0x3C,0x3F,
1212                0xC0,0xC3,0xCC,0xCF,0xF0,0xF3,0xFC,0xFF
1213          };
1214          cm = bit_deinterleave_table[cm];
1215          haar1(X, N0>>k, 1<<k);
1216       }
1217       B<<=recombine;
1218
1219       /* Scale output for later folding */
1220       if (lowband_out)
1221       {
1222          int j;
1223          opus_val16 n;
1224          n = celt_sqrt(SHL32(EXTEND32(N0),22));
1225          for (j=0;j<N0;j++)
1226             lowband_out[j] = MULT16_16_Q15(n,X[j]);
1227       }
1228       cm &= (1<<B)-1;
1229    }
1230    return cm;
1231 }
1232
1233
1234 /* This function is responsible for encoding and decoding a band for the stereo case. */
1235 static unsigned quant_band_stereo(struct band_ctx *ctx, celt_norm *X, celt_norm *Y,
1236       int N, int b, int B, celt_norm *lowband,
1237       int LM, celt_norm *lowband_out,
1238       celt_norm *lowband_scratch, int fill)
1239 {
1240    int imid=0, iside=0;
1241    int inv = 0;
1242    opus_val16 mid=0, side=0;
1243    unsigned cm=0;
1244    int mbits, sbits, delta;
1245    int itheta;
1246    int qalloc;
1247    struct split_ctx sctx;
1248    int orig_fill;
1249    int encode;
1250    ec_ctx *ec;
1251
1252    encode = ctx->encode;
1253    ec = ctx->ec;
1254
1255    /* Special case for one sample */
1256    if (N==1)
1257    {
1258       return quant_band_n1(ctx, X, Y, b, lowband_out);
1259    }
1260
1261    orig_fill = fill;
1262
1263    compute_theta(ctx, &sctx, X, Y, N, &b, B, B, LM, 1, &fill);
1264    inv = sctx.inv;
1265    imid = sctx.imid;
1266    iside = sctx.iside;
1267    delta = sctx.delta;
1268    itheta = sctx.itheta;
1269    qalloc = sctx.qalloc;
1270 #ifdef FIXED_POINT
1271    mid = imid;
1272    side = iside;
1273 #else
1274    mid = (1.f/32768)*imid;
1275    side = (1.f/32768)*iside;
1276 #endif
1277
1278    /* This is a special case for N=2 that only works for stereo and takes
1279       advantage of the fact that mid and side are orthogonal to encode
1280       the side with just one bit. */
1281    if (N==2)
1282    {
1283       int c;
1284       int sign=0;
1285       celt_norm *x2, *y2;
1286       mbits = b;
1287       sbits = 0;
1288       /* Only need one bit for the side. */
1289       if (itheta != 0 && itheta != 16384)
1290          sbits = 1<<BITRES;
1291       mbits -= sbits;
1292       c = itheta > 8192;
1293       ctx->remaining_bits -= qalloc+sbits;
1294
1295       x2 = c ? Y : X;
1296       y2 = c ? X : Y;
1297       if (sbits)
1298       {
1299          if (encode)
1300          {
1301             /* Here we only need to encode a sign for the side. */
1302             sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
1303             ec_enc_bits(ec, sign, 1);
1304          } else {
1305             sign = ec_dec_bits(ec, 1);
1306          }
1307       }
1308       sign = 1-2*sign;
1309       /* We use orig_fill here because we want to fold the side, but if
1310          itheta==16384, we'll have cleared the low bits of fill. */
1311       cm = quant_band(ctx, x2, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1312             lowband_scratch, orig_fill);
1313       /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
1314          and there's no need to worry about mixing with the other channel. */
1315       y2[0] = -sign*x2[1];
1316       y2[1] = sign*x2[0];
1317       if (ctx->resynth)
1318       {
1319          celt_norm tmp;
1320          X[0] = MULT16_16_Q15(mid, X[0]);
1321          X[1] = MULT16_16_Q15(mid, X[1]);
1322          Y[0] = MULT16_16_Q15(side, Y[0]);
1323          Y[1] = MULT16_16_Q15(side, Y[1]);
1324          tmp = X[0];
1325          X[0] = SUB16(tmp,Y[0]);
1326          Y[0] = ADD16(tmp,Y[0]);
1327          tmp = X[1];
1328          X[1] = SUB16(tmp,Y[1]);
1329          Y[1] = ADD16(tmp,Y[1]);
1330       }
1331    } else {
1332       /* "Normal" split code */
1333       opus_int32 rebalance;
1334
1335       mbits = IMAX(0, IMIN(b, (b-delta)/2));
1336       sbits = b-mbits;
1337       ctx->remaining_bits -= qalloc;
1338
1339       rebalance = ctx->remaining_bits;
1340       if (mbits >= sbits)
1341       {
1342          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1343             mid for folding later. */
1344          cm = quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1345                lowband_scratch, fill);
1346          rebalance = mbits - (rebalance-ctx->remaining_bits);
1347          if (rebalance > 3<<BITRES && itheta!=0)
1348             sbits += rebalance - (3<<BITRES);
1349
1350          /* For a stereo split, the high bits of fill are always zero, so no
1351             folding will be done to the side. */
1352          cm |= quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1353       } else {
1354          /* For a stereo split, the high bits of fill are always zero, so no
1355             folding will be done to the side. */
1356          cm = quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1357          rebalance = sbits - (rebalance-ctx->remaining_bits);
1358          if (rebalance > 3<<BITRES && itheta!=16384)
1359             mbits += rebalance - (3<<BITRES);
1360          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1361             mid for folding later. */
1362          cm |= quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1363                lowband_scratch, fill);
1364       }
1365    }
1366
1367
1368    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1369    if (ctx->resynth)
1370    {
1371       if (N!=2)
1372          stereo_merge(X, Y, mid, N, ctx->arch);
1373       if (inv)
1374       {
1375          int j;
1376          for (j=0;j<N;j++)
1377             Y[j] = -Y[j];
1378       }
1379    }
1380    return cm;
1381 }
1382
1383 static void special_hybrid_folding(const CELTMode *m, celt_norm *norm, celt_norm *norm2, int start, int M, int dual_stereo)
1384 {
1385    int n1, n2;
1386    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
1387    n1 = M*(eBands[start+1]-eBands[start]);
1388    n2 = M*(eBands[start+2]-eBands[start+1]);
1389    /* Duplicate enough of the first band folding data to be able to fold the second band.
1390       Copies no data for CELT-only mode. */
1391    OPUS_COPY(&norm[n1], &norm[2*n1 - n2], n2-n1);
1392    if (dual_stereo)
1393       OPUS_COPY(&norm2[n1], &norm2[2*n1 - n2], n2-n1);
1394 }
1395
1396 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1397       celt_norm *X_, celt_norm *Y_, unsigned char *collapse_masks,
1398       const celt_ener *bandE, int *pulses, int shortBlocks, int spread,
1399       int dual_stereo, int intensity, int *tf_res, opus_int32 total_bits,
1400       opus_int32 balance, ec_ctx *ec, int LM, int codedBands,
1401       opus_uint32 *seed, int complexity, int arch, int disable_inv)
1402 {
1403    int i;
1404    opus_int32 remaining_bits;
1405    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
1406    celt_norm * OPUS_RESTRICT norm, * OPUS_RESTRICT norm2;
1407    VARDECL(celt_norm, _norm);
1408    VARDECL(celt_norm, _lowband_scratch);
1409    VARDECL(celt_norm, X_save);
1410    VARDECL(celt_norm, Y_save);
1411    VARDECL(celt_norm, X_save2);
1412    VARDECL(celt_norm, Y_save2);
1413    VARDECL(celt_norm, norm_save2);
1414    int resynth_alloc;
1415    celt_norm *lowband_scratch;
1416    int B;
1417    int M;
1418    int lowband_offset;
1419    int update_lowband = 1;
1420    int C = Y_ != NULL ? 2 : 1;
1421    int norm_offset;
1422    int theta_rdo = encode && Y_!=NULL && !dual_stereo && complexity>=8;
1423 #ifdef RESYNTH
1424    int resynth = 1;
1425 #else
1426    int resynth = !encode || theta_rdo;
1427 #endif
1428    struct band_ctx ctx;
1429    SAVE_STACK;
1430
1431    M = 1<<LM;
1432    B = shortBlocks ? M : 1;
1433    norm_offset = M*eBands[start];
1434    /* No need to allocate norm for the last band because we don't need an
1435       output in that band. */
1436    ALLOC(_norm, C*(M*eBands[m->nbEBands-1]-norm_offset), celt_norm);
1437    norm = _norm;
1438    norm2 = norm + M*eBands[m->nbEBands-1]-norm_offset;
1439
1440    /* For decoding, we can use the last band as scratch space because we don't need that
1441       scratch space for the last band and we don't care about the data there until we're
1442       decoding the last band. */
1443    if (encode && resynth)
1444       resynth_alloc = M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]);
1445    else
1446       resynth_alloc = ALLOC_NONE;
1447    ALLOC(_lowband_scratch, resynth_alloc, celt_norm);
1448    if (encode && resynth)
1449       lowband_scratch = _lowband_scratch;
1450    else
1451       lowband_scratch = X_+M*eBands[m->nbEBands-1];
1452    ALLOC(X_save, resynth_alloc, celt_norm);
1453    ALLOC(Y_save, resynth_alloc, celt_norm);
1454    ALLOC(X_save2, resynth_alloc, celt_norm);
1455    ALLOC(Y_save2, resynth_alloc, celt_norm);
1456    ALLOC(norm_save2, resynth_alloc, celt_norm);
1457
1458    lowband_offset = 0;
1459    ctx.bandE = bandE;
1460    ctx.ec = ec;
1461    ctx.encode = encode;
1462    ctx.intensity = intensity;
1463    ctx.m = m;
1464    ctx.seed = *seed;
1465    ctx.spread = spread;
1466    ctx.arch = arch;
1467    ctx.disable_inv = disable_inv;
1468    ctx.resynth = resynth;
1469    ctx.theta_round = 0;
1470    /* Avoid injecting noise in the first band on transients. */
1471    ctx.avoid_split_noise = B > 1;
1472    for (i=start;i<end;i++)
1473    {
1474       opus_int32 tell;
1475       int b;
1476       int N;
1477       opus_int32 curr_balance;
1478       int effective_lowband=-1;
1479       celt_norm * OPUS_RESTRICT X, * OPUS_RESTRICT Y;
1480       int tf_change=0;
1481       unsigned x_cm;
1482       unsigned y_cm;
1483       int last;
1484
1485       ctx.i = i;
1486       last = (i==end-1);
1487
1488       X = X_+M*eBands[i];
1489       if (Y_!=NULL)
1490          Y = Y_+M*eBands[i];
1491       else
1492          Y = NULL;
1493       N = M*eBands[i+1]-M*eBands[i];
1494       tell = ec_tell_frac(ec);
1495
1496       /* Compute how many bits we want to allocate to this band */
1497       if (i != start)
1498          balance -= tell;
1499       remaining_bits = total_bits-tell-1;
1500       ctx.remaining_bits = remaining_bits;
1501       if (i <= codedBands-1)
1502       {
1503          curr_balance = celt_sudiv(balance, IMIN(3, codedBands-i));
1504          b = IMAX(0, IMIN(16383, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1505       } else {
1506          b = 0;
1507       }
1508
1509 #ifdef ENABLE_UPDATE_DRAFT
1510       if (resynth && (M*eBands[i]-N >= M*eBands[start] || i==start+1) && (update_lowband || lowband_offset==0))
1511             lowband_offset = i;
1512       if (i == start+1)
1513          special_hybrid_folding(m, norm, norm2, start, M, dual_stereo);
1514 #else
1515       if (resynth && M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1516             lowband_offset = i;
1517 #endif
1518
1519       tf_change = tf_res[i];
1520       ctx.tf_change = tf_change;
1521       if (i>=m->effEBands)
1522       {
1523          X=norm;
1524          if (Y_!=NULL)
1525             Y = norm;
1526          lowband_scratch = NULL;
1527       }
1528       if (last && !theta_rdo)
1529          lowband_scratch = NULL;
1530
1531       /* Get a conservative estimate of the collapse_mask's for the bands we're
1532          going to be folding from. */
1533       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1 || tf_change<0))
1534       {
1535          int fold_start;
1536          int fold_end;
1537          int fold_i;
1538          /* This ensures we never repeat spectral content within one band */
1539          effective_lowband = IMAX(0, M*eBands[lowband_offset]-norm_offset-N);
1540          fold_start = lowband_offset;
1541          while(M*eBands[--fold_start] > effective_lowband+norm_offset);
1542          fold_end = lowband_offset-1;
1543 #ifdef ENABLE_UPDATE_DRAFT
1544          while(++fold_end < i && M*eBands[fold_end] < effective_lowband+norm_offset+N);
1545 #else
1546          while(M*eBands[++fold_end] < effective_lowband+norm_offset+N);
1547 #endif
1548          x_cm = y_cm = 0;
1549          fold_i = fold_start; do {
1550            x_cm |= collapse_masks[fold_i*C+0];
1551            y_cm |= collapse_masks[fold_i*C+C-1];
1552          } while (++fold_i<fold_end);
1553       }
1554       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1555          always) be non-zero. */
1556       else
1557          x_cm = y_cm = (1<<B)-1;
1558
1559       if (dual_stereo && i==intensity)
1560       {
1561          int j;
1562
1563          /* Switch off dual stereo to do intensity. */
1564          dual_stereo = 0;
1565          if (resynth)
1566             for (j=0;j<M*eBands[i]-norm_offset;j++)
1567                norm[j] = HALF32(norm[j]+norm2[j]);
1568       }
1569       if (dual_stereo)
1570       {
1571          x_cm = quant_band(&ctx, X, N, b/2, B,
1572                effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1573                last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm);
1574          y_cm = quant_band(&ctx, Y, N, b/2, B,
1575                effective_lowband != -1 ? norm2+effective_lowband : NULL, LM,
1576                last?NULL:norm2+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, y_cm);
1577       } else {
1578          if (Y!=NULL)
1579          {
1580             if (theta_rdo && i < intensity)
1581             {
1582                ec_ctx ec_save, ec_save2;
1583                struct band_ctx ctx_save, ctx_save2;
1584                opus_val32 dist0, dist1;
1585                unsigned cm, cm2;
1586                int nstart_bytes, nend_bytes, save_bytes;
1587                unsigned char *bytes_buf;
1588                unsigned char bytes_save[1275];
1589                opus_val16 w[2];
1590                compute_channel_weights(bandE[i], bandE[i+m->nbEBands], w);
1591                /* Make a copy. */
1592                cm = x_cm|y_cm;
1593                ec_save = *ec;
1594                ctx_save = ctx;
1595                OPUS_COPY(X_save, X, N);
1596                OPUS_COPY(Y_save, Y, N);
1597                /* Encode and round down. */
1598                ctx.theta_round = -1;
1599                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1600                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1601                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1602                dist0 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1603
1604                /* Save first result. */
1605                cm2 = x_cm;
1606                ec_save2 = *ec;
1607                ctx_save2 = ctx;
1608                OPUS_COPY(X_save2, X, N);
1609                OPUS_COPY(Y_save2, Y, N);
1610                if (!last)
1611                   OPUS_COPY(norm_save2, norm+M*eBands[i]-norm_offset, N);
1612                nstart_bytes = ec_save.offs;
1613                nend_bytes = ec_save.storage;
1614                bytes_buf = ec_save.buf+nstart_bytes;
1615                save_bytes = nend_bytes-nstart_bytes;
1616                OPUS_COPY(bytes_save, bytes_buf, save_bytes);
1617
1618                /* Restore */
1619                *ec = ec_save;
1620                ctx = ctx_save;
1621                OPUS_COPY(X, X_save, N);
1622                OPUS_COPY(Y, Y_save, N);
1623                if (i == start+1)
1624                   special_hybrid_folding(m, norm, norm2, start, M, dual_stereo);
1625                /* Encode and round up. */
1626                ctx.theta_round = 1;
1627                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1628                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1629                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1630                dist1 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1631                if (dist0 >= dist1) {
1632                   x_cm = cm2;
1633                   *ec = ec_save2;
1634                   ctx = ctx_save2;
1635                   OPUS_COPY(X, X_save2, N);
1636                   OPUS_COPY(Y, Y_save2, N);
1637                   if (!last)
1638                      OPUS_COPY(norm+M*eBands[i]-norm_offset, norm_save2, N);
1639                   OPUS_COPY(bytes_buf, bytes_save, save_bytes);
1640                }
1641             } else {
1642                ctx.theta_round = 0;
1643                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1644                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1645                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, x_cm|y_cm);
1646             }
1647          } else {
1648             x_cm = quant_band(&ctx, X, N, b, B,
1649                   effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1650                   last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm|y_cm);
1651          }
1652          y_cm = x_cm;
1653       }
1654       collapse_masks[i*C+0] = (unsigned char)x_cm;
1655       collapse_masks[i*C+C-1] = (unsigned char)y_cm;
1656       balance += pulses[i] + tell;
1657
1658       /* Update the folding position only as long as we have 1 bit/sample depth. */
1659       update_lowband = b>(N<<BITRES);
1660       /* We only need to avoid noise on a split for the first band. After that, we
1661          have folding. */
1662       ctx.avoid_split_noise = 0;
1663    }
1664    *seed = ctx.seed;
1665
1666    RESTORE_STACK;
1667 }
1668