05205554fe7658cde5b8491d3ddbffdc731c744c
[opus.git] / celt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16
17    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 #ifdef HAVE_CONFIG_H
31 #include "config.h"
32 #endif
33
34 #include <math.h>
35 #include "bands.h"
36 #include "modes.h"
37 #include "vq.h"
38 #include "cwrs.h"
39 #include "stack_alloc.h"
40 #include "os_support.h"
41 #include "mathops.h"
42 #include "rate.h"
43 #include "quant_bands.h"
44 #include "pitch.h"
45
46 int hysteresis_decision(opus_val16 val, const opus_val16 *thresholds, const opus_val16 *hysteresis, int N, int prev)
47 {
48    int i;
49    for (i=0;i<N;i++)
50    {
51       if (val < thresholds[i])
52          break;
53    }
54    if (i>prev && val < thresholds[prev]+hysteresis[prev])
55       i=prev;
56    if (i<prev && val > thresholds[prev-1]-hysteresis[prev-1])
57       i=prev;
58    return i;
59 }
60
61 opus_uint32 celt_lcg_rand(opus_uint32 seed)
62 {
63    return 1664525 * seed + 1013904223;
64 }
65
66 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
67    with this approximation is important because it has an impact on the bit allocation */
68 static opus_int16 bitexact_cos(opus_int16 x)
69 {
70    opus_int32 tmp;
71    opus_int16 x2;
72    tmp = (4096+((opus_int32)(x)*(x)))>>13;
73    celt_assert(tmp<=32767);
74    x2 = tmp;
75    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
76    celt_assert(x2<=32766);
77    return 1+x2;
78 }
79
80 static int bitexact_log2tan(int isin,int icos)
81 {
82    int lc;
83    int ls;
84    lc=EC_ILOG(icos);
85    ls=EC_ILOG(isin);
86    icos<<=15-lc;
87    isin<<=15-ls;
88    return (ls-lc)*(1<<11)
89          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
90          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
91 }
92
93 #ifdef FIXED_POINT
94 /* Compute the amplitude (sqrt energy) in each of the bands */
95 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
96 {
97    int i, c, N;
98    const opus_int16 *eBands = m->eBands;
99    N = m->shortMdctSize<<LM;
100    c=0; do {
101       for (i=0;i<end;i++)
102       {
103          int j;
104          opus_val32 maxval=0;
105          opus_val32 sum = 0;
106
107          maxval = celt_maxabs32(&X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
108          if (maxval > 0)
109          {
110             int shift = celt_ilog2(maxval) - 14 + (((m->logN[i]>>BITRES)+LM+1)>>1);
111             j=eBands[i]<<LM;
112             if (shift>0)
113             {
114                do {
115                   sum = MAC16_16(sum, EXTRACT16(SHR32(X[j+c*N],shift)),
116                         EXTRACT16(SHR32(X[j+c*N],shift)));
117                } while (++j<eBands[i+1]<<LM);
118             } else {
119                do {
120                   sum = MAC16_16(sum, EXTRACT16(SHL32(X[j+c*N],-shift)),
121                         EXTRACT16(SHL32(X[j+c*N],-shift)));
122                } while (++j<eBands[i+1]<<LM);
123             }
124             /* We're adding one here to ensure the normalized band isn't larger than unity norm */
125             bandE[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
126          } else {
127             bandE[i+c*m->nbEBands] = EPSILON;
128          }
129          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
130       }
131    } while (++c<C);
132    /*printf ("\n");*/
133 }
134
135 /* Normalise each band such that the energy is one. */
136 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
137 {
138    int i, c, N;
139    const opus_int16 *eBands = m->eBands;
140    N = M*m->shortMdctSize;
141    c=0; do {
142       i=0; do {
143          opus_val16 g;
144          int j,shift;
145          opus_val16 E;
146          shift = celt_zlog2(bandE[i+c*m->nbEBands])-13;
147          E = VSHR32(bandE[i+c*m->nbEBands], shift);
148          g = EXTRACT16(celt_rcp(SHL32(E,3)));
149          j=M*eBands[i]; do {
150             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
151          } while (++j<M*eBands[i+1]);
152       } while (++i<end);
153    } while (++c<C);
154 }
155
156 #else /* FIXED_POINT */
157 /* Compute the amplitude (sqrt energy) in each of the bands */
158 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
159 {
160    int i, c, N;
161    const opus_int16 *eBands = m->eBands;
162    N = m->shortMdctSize<<LM;
163    c=0; do {
164       for (i=0;i<end;i++)
165       {
166          opus_val32 sum;
167          sum = 1e-27f + celt_inner_prod_c(&X[c*N+(eBands[i]<<LM)], &X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
168          bandE[i+c*m->nbEBands] = celt_sqrt(sum);
169          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
170       }
171    } while (++c<C);
172    /*printf ("\n");*/
173 }
174
175 /* Normalise each band such that the energy is one. */
176 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
177 {
178    int i, c, N;
179    const opus_int16 *eBands = m->eBands;
180    N = M*m->shortMdctSize;
181    c=0; do {
182       for (i=0;i<end;i++)
183       {
184          int j;
185          opus_val16 g = 1.f/(1e-27f+bandE[i+c*m->nbEBands]);
186          for (j=M*eBands[i];j<M*eBands[i+1];j++)
187             X[j+c*N] = freq[j+c*N]*g;
188       }
189    } while (++c<C);
190 }
191
192 #endif /* FIXED_POINT */
193
194 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
195 void denormalise_bands(const CELTMode *m, const celt_norm * OPUS_RESTRICT X,
196       celt_sig * OPUS_RESTRICT freq, const opus_val16 *bandLogE, int start,
197       int end, int M, int downsample, int silence)
198 {
199    int i, N;
200    int bound;
201    celt_sig * OPUS_RESTRICT f;
202    const celt_norm * OPUS_RESTRICT x;
203    const opus_int16 *eBands = m->eBands;
204    N = M*m->shortMdctSize;
205    bound = M*eBands[end];
206    if (downsample!=1)
207       bound = IMIN(bound, N/downsample);
208    if (silence)
209    {
210       bound = 0;
211       start = end = 0;
212    }
213    f = freq;
214    x = X+M*eBands[start];
215    for (i=0;i<M*eBands[start];i++)
216       *f++ = 0;
217    for (i=start;i<end;i++)
218    {
219       int j, band_end;
220       opus_val16 g;
221       opus_val16 lg;
222 #ifdef FIXED_POINT
223       int shift;
224 #endif
225       j=M*eBands[i];
226       band_end = M*eBands[i+1];
227       lg = SATURATE16(ADD32(bandLogE[i], SHL32((opus_val32)eMeans[i],6)));
228 #ifndef FIXED_POINT
229       g = celt_exp2(MIN32(32.f, lg));
230 #else
231       /* Handle the integer part of the log energy */
232       shift = 16-(lg>>DB_SHIFT);
233       if (shift>31)
234       {
235          shift=0;
236          g=0;
237       } else {
238          /* Handle the fractional part. */
239          g = celt_exp2_frac(lg&((1<<DB_SHIFT)-1));
240       }
241       /* Handle extreme gains with negative shift. */
242       if (shift<0)
243       {
244          /* For shift <= -2 and g > 16384 we'd be likely to overflow, so we're
245             capping the gain here, which is equivalent to a cap of 18 on lg.
246             This shouldn't trigger unless the bitstream is already corrupted. */
247          if (shift <= -2)
248          {
249             g = 16384;
250             shift = -2;
251          }
252          do {
253             *f++ = SHL32(MULT16_16(*x++, g), -shift);
254          } while (++j<band_end);
255       } else
256 #endif
257          /* Be careful of the fixed-point "else" just above when changing this code */
258          do {
259             *f++ = SHR32(MULT16_16(*x++, g), shift);
260          } while (++j<band_end);
261    }
262    celt_assert(start <= end);
263    OPUS_CLEAR(&freq[bound], N-bound);
264 }
265
266 /* This prevents energy collapse for transients with multiple short MDCTs */
267 void anti_collapse(const CELTMode *m, celt_norm *X_, unsigned char *collapse_masks, int LM, int C, int size,
268       int start, int end, const opus_val16 *logE, const opus_val16 *prev1logE,
269       const opus_val16 *prev2logE, const int *pulses, opus_uint32 seed, int arch)
270 {
271    int c, i, j, k;
272    for (i=start;i<end;i++)
273    {
274       int N0;
275       opus_val16 thresh, sqrt_1;
276       int depth;
277 #ifdef FIXED_POINT
278       int shift;
279       opus_val32 thresh32;
280 #endif
281
282       N0 = m->eBands[i+1]-m->eBands[i];
283       /* depth in 1/8 bits */
284       celt_assert(pulses[i]>=0);
285       depth = celt_udiv(1+pulses[i], (m->eBands[i+1]-m->eBands[i]))>>LM;
286
287 #ifdef FIXED_POINT
288       thresh32 = SHR32(celt_exp2(-SHL16(depth, 10-BITRES)),1);
289       thresh = MULT16_32_Q15(QCONST16(0.5f, 15), MIN32(32767,thresh32));
290       {
291          opus_val32 t;
292          t = N0<<LM;
293          shift = celt_ilog2(t)>>1;
294          t = SHL32(t, (7-shift)<<1);
295          sqrt_1 = celt_rsqrt_norm(t);
296       }
297 #else
298       thresh = .5f*celt_exp2(-.125f*depth);
299       sqrt_1 = celt_rsqrt(N0<<LM);
300 #endif
301
302       c=0; do
303       {
304          celt_norm *X;
305          opus_val16 prev1;
306          opus_val16 prev2;
307          opus_val32 Ediff;
308          opus_val16 r;
309          int renormalize=0;
310          prev1 = prev1logE[c*m->nbEBands+i];
311          prev2 = prev2logE[c*m->nbEBands+i];
312          if (C==1)
313          {
314             prev1 = MAX16(prev1,prev1logE[m->nbEBands+i]);
315             prev2 = MAX16(prev2,prev2logE[m->nbEBands+i]);
316          }
317          Ediff = EXTEND32(logE[c*m->nbEBands+i])-EXTEND32(MIN16(prev1,prev2));
318          Ediff = MAX32(0, Ediff);
319
320 #ifdef FIXED_POINT
321          if (Ediff < 16384)
322          {
323             opus_val32 r32 = SHR32(celt_exp2(-EXTRACT16(Ediff)),1);
324             r = 2*MIN16(16383,r32);
325          } else {
326             r = 0;
327          }
328          if (LM==3)
329             r = MULT16_16_Q14(23170, MIN32(23169, r));
330          r = SHR16(MIN16(thresh, r),1);
331          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
332 #else
333          /* r needs to be multiplied by 2 or 2*sqrt(2) depending on LM because
334             short blocks don't have the same energy as long */
335          r = 2.f*celt_exp2(-Ediff);
336          if (LM==3)
337             r *= 1.41421356f;
338          r = MIN16(thresh, r);
339          r = r*sqrt_1;
340 #endif
341          X = X_+c*size+(m->eBands[i]<<LM);
342          for (k=0;k<1<<LM;k++)
343          {
344             /* Detect collapse */
345             if (!(collapse_masks[i*C+c]&1<<k))
346             {
347                /* Fill with noise */
348                for (j=0;j<N0;j++)
349                {
350                   seed = celt_lcg_rand(seed);
351                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
352                }
353                renormalize = 1;
354             }
355          }
356          /* We just added some energy, so we need to renormalise */
357          if (renormalize)
358             renormalise_vector(X, N0<<LM, Q15ONE, arch);
359       } while (++c<C);
360    }
361 }
362
363 /* Compute the weights to use for optimizing normalized distortion across
364    channels. We use the amplitude to weight square distortion, which means
365    that we use the square root of the value we would have been using if we
366    wanted to minimize the MSE in the non-normalized domain. This roughly
367    corresponds to some quick-and-dirty perceptual experiments I ran to
368    measure inter-aural masking (there doesn't seem to be any published data
369    on the topic). */
370 static void compute_channel_weights(celt_ener Ex, celt_ener Ey, opus_val16 w[2])
371 {
372    celt_ener minE;
373 #if FIXED_POINT
374    int shift;
375 #endif
376    minE = MIN32(Ex, Ey);
377    /* Adjustment to make the weights a bit more conservative. */
378    Ex = ADD32(Ex, minE/3);
379    Ey = ADD32(Ey, minE/3);
380 #if FIXED_POINT
381    shift = celt_ilog2(EPSILON+MAX32(Ex, Ey))-14;
382 #endif
383    w[0] = VSHR32(Ex, shift);
384    w[1] = VSHR32(Ey, shift);
385 }
386
387 static void intensity_stereo(const CELTMode *m, celt_norm * OPUS_RESTRICT X, const celt_norm * OPUS_RESTRICT Y, const celt_ener *bandE, int bandID, int N)
388 {
389    int i = bandID;
390    int j;
391    opus_val16 a1, a2;
392    opus_val16 left, right;
393    opus_val16 norm;
394 #ifdef FIXED_POINT
395    int shift = celt_zlog2(MAX32(bandE[i], bandE[i+m->nbEBands]))-13;
396 #endif
397    left = VSHR32(bandE[i],shift);
398    right = VSHR32(bandE[i+m->nbEBands],shift);
399    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
400    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
401    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
402    for (j=0;j<N;j++)
403    {
404       celt_norm r, l;
405       l = X[j];
406       r = Y[j];
407       X[j] = EXTRACT16(SHR32(MAC16_16(MULT16_16(a1, l), a2, r), 14));
408       /* Side is not encoded, no need to calculate */
409    }
410 }
411
412 static void stereo_split(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, int N)
413 {
414    int j;
415    for (j=0;j<N;j++)
416    {
417       opus_val32 r, l;
418       l = MULT16_16(QCONST16(.70710678f, 15), X[j]);
419       r = MULT16_16(QCONST16(.70710678f, 15), Y[j]);
420       X[j] = EXTRACT16(SHR32(ADD32(l, r), 15));
421       Y[j] = EXTRACT16(SHR32(SUB32(r, l), 15));
422    }
423 }
424
425 static void stereo_merge(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, opus_val16 mid, int N, int arch)
426 {
427    int j;
428    opus_val32 xp=0, side=0;
429    opus_val32 El, Er;
430    opus_val16 mid2;
431 #ifdef FIXED_POINT
432    int kl, kr;
433 #endif
434    opus_val32 t, lgain, rgain;
435
436    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
437    dual_inner_prod(Y, X, Y, N, &xp, &side, arch);
438    /* Compensating for the mid normalization */
439    xp = MULT16_32_Q15(mid, xp);
440    /* mid and side are in Q15, not Q14 like X and Y */
441    mid2 = SHR16(mid, 1);
442    El = MULT16_16(mid2, mid2) + side - 2*xp;
443    Er = MULT16_16(mid2, mid2) + side + 2*xp;
444    if (Er < QCONST32(6e-4f, 28) || El < QCONST32(6e-4f, 28))
445    {
446       OPUS_COPY(Y, X, N);
447       return;
448    }
449
450 #ifdef FIXED_POINT
451    kl = celt_ilog2(El)>>1;
452    kr = celt_ilog2(Er)>>1;
453 #endif
454    t = VSHR32(El, (kl-7)<<1);
455    lgain = celt_rsqrt_norm(t);
456    t = VSHR32(Er, (kr-7)<<1);
457    rgain = celt_rsqrt_norm(t);
458
459 #ifdef FIXED_POINT
460    if (kl < 7)
461       kl = 7;
462    if (kr < 7)
463       kr = 7;
464 #endif
465
466    for (j=0;j<N;j++)
467    {
468       celt_norm r, l;
469       /* Apply mid scaling (side is already scaled) */
470       l = MULT16_16_P15(mid, X[j]);
471       r = Y[j];
472       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
473       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
474    }
475 }
476
477 /* Decide whether we should spread the pulses in the current frame */
478 int spreading_decision(const CELTMode *m, const celt_norm *X, int *average,
479       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
480       int end, int C, int M)
481 {
482    int i, c, N0;
483    int sum = 0, nbBands=0;
484    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
485    int decision;
486    int hf_sum=0;
487
488    celt_assert(end>0);
489
490    N0 = M*m->shortMdctSize;
491
492    if (M*(eBands[end]-eBands[end-1]) <= 8)
493       return SPREAD_NONE;
494    c=0; do {
495       for (i=0;i<end;i++)
496       {
497          int j, N, tmp=0;
498          int tcount[3] = {0,0,0};
499          const celt_norm * OPUS_RESTRICT x = X+M*eBands[i]+c*N0;
500          N = M*(eBands[i+1]-eBands[i]);
501          if (N<=8)
502             continue;
503          /* Compute rough CDF of |x[j]| */
504          for (j=0;j<N;j++)
505          {
506             opus_val32 x2N; /* Q13 */
507
508             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
509             if (x2N < QCONST16(0.25f,13))
510                tcount[0]++;
511             if (x2N < QCONST16(0.0625f,13))
512                tcount[1]++;
513             if (x2N < QCONST16(0.015625f,13))
514                tcount[2]++;
515          }
516
517          /* Only include four last bands (8 kHz and up) */
518          if (i>m->nbEBands-4)
519             hf_sum += celt_udiv(32*(tcount[1]+tcount[0]), N);
520          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
521          sum += tmp*256;
522          nbBands++;
523       }
524    } while (++c<C);
525
526    if (update_hf)
527    {
528       if (hf_sum)
529          hf_sum = celt_udiv(hf_sum, C*(4-m->nbEBands+end));
530       *hf_average = (*hf_average+hf_sum)>>1;
531       hf_sum = *hf_average;
532       if (*tapset_decision==2)
533          hf_sum += 4;
534       else if (*tapset_decision==0)
535          hf_sum -= 4;
536       if (hf_sum > 22)
537          *tapset_decision=2;
538       else if (hf_sum > 18)
539          *tapset_decision=1;
540       else
541          *tapset_decision=0;
542    }
543    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
544    celt_assert(nbBands>0); /* end has to be non-zero */
545    celt_assert(sum>=0);
546    sum = celt_udiv(sum, nbBands);
547    /* Recursive averaging */
548    sum = (sum+*average)>>1;
549    *average = sum;
550    /* Hysteresis */
551    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
552    if (sum < 80)
553    {
554       decision = SPREAD_AGGRESSIVE;
555    } else if (sum < 256)
556    {
557       decision = SPREAD_NORMAL;
558    } else if (sum < 384)
559    {
560       decision = SPREAD_LIGHT;
561    } else {
562       decision = SPREAD_NONE;
563    }
564 #ifdef FUZZING
565    decision = rand()&0x3;
566    *tapset_decision=rand()%3;
567 #endif
568    return decision;
569 }
570
571 /* Indexing table for converting from natural Hadamard to ordery Hadamard
572    This is essentially a bit-reversed Gray, on top of which we've added
573    an inversion of the order because we want the DC at the end rather than
574    the beginning. The lines are for N=2, 4, 8, 16 */
575 static const int ordery_table[] = {
576        1,  0,
577        3,  0,  2,  1,
578        7,  0,  4,  3,  6,  1,  5,  2,
579       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
580 };
581
582 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
583 {
584    int i,j;
585    VARDECL(celt_norm, tmp);
586    int N;
587    SAVE_STACK;
588    N = N0*stride;
589    ALLOC(tmp, N, celt_norm);
590    celt_assert(stride>0);
591    if (hadamard)
592    {
593       const int *ordery = ordery_table+stride-2;
594       for (i=0;i<stride;i++)
595       {
596          for (j=0;j<N0;j++)
597             tmp[ordery[i]*N0+j] = X[j*stride+i];
598       }
599    } else {
600       for (i=0;i<stride;i++)
601          for (j=0;j<N0;j++)
602             tmp[i*N0+j] = X[j*stride+i];
603    }
604    OPUS_COPY(X, tmp, N);
605    RESTORE_STACK;
606 }
607
608 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
609 {
610    int i,j;
611    VARDECL(celt_norm, tmp);
612    int N;
613    SAVE_STACK;
614    N = N0*stride;
615    ALLOC(tmp, N, celt_norm);
616    if (hadamard)
617    {
618       const int *ordery = ordery_table+stride-2;
619       for (i=0;i<stride;i++)
620          for (j=0;j<N0;j++)
621             tmp[j*stride+i] = X[ordery[i]*N0+j];
622    } else {
623       for (i=0;i<stride;i++)
624          for (j=0;j<N0;j++)
625             tmp[j*stride+i] = X[i*N0+j];
626    }
627    OPUS_COPY(X, tmp, N);
628    RESTORE_STACK;
629 }
630
631 void haar1(celt_norm *X, int N0, int stride)
632 {
633    int i, j;
634    N0 >>= 1;
635    for (i=0;i<stride;i++)
636       for (j=0;j<N0;j++)
637       {
638          opus_val32 tmp1, tmp2;
639          tmp1 = MULT16_16(QCONST16(.70710678f,15), X[stride*2*j+i]);
640          tmp2 = MULT16_16(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
641          X[stride*2*j+i] = EXTRACT16(PSHR32(ADD32(tmp1, tmp2), 15));
642          X[stride*(2*j+1)+i] = EXTRACT16(PSHR32(SUB32(tmp1, tmp2), 15));
643       }
644 }
645
646 static int compute_qn(int N, int b, int offset, int pulse_cap, int stereo)
647 {
648    static const opus_int16 exp2_table8[8] =
649       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
650    int qn, qb;
651    int N2 = 2*N-1;
652    if (stereo && N==2)
653       N2--;
654    /* The upper limit ensures that in a stereo split with itheta==16384, we'll
655        always have enough bits left over to code at least one pulse in the
656        side; otherwise it would collapse, since it doesn't get folded. */
657    qb = celt_sudiv(b+N2*offset, N2);
658    qb = IMIN(b-pulse_cap-(4<<BITRES), qb);
659
660    qb = IMIN(8<<BITRES, qb);
661
662    if (qb<(1<<BITRES>>1)) {
663       qn = 1;
664    } else {
665       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
666       qn = (qn+1)>>1<<1;
667    }
668    celt_assert(qn <= 256);
669    return qn;
670 }
671
672 struct band_ctx {
673    int encode;
674    int resynth;
675    const CELTMode *m;
676    int i;
677    int intensity;
678    int spread;
679    int tf_change;
680    ec_ctx *ec;
681    opus_int32 remaining_bits;
682    const celt_ener *bandE;
683    opus_uint32 seed;
684    int arch;
685    int theta_round;
686    int disable_inv;
687 };
688
689 struct split_ctx {
690    int inv;
691    int imid;
692    int iside;
693    int delta;
694    int itheta;
695    int qalloc;
696 };
697
698 static void compute_theta(struct band_ctx *ctx, struct split_ctx *sctx,
699       celt_norm *X, celt_norm *Y, int N, int *b, int B, int B0,
700       int LM,
701       int stereo, int *fill)
702 {
703    int qn;
704    int itheta=0;
705    int delta;
706    int imid, iside;
707    int qalloc;
708    int pulse_cap;
709    int offset;
710    opus_int32 tell;
711    int inv=0;
712    int encode;
713    const CELTMode *m;
714    int i;
715    int intensity;
716    ec_ctx *ec;
717    const celt_ener *bandE;
718
719    encode = ctx->encode;
720    m = ctx->m;
721    i = ctx->i;
722    intensity = ctx->intensity;
723    ec = ctx->ec;
724    bandE = ctx->bandE;
725
726    /* Decide on the resolution to give to the split parameter theta */
727    pulse_cap = m->logN[i]+LM*(1<<BITRES);
728    offset = (pulse_cap>>1) - (stereo&&N==2 ? QTHETA_OFFSET_TWOPHASE : QTHETA_OFFSET);
729    qn = compute_qn(N, *b, offset, pulse_cap, stereo);
730    if (stereo && i>=intensity)
731       qn = 1;
732    if (encode)
733    {
734       /* theta is the atan() of the ratio between the (normalized)
735          side and mid. With just that parameter, we can re-scale both
736          mid and side because we know that 1) they have unit norm and
737          2) they are orthogonal. */
738       itheta = stereo_itheta(X, Y, stereo, N, ctx->arch);
739    }
740    tell = ec_tell_frac(ec);
741    if (qn!=1)
742    {
743       if (encode)
744       {
745          if (!stereo || ctx->theta_round == 0)
746          {
747             itheta = (itheta*(opus_int32)qn+8192)>>14;
748          } else {
749             int down;
750             /* Bias quantization towards itheta=0 and itheta=16384. */
751             int bias = itheta > 8192 ? 32767/qn : -32767/qn;
752             down = IMIN(qn-1, IMAX(0, (itheta*(opus_int32)qn + bias)>>14));
753             if (ctx->theta_round < 0)
754                itheta = down;
755             else
756                itheta = down+1;
757          }
758       }
759       /* Entropy coding of the angle. We use a uniform pdf for the
760          time split, a step for stereo, and a triangular one for the rest. */
761       if (stereo && N>2)
762       {
763          int p0 = 3;
764          int x = itheta;
765          int x0 = qn/2;
766          int ft = p0*(x0+1) + x0;
767          /* Use a probability of p0 up to itheta=8192 and then use 1 after */
768          if (encode)
769          {
770             ec_encode(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
771          } else {
772             int fs;
773             fs=ec_decode(ec,ft);
774             if (fs<(x0+1)*p0)
775                x=fs/p0;
776             else
777                x=x0+1+(fs-(x0+1)*p0);
778             ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
779             itheta = x;
780          }
781       } else if (B0>1 || stereo) {
782          /* Uniform pdf */
783          if (encode)
784             ec_enc_uint(ec, itheta, qn+1);
785          else
786             itheta = ec_dec_uint(ec, qn+1);
787       } else {
788          int fs=1, ft;
789          ft = ((qn>>1)+1)*((qn>>1)+1);
790          if (encode)
791          {
792             int fl;
793
794             fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
795             fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
796              ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
797
798             ec_encode(ec, fl, fl+fs, ft);
799          } else {
800             /* Triangular pdf */
801             int fl=0;
802             int fm;
803             fm = ec_decode(ec, ft);
804
805             if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
806             {
807                itheta = (isqrt32(8*(opus_uint32)fm + 1) - 1)>>1;
808                fs = itheta + 1;
809                fl = itheta*(itheta + 1)>>1;
810             }
811             else
812             {
813                itheta = (2*(qn + 1)
814                 - isqrt32(8*(opus_uint32)(ft - fm - 1) + 1))>>1;
815                fs = qn + 1 - itheta;
816                fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
817             }
818
819             ec_dec_update(ec, fl, fl+fs, ft);
820          }
821       }
822       celt_assert(itheta>=0);
823       itheta = celt_udiv((opus_int32)itheta*16384, qn);
824       if (encode && stereo)
825       {
826          if (itheta==0)
827             intensity_stereo(m, X, Y, bandE, i, N);
828          else
829             stereo_split(X, Y, N);
830       }
831       /* NOTE: Renormalising X and Y *may* help fixed-point a bit at very high rate.
832                Let's do that at higher complexity */
833    } else if (stereo) {
834       if (encode)
835       {
836          inv = itheta > 8192 && !ctx->disable_inv;
837          if (inv)
838          {
839             int j;
840             for (j=0;j<N;j++)
841                Y[j] = -Y[j];
842          }
843          intensity_stereo(m, X, Y, bandE, i, N);
844       }
845       if (*b>2<<BITRES && ctx->remaining_bits > 2<<BITRES)
846       {
847          if (encode)
848             ec_enc_bit_logp(ec, inv, 2);
849          else
850             inv = ec_dec_bit_logp(ec, 2);
851       } else
852          inv = 0;
853       /* inv flag override to avoid problems with downmixing. */
854       if (ctx->disable_inv)
855          inv = 0;
856       itheta = 0;
857    }
858    qalloc = ec_tell_frac(ec) - tell;
859    *b -= qalloc;
860
861    if (itheta == 0)
862    {
863       imid = 32767;
864       iside = 0;
865       *fill &= (1<<B)-1;
866       delta = -16384;
867    } else if (itheta == 16384)
868    {
869       imid = 0;
870       iside = 32767;
871       *fill &= ((1<<B)-1)<<B;
872       delta = 16384;
873    } else {
874       imid = bitexact_cos((opus_int16)itheta);
875       iside = bitexact_cos((opus_int16)(16384-itheta));
876       /* This is the mid vs side allocation that minimizes squared error
877          in that band. */
878       delta = FRAC_MUL16((N-1)<<7,bitexact_log2tan(iside,imid));
879    }
880
881    sctx->inv = inv;
882    sctx->imid = imid;
883    sctx->iside = iside;
884    sctx->delta = delta;
885    sctx->itheta = itheta;
886    sctx->qalloc = qalloc;
887 }
888 static unsigned quant_band_n1(struct band_ctx *ctx, celt_norm *X, celt_norm *Y, int b,
889       celt_norm *lowband_out)
890 {
891    int c;
892    int stereo;
893    celt_norm *x = X;
894    int encode;
895    ec_ctx *ec;
896
897    encode = ctx->encode;
898    ec = ctx->ec;
899
900    stereo = Y != NULL;
901    c=0; do {
902       int sign=0;
903       if (ctx->remaining_bits>=1<<BITRES)
904       {
905          if (encode)
906          {
907             sign = x[0]<0;
908             ec_enc_bits(ec, sign, 1);
909          } else {
910             sign = ec_dec_bits(ec, 1);
911          }
912          ctx->remaining_bits -= 1<<BITRES;
913          b-=1<<BITRES;
914       }
915       if (ctx->resynth)
916          x[0] = sign ? -NORM_SCALING : NORM_SCALING;
917       x = Y;
918    } while (++c<1+stereo);
919    if (lowband_out)
920       lowband_out[0] = SHR16(X[0],4);
921    return 1;
922 }
923
924 /* This function is responsible for encoding and decoding a mono partition.
925    It can split the band in two and transmit the energy difference with
926    the two half-bands. It can be called recursively so bands can end up being
927    split in 8 parts. */
928 static unsigned quant_partition(struct band_ctx *ctx, celt_norm *X,
929       int N, int b, int B, celt_norm *lowband,
930       int LM,
931       opus_val16 gain, int fill)
932 {
933    const unsigned char *cache;
934    int q;
935    int curr_bits;
936    int imid=0, iside=0;
937    int B0=B;
938    opus_val16 mid=0, side=0;
939    unsigned cm=0;
940    celt_norm *Y=NULL;
941    int encode;
942    const CELTMode *m;
943    int i;
944    int spread;
945    ec_ctx *ec;
946
947    encode = ctx->encode;
948    m = ctx->m;
949    i = ctx->i;
950    spread = ctx->spread;
951    ec = ctx->ec;
952
953    /* If we need 1.5 more bit than we can produce, split the band in two. */
954    cache = m->cache.bits + m->cache.index[(LM+1)*m->nbEBands+i];
955    if (LM != -1 && b > cache[cache[0]]+12 && N>2)
956    {
957       int mbits, sbits, delta;
958       int itheta;
959       int qalloc;
960       struct split_ctx sctx;
961       celt_norm *next_lowband2=NULL;
962       opus_int32 rebalance;
963
964       N >>= 1;
965       Y = X+N;
966       LM -= 1;
967       if (B==1)
968          fill = (fill&1)|(fill<<1);
969       B = (B+1)>>1;
970
971       compute_theta(ctx, &sctx, X, Y, N, &b, B, B0, LM, 0, &fill);
972       imid = sctx.imid;
973       iside = sctx.iside;
974       delta = sctx.delta;
975       itheta = sctx.itheta;
976       qalloc = sctx.qalloc;
977 #ifdef FIXED_POINT
978       mid = imid;
979       side = iside;
980 #else
981       mid = (1.f/32768)*imid;
982       side = (1.f/32768)*iside;
983 #endif
984
985       /* Give more bits to low-energy MDCTs than they would otherwise deserve */
986       if (B0>1 && (itheta&0x3fff))
987       {
988          if (itheta > 8192)
989             /* Rough approximation for pre-echo masking */
990             delta -= delta>>(4-LM);
991          else
992             /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
993             delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
994       }
995       mbits = IMAX(0, IMIN(b, (b-delta)/2));
996       sbits = b-mbits;
997       ctx->remaining_bits -= qalloc;
998
999       if (lowband)
1000          next_lowband2 = lowband+N; /* >32-bit split case */
1001
1002       rebalance = ctx->remaining_bits;
1003       if (mbits >= sbits)
1004       {
1005          cm = quant_partition(ctx, X, N, mbits, B, lowband, LM,
1006                MULT16_16_P15(gain,mid), fill);
1007          rebalance = mbits - (rebalance-ctx->remaining_bits);
1008          if (rebalance > 3<<BITRES && itheta!=0)
1009             sbits += rebalance - (3<<BITRES);
1010          cm |= quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1011                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1012       } else {
1013          cm = quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1014                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1015          rebalance = sbits - (rebalance-ctx->remaining_bits);
1016          if (rebalance > 3<<BITRES && itheta!=16384)
1017             mbits += rebalance - (3<<BITRES);
1018          cm |= quant_partition(ctx, X, N, mbits, B, lowband, LM,
1019                MULT16_16_P15(gain,mid), fill);
1020       }
1021    } else {
1022       /* This is the basic no-split case */
1023       q = bits2pulses(m, i, LM, b);
1024       curr_bits = pulses2bits(m, i, LM, q);
1025       ctx->remaining_bits -= curr_bits;
1026
1027       /* Ensures we can never bust the budget */
1028       while (ctx->remaining_bits < 0 && q > 0)
1029       {
1030          ctx->remaining_bits += curr_bits;
1031          q--;
1032          curr_bits = pulses2bits(m, i, LM, q);
1033          ctx->remaining_bits -= curr_bits;
1034       }
1035
1036       if (q!=0)
1037       {
1038          int K = get_pulses(q);
1039
1040          /* Finally do the actual quantization */
1041          if (encode)
1042          {
1043             cm = alg_quant(X, N, K, spread, B, ec, gain, ctx->resynth, ctx->arch);
1044          } else {
1045             cm = alg_unquant(X, N, K, spread, B, ec, gain);
1046          }
1047       } else {
1048          /* If there's no pulse, fill the band anyway */
1049          int j;
1050          if (ctx->resynth)
1051          {
1052             unsigned cm_mask;
1053             /* B can be as large as 16, so this shift might overflow an int on a
1054                16-bit platform; use a long to get defined behavior.*/
1055             cm_mask = (unsigned)(1UL<<B)-1;
1056             fill &= cm_mask;
1057             if (!fill)
1058             {
1059                OPUS_CLEAR(X, N);
1060             } else {
1061                if (lowband == NULL)
1062                {
1063                   /* Noise */
1064                   for (j=0;j<N;j++)
1065                   {
1066                      ctx->seed = celt_lcg_rand(ctx->seed);
1067                      X[j] = (celt_norm)((opus_int32)ctx->seed>>20);
1068                   }
1069                   cm = cm_mask;
1070                } else {
1071                   /* Folded spectrum */
1072                   for (j=0;j<N;j++)
1073                   {
1074                      opus_val16 tmp;
1075                      ctx->seed = celt_lcg_rand(ctx->seed);
1076                      /* About 48 dB below the "normal" folding level */
1077                      tmp = QCONST16(1.0f/256, 10);
1078                      tmp = (ctx->seed)&0x8000 ? tmp : -tmp;
1079                      X[j] = lowband[j]+tmp;
1080                   }
1081                   cm = fill;
1082                }
1083                renormalise_vector(X, N, gain, ctx->arch);
1084             }
1085          }
1086       }
1087    }
1088
1089    return cm;
1090 }
1091
1092
1093 /* This function is responsible for encoding and decoding a band for the mono case. */
1094 static unsigned quant_band(struct band_ctx *ctx, celt_norm *X,
1095       int N, int b, int B, celt_norm *lowband,
1096       int LM, celt_norm *lowband_out,
1097       opus_val16 gain, celt_norm *lowband_scratch, int fill)
1098 {
1099    int N0=N;
1100    int N_B=N;
1101    int N_B0;
1102    int B0=B;
1103    int time_divide=0;
1104    int recombine=0;
1105    int longBlocks;
1106    unsigned cm=0;
1107    int k;
1108    int encode;
1109    int tf_change;
1110
1111    encode = ctx->encode;
1112    tf_change = ctx->tf_change;
1113
1114    longBlocks = B0==1;
1115
1116    N_B = celt_udiv(N_B, B);
1117
1118    /* Special case for one sample */
1119    if (N==1)
1120    {
1121       return quant_band_n1(ctx, X, NULL, b, lowband_out);
1122    }
1123
1124    if (tf_change>0)
1125       recombine = tf_change;
1126    /* Band recombining to increase frequency resolution */
1127
1128    if (lowband_scratch && lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
1129    {
1130       OPUS_COPY(lowband_scratch, lowband, N);
1131       lowband = lowband_scratch;
1132    }
1133
1134    for (k=0;k<recombine;k++)
1135    {
1136       static const unsigned char bit_interleave_table[16]={
1137             0,1,1,1,2,3,3,3,2,3,3,3,2,3,3,3
1138       };
1139       if (encode)
1140          haar1(X, N>>k, 1<<k);
1141       if (lowband)
1142          haar1(lowband, N>>k, 1<<k);
1143       fill = bit_interleave_table[fill&0xF]|bit_interleave_table[fill>>4]<<2;
1144    }
1145    B>>=recombine;
1146    N_B<<=recombine;
1147
1148    /* Increasing the time resolution */
1149    while ((N_B&1) == 0 && tf_change<0)
1150    {
1151       if (encode)
1152          haar1(X, N_B, B);
1153       if (lowband)
1154          haar1(lowband, N_B, B);
1155       fill |= fill<<B;
1156       B <<= 1;
1157       N_B >>= 1;
1158       time_divide++;
1159       tf_change++;
1160    }
1161    B0=B;
1162    N_B0 = N_B;
1163
1164    /* Reorganize the samples in time order instead of frequency order */
1165    if (B0>1)
1166    {
1167       if (encode)
1168          deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1169       if (lowband)
1170          deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
1171    }
1172
1173    cm = quant_partition(ctx, X, N, b, B, lowband, LM, gain, fill);
1174
1175    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1176    if (ctx->resynth)
1177    {
1178       /* Undo the sample reorganization going from time order to frequency order */
1179       if (B0>1)
1180          interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1181
1182       /* Undo time-freq changes that we did earlier */
1183       N_B = N_B0;
1184       B = B0;
1185       for (k=0;k<time_divide;k++)
1186       {
1187          B >>= 1;
1188          N_B <<= 1;
1189          cm |= cm>>B;
1190          haar1(X, N_B, B);
1191       }
1192
1193       for (k=0;k<recombine;k++)
1194       {
1195          static const unsigned char bit_deinterleave_table[16]={
1196                0x00,0x03,0x0C,0x0F,0x30,0x33,0x3C,0x3F,
1197                0xC0,0xC3,0xCC,0xCF,0xF0,0xF3,0xFC,0xFF
1198          };
1199          cm = bit_deinterleave_table[cm];
1200          haar1(X, N0>>k, 1<<k);
1201       }
1202       B<<=recombine;
1203
1204       /* Scale output for later folding */
1205       if (lowband_out)
1206       {
1207          int j;
1208          opus_val16 n;
1209          n = celt_sqrt(SHL32(EXTEND32(N0),22));
1210          for (j=0;j<N0;j++)
1211             lowband_out[j] = MULT16_16_Q15(n,X[j]);
1212       }
1213       cm &= (1<<B)-1;
1214    }
1215    return cm;
1216 }
1217
1218
1219 /* This function is responsible for encoding and decoding a band for the stereo case. */
1220 static unsigned quant_band_stereo(struct band_ctx *ctx, celt_norm *X, celt_norm *Y,
1221       int N, int b, int B, celt_norm *lowband,
1222       int LM, celt_norm *lowband_out,
1223       celt_norm *lowband_scratch, int fill)
1224 {
1225    int imid=0, iside=0;
1226    int inv = 0;
1227    opus_val16 mid=0, side=0;
1228    unsigned cm=0;
1229    int mbits, sbits, delta;
1230    int itheta;
1231    int qalloc;
1232    struct split_ctx sctx;
1233    int orig_fill;
1234    int encode;
1235    ec_ctx *ec;
1236
1237    encode = ctx->encode;
1238    ec = ctx->ec;
1239
1240    /* Special case for one sample */
1241    if (N==1)
1242    {
1243       return quant_band_n1(ctx, X, Y, b, lowband_out);
1244    }
1245
1246    orig_fill = fill;
1247
1248    compute_theta(ctx, &sctx, X, Y, N, &b, B, B, LM, 1, &fill);
1249    inv = sctx.inv;
1250    imid = sctx.imid;
1251    iside = sctx.iside;
1252    delta = sctx.delta;
1253    itheta = sctx.itheta;
1254    qalloc = sctx.qalloc;
1255 #ifdef FIXED_POINT
1256    mid = imid;
1257    side = iside;
1258 #else
1259    mid = (1.f/32768)*imid;
1260    side = (1.f/32768)*iside;
1261 #endif
1262
1263    /* This is a special case for N=2 that only works for stereo and takes
1264       advantage of the fact that mid and side are orthogonal to encode
1265       the side with just one bit. */
1266    if (N==2)
1267    {
1268       int c;
1269       int sign=0;
1270       celt_norm *x2, *y2;
1271       mbits = b;
1272       sbits = 0;
1273       /* Only need one bit for the side. */
1274       if (itheta != 0 && itheta != 16384)
1275          sbits = 1<<BITRES;
1276       mbits -= sbits;
1277       c = itheta > 8192;
1278       ctx->remaining_bits -= qalloc+sbits;
1279
1280       x2 = c ? Y : X;
1281       y2 = c ? X : Y;
1282       if (sbits)
1283       {
1284          if (encode)
1285          {
1286             /* Here we only need to encode a sign for the side. */
1287             sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
1288             ec_enc_bits(ec, sign, 1);
1289          } else {
1290             sign = ec_dec_bits(ec, 1);
1291          }
1292       }
1293       sign = 1-2*sign;
1294       /* We use orig_fill here because we want to fold the side, but if
1295          itheta==16384, we'll have cleared the low bits of fill. */
1296       cm = quant_band(ctx, x2, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1297             lowband_scratch, orig_fill);
1298       /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
1299          and there's no need to worry about mixing with the other channel. */
1300       y2[0] = -sign*x2[1];
1301       y2[1] = sign*x2[0];
1302       if (ctx->resynth)
1303       {
1304          celt_norm tmp;
1305          X[0] = MULT16_16_Q15(mid, X[0]);
1306          X[1] = MULT16_16_Q15(mid, X[1]);
1307          Y[0] = MULT16_16_Q15(side, Y[0]);
1308          Y[1] = MULT16_16_Q15(side, Y[1]);
1309          tmp = X[0];
1310          X[0] = SUB16(tmp,Y[0]);
1311          Y[0] = ADD16(tmp,Y[0]);
1312          tmp = X[1];
1313          X[1] = SUB16(tmp,Y[1]);
1314          Y[1] = ADD16(tmp,Y[1]);
1315       }
1316    } else {
1317       /* "Normal" split code */
1318       opus_int32 rebalance;
1319
1320       mbits = IMAX(0, IMIN(b, (b-delta)/2));
1321       sbits = b-mbits;
1322       ctx->remaining_bits -= qalloc;
1323
1324       rebalance = ctx->remaining_bits;
1325       if (mbits >= sbits)
1326       {
1327          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1328             mid for folding later. */
1329          cm = quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1330                lowband_scratch, fill);
1331          rebalance = mbits - (rebalance-ctx->remaining_bits);
1332          if (rebalance > 3<<BITRES && itheta!=0)
1333             sbits += rebalance - (3<<BITRES);
1334
1335          /* For a stereo split, the high bits of fill are always zero, so no
1336             folding will be done to the side. */
1337          cm |= quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1338       } else {
1339          /* For a stereo split, the high bits of fill are always zero, so no
1340             folding will be done to the side. */
1341          cm = quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1342          rebalance = sbits - (rebalance-ctx->remaining_bits);
1343          if (rebalance > 3<<BITRES && itheta!=16384)
1344             mbits += rebalance - (3<<BITRES);
1345          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1346             mid for folding later. */
1347          cm |= quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1348                lowband_scratch, fill);
1349       }
1350    }
1351
1352
1353    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1354    if (ctx->resynth)
1355    {
1356       if (N!=2)
1357          stereo_merge(X, Y, mid, N, ctx->arch);
1358       if (inv)
1359       {
1360          int j;
1361          for (j=0;j<N;j++)
1362             Y[j] = -Y[j];
1363       }
1364    }
1365    return cm;
1366 }
1367
1368 static void special_hybrid_folding(const CELTMode *m, celt_norm *norm, celt_norm *norm2, int start, int M, int dual_stereo)
1369 {
1370    int n1, n2;
1371    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
1372    n1 = M*(eBands[start+1]-eBands[start]);
1373    n2 = M*(eBands[start+2]-eBands[start+1]);
1374    /* Duplicate enough of the first band folding data to be able to fold the second band.
1375       Copies no data for CELT-only mode. */
1376    OPUS_COPY(&norm[n1], &norm[2*n1 - n2], n2-n1);
1377    if (dual_stereo)
1378       OPUS_COPY(&norm2[n1], &norm2[2*n1 - n2], n2-n1);
1379 }
1380
1381 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1382       celt_norm *X_, celt_norm *Y_, unsigned char *collapse_masks,
1383       const celt_ener *bandE, int *pulses, int shortBlocks, int spread,
1384       int dual_stereo, int intensity, int *tf_res, opus_int32 total_bits,
1385       opus_int32 balance, ec_ctx *ec, int LM, int codedBands,
1386       opus_uint32 *seed, int complexity, int arch, int disable_inv)
1387 {
1388    int i;
1389    opus_int32 remaining_bits;
1390    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
1391    celt_norm * OPUS_RESTRICT norm, * OPUS_RESTRICT norm2;
1392    VARDECL(celt_norm, _norm);
1393    VARDECL(celt_norm, _lowband_scratch);
1394    VARDECL(celt_norm, X_save);
1395    VARDECL(celt_norm, Y_save);
1396    VARDECL(celt_norm, X_save2);
1397    VARDECL(celt_norm, Y_save2);
1398    VARDECL(celt_norm, norm_save2);
1399    int resynth_alloc;
1400    celt_norm *lowband_scratch;
1401    int B;
1402    int M;
1403    int lowband_offset;
1404    int update_lowband = 1;
1405    int C = Y_ != NULL ? 2 : 1;
1406    int norm_offset;
1407    int theta_rdo = encode && Y_!=NULL && !dual_stereo && complexity>=8;
1408 #ifdef RESYNTH
1409    int resynth = 1;
1410 #else
1411    int resynth = !encode || theta_rdo;
1412 #endif
1413    struct band_ctx ctx;
1414    SAVE_STACK;
1415
1416    M = 1<<LM;
1417    B = shortBlocks ? M : 1;
1418    norm_offset = M*eBands[start];
1419    /* No need to allocate norm for the last band because we don't need an
1420       output in that band. */
1421    ALLOC(_norm, C*(M*eBands[m->nbEBands-1]-norm_offset), celt_norm);
1422    norm = _norm;
1423    norm2 = norm + M*eBands[m->nbEBands-1]-norm_offset;
1424
1425    /* For decoding, we can use the last band as scratch space because we don't need that
1426       scratch space for the last band and we don't care about the data there until we're
1427       decoding the last band. */
1428    if (encode && resynth)
1429       resynth_alloc = M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]);
1430    else
1431       resynth_alloc = ALLOC_NONE;
1432    ALLOC(_lowband_scratch, resynth_alloc, celt_norm);
1433    if (encode && resynth)
1434       lowband_scratch = _lowband_scratch;
1435    else
1436       lowband_scratch = X_+M*eBands[m->nbEBands-1];
1437    ALLOC(X_save, resynth_alloc, celt_norm);
1438    ALLOC(Y_save, resynth_alloc, celt_norm);
1439    ALLOC(X_save2, resynth_alloc, celt_norm);
1440    ALLOC(Y_save2, resynth_alloc, celt_norm);
1441    ALLOC(norm_save2, resynth_alloc, celt_norm);
1442
1443    lowband_offset = 0;
1444    ctx.bandE = bandE;
1445    ctx.ec = ec;
1446    ctx.encode = encode;
1447    ctx.intensity = intensity;
1448    ctx.m = m;
1449    ctx.seed = *seed;
1450    ctx.spread = spread;
1451    ctx.arch = arch;
1452    ctx.disable_inv = disable_inv;
1453    ctx.resynth = resynth;
1454    ctx.theta_round = 0;
1455    for (i=start;i<end;i++)
1456    {
1457       opus_int32 tell;
1458       int b;
1459       int N;
1460       opus_int32 curr_balance;
1461       int effective_lowband=-1;
1462       celt_norm * OPUS_RESTRICT X, * OPUS_RESTRICT Y;
1463       int tf_change=0;
1464       unsigned x_cm;
1465       unsigned y_cm;
1466       int last;
1467
1468       ctx.i = i;
1469       last = (i==end-1);
1470
1471       X = X_+M*eBands[i];
1472       if (Y_!=NULL)
1473          Y = Y_+M*eBands[i];
1474       else
1475          Y = NULL;
1476       N = M*eBands[i+1]-M*eBands[i];
1477       tell = ec_tell_frac(ec);
1478
1479       /* Compute how many bits we want to allocate to this band */
1480       if (i != start)
1481          balance -= tell;
1482       remaining_bits = total_bits-tell-1;
1483       ctx.remaining_bits = remaining_bits;
1484       if (i <= codedBands-1)
1485       {
1486          curr_balance = celt_sudiv(balance, IMIN(3, codedBands-i));
1487          b = IMAX(0, IMIN(16383, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1488       } else {
1489          b = 0;
1490       }
1491
1492 #ifdef ENABLE_UPDATE_DRAFT
1493       if (resynth && (M*eBands[i]-N >= M*eBands[start] || i==start+1) && (update_lowband || lowband_offset==0))
1494             lowband_offset = i;
1495       if (i == start+1)
1496          special_hybrid_folding(m, norm, norm2, start, M, dual_stereo);
1497 #else
1498       if (resynth && M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1499             lowband_offset = i;
1500 #endif
1501
1502       tf_change = tf_res[i];
1503       ctx.tf_change = tf_change;
1504       if (i>=m->effEBands)
1505       {
1506          X=norm;
1507          if (Y_!=NULL)
1508             Y = norm;
1509          lowband_scratch = NULL;
1510       }
1511       if (last && !theta_rdo)
1512          lowband_scratch = NULL;
1513
1514       /* Get a conservative estimate of the collapse_mask's for the bands we're
1515          going to be folding from. */
1516       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1 || tf_change<0))
1517       {
1518          int fold_start;
1519          int fold_end;
1520          int fold_i;
1521          /* This ensures we never repeat spectral content within one band */
1522          effective_lowband = IMAX(0, M*eBands[lowband_offset]-norm_offset-N);
1523          fold_start = lowband_offset;
1524          while(M*eBands[--fold_start] > effective_lowband+norm_offset);
1525          fold_end = lowband_offset-1;
1526 #ifdef ENABLE_UPDATE_DRAFT
1527          while(++fold_end < i && M*eBands[fold_end] < effective_lowband+norm_offset+N);
1528 #else
1529          while(M*eBands[++fold_end] < effective_lowband+norm_offset+N);
1530 #endif
1531          x_cm = y_cm = 0;
1532          fold_i = fold_start; do {
1533            x_cm |= collapse_masks[fold_i*C+0];
1534            y_cm |= collapse_masks[fold_i*C+C-1];
1535          } while (++fold_i<fold_end);
1536       }
1537       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1538          always) be non-zero. */
1539       else
1540          x_cm = y_cm = (1<<B)-1;
1541
1542       if (dual_stereo && i==intensity)
1543       {
1544          int j;
1545
1546          /* Switch off dual stereo to do intensity. */
1547          dual_stereo = 0;
1548          if (resynth)
1549             for (j=0;j<M*eBands[i]-norm_offset;j++)
1550                norm[j] = HALF32(norm[j]+norm2[j]);
1551       }
1552       if (dual_stereo)
1553       {
1554          x_cm = quant_band(&ctx, X, N, b/2, B,
1555                effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1556                last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm);
1557          y_cm = quant_band(&ctx, Y, N, b/2, B,
1558                effective_lowband != -1 ? norm2+effective_lowband : NULL, LM,
1559                last?NULL:norm2+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, y_cm);
1560       } else {
1561          if (Y!=NULL)
1562          {
1563             if (theta_rdo && i < intensity)
1564             {
1565                ec_ctx ec_save, ec_save2;
1566                struct band_ctx ctx_save, ctx_save2;
1567                opus_val32 dist0, dist1;
1568                unsigned cm, cm2;
1569                int nstart_bytes, nend_bytes, save_bytes;
1570                unsigned char *bytes_buf;
1571                unsigned char bytes_save[1275];
1572                opus_val16 w[2];
1573                compute_channel_weights(bandE[i], bandE[i+m->nbEBands], w);
1574                /* Make a copy. */
1575                cm = x_cm|y_cm;
1576                ec_save = *ec;
1577                ctx_save = ctx;
1578                OPUS_COPY(X_save, X, N);
1579                OPUS_COPY(Y_save, Y, N);
1580                /* Encode and round down. */
1581                ctx.theta_round = -1;
1582                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1583                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1584                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1585                dist0 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1586
1587                /* Save first result. */
1588                cm2 = x_cm;
1589                ec_save2 = *ec;
1590                ctx_save2 = ctx;
1591                OPUS_COPY(X_save2, X, N);
1592                OPUS_COPY(Y_save2, Y, N);
1593                if (!last)
1594                   OPUS_COPY(norm_save2, norm+M*eBands[i]-norm_offset, N);
1595                nstart_bytes = ec_save.offs;
1596                nend_bytes = ec_save.storage;
1597                bytes_buf = ec_save.buf+nstart_bytes;
1598                save_bytes = nend_bytes-nstart_bytes;
1599                OPUS_COPY(bytes_save, bytes_buf, save_bytes);
1600
1601                /* Restore */
1602                *ec = ec_save;
1603                ctx = ctx_save;
1604                OPUS_COPY(X, X_save, N);
1605                OPUS_COPY(Y, Y_save, N);
1606                if (i == start+1)
1607                   special_hybrid_folding(m, norm, norm2, start, M, dual_stereo);
1608                /* Encode and round up. */
1609                ctx.theta_round = 1;
1610                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1611                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1612                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1613                dist1 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1614                if (dist0 >= dist1) {
1615                   x_cm = cm2;
1616                   *ec = ec_save2;
1617                   ctx = ctx_save2;
1618                   OPUS_COPY(X, X_save2, N);
1619                   OPUS_COPY(Y, Y_save2, N);
1620                   if (!last)
1621                      OPUS_COPY(norm+M*eBands[i]-norm_offset, norm_save2, N);
1622                   OPUS_COPY(bytes_buf, bytes_save, save_bytes);
1623                }
1624             } else {
1625                ctx.theta_round = 0;
1626                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1627                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1628                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, x_cm|y_cm);
1629             }
1630          } else {
1631             x_cm = quant_band(&ctx, X, N, b, B,
1632                   effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1633                   last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm|y_cm);
1634          }
1635          y_cm = x_cm;
1636       }
1637       collapse_masks[i*C+0] = (unsigned char)x_cm;
1638       collapse_masks[i*C+C-1] = (unsigned char)y_cm;
1639       balance += pulses[i] + tell;
1640
1641       /* Update the folding position only as long as we have 1 bit/sample depth. */
1642       update_lowband = b>(N<<BITRES);
1643    }
1644    *seed = ctx.seed;
1645
1646    RESTORE_STACK;
1647 }
1648