Weighting theta rdo based on channel energy
[opus.git] / celt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16
17    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 #ifdef HAVE_CONFIG_H
31 #include "config.h"
32 #endif
33
34 #include <math.h>
35 #include "bands.h"
36 #include "modes.h"
37 #include "vq.h"
38 #include "cwrs.h"
39 #include "stack_alloc.h"
40 #include "os_support.h"
41 #include "mathops.h"
42 #include "rate.h"
43 #include "quant_bands.h"
44 #include "pitch.h"
45
46 int hysteresis_decision(opus_val16 val, const opus_val16 *thresholds, const opus_val16 *hysteresis, int N, int prev)
47 {
48    int i;
49    for (i=0;i<N;i++)
50    {
51       if (val < thresholds[i])
52          break;
53    }
54    if (i>prev && val < thresholds[prev]+hysteresis[prev])
55       i=prev;
56    if (i<prev && val > thresholds[prev-1]-hysteresis[prev-1])
57       i=prev;
58    return i;
59 }
60
61 opus_uint32 celt_lcg_rand(opus_uint32 seed)
62 {
63    return 1664525 * seed + 1013904223;
64 }
65
66 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
67    with this approximation is important because it has an impact on the bit allocation */
68 static opus_int16 bitexact_cos(opus_int16 x)
69 {
70    opus_int32 tmp;
71    opus_int16 x2;
72    tmp = (4096+((opus_int32)(x)*(x)))>>13;
73    celt_assert(tmp<=32767);
74    x2 = tmp;
75    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
76    celt_assert(x2<=32766);
77    return 1+x2;
78 }
79
80 static int bitexact_log2tan(int isin,int icos)
81 {
82    int lc;
83    int ls;
84    lc=EC_ILOG(icos);
85    ls=EC_ILOG(isin);
86    icos<<=15-lc;
87    isin<<=15-ls;
88    return (ls-lc)*(1<<11)
89          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
90          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
91 }
92
93 #ifdef FIXED_POINT
94 /* Compute the amplitude (sqrt energy) in each of the bands */
95 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
96 {
97    int i, c, N;
98    const opus_int16 *eBands = m->eBands;
99    N = m->shortMdctSize<<LM;
100    c=0; do {
101       for (i=0;i<end;i++)
102       {
103          int j;
104          opus_val32 maxval=0;
105          opus_val32 sum = 0;
106
107          maxval = celt_maxabs32(&X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
108          if (maxval > 0)
109          {
110             int shift = celt_ilog2(maxval) - 14 + (((m->logN[i]>>BITRES)+LM+1)>>1);
111             j=eBands[i]<<LM;
112             if (shift>0)
113             {
114                do {
115                   sum = MAC16_16(sum, EXTRACT16(SHR32(X[j+c*N],shift)),
116                         EXTRACT16(SHR32(X[j+c*N],shift)));
117                } while (++j<eBands[i+1]<<LM);
118             } else {
119                do {
120                   sum = MAC16_16(sum, EXTRACT16(SHL32(X[j+c*N],-shift)),
121                         EXTRACT16(SHL32(X[j+c*N],-shift)));
122                } while (++j<eBands[i+1]<<LM);
123             }
124             /* We're adding one here to ensure the normalized band isn't larger than unity norm */
125             bandE[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
126          } else {
127             bandE[i+c*m->nbEBands] = EPSILON;
128          }
129          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
130       }
131    } while (++c<C);
132    /*printf ("\n");*/
133 }
134
135 /* Normalise each band such that the energy is one. */
136 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
137 {
138    int i, c, N;
139    const opus_int16 *eBands = m->eBands;
140    N = M*m->shortMdctSize;
141    c=0; do {
142       i=0; do {
143          opus_val16 g;
144          int j,shift;
145          opus_val16 E;
146          shift = celt_zlog2(bandE[i+c*m->nbEBands])-13;
147          E = VSHR32(bandE[i+c*m->nbEBands], shift);
148          g = EXTRACT16(celt_rcp(SHL32(E,3)));
149          j=M*eBands[i]; do {
150             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
151          } while (++j<M*eBands[i+1]);
152       } while (++i<end);
153    } while (++c<C);
154 }
155
156 #else /* FIXED_POINT */
157 /* Compute the amplitude (sqrt energy) in each of the bands */
158 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bandE, int end, int C, int LM)
159 {
160    int i, c, N;
161    const opus_int16 *eBands = m->eBands;
162    N = m->shortMdctSize<<LM;
163    c=0; do {
164       for (i=0;i<end;i++)
165       {
166          opus_val32 sum;
167          sum = 1e-27f + celt_inner_prod_c(&X[c*N+(eBands[i]<<LM)], &X[c*N+(eBands[i]<<LM)], (eBands[i+1]-eBands[i])<<LM);
168          bandE[i+c*m->nbEBands] = celt_sqrt(sum);
169          /*printf ("%f ", bandE[i+c*m->nbEBands]);*/
170       }
171    } while (++c<C);
172    /*printf ("\n");*/
173 }
174
175 /* Normalise each band such that the energy is one. */
176 void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, celt_norm * OPUS_RESTRICT X, const celt_ener *bandE, int end, int C, int M)
177 {
178    int i, c, N;
179    const opus_int16 *eBands = m->eBands;
180    N = M*m->shortMdctSize;
181    c=0; do {
182       for (i=0;i<end;i++)
183       {
184          int j;
185          opus_val16 g = 1.f/(1e-27f+bandE[i+c*m->nbEBands]);
186          for (j=M*eBands[i];j<M*eBands[i+1];j++)
187             X[j+c*N] = freq[j+c*N]*g;
188       }
189    } while (++c<C);
190 }
191
192 #endif /* FIXED_POINT */
193
194 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
195 void denormalise_bands(const CELTMode *m, const celt_norm * OPUS_RESTRICT X,
196       celt_sig * OPUS_RESTRICT freq, const opus_val16 *bandLogE, int start,
197       int end, int M, int downsample, int silence)
198 {
199    int i, N;
200    int bound;
201    celt_sig * OPUS_RESTRICT f;
202    const celt_norm * OPUS_RESTRICT x;
203    const opus_int16 *eBands = m->eBands;
204    N = M*m->shortMdctSize;
205    bound = M*eBands[end];
206    if (downsample!=1)
207       bound = IMIN(bound, N/downsample);
208    if (silence)
209    {
210       bound = 0;
211       start = end = 0;
212    }
213    f = freq;
214    x = X+M*eBands[start];
215    for (i=0;i<M*eBands[start];i++)
216       *f++ = 0;
217    for (i=start;i<end;i++)
218    {
219       int j, band_end;
220       opus_val16 g;
221       opus_val16 lg;
222 #ifdef FIXED_POINT
223       int shift;
224 #endif
225       j=M*eBands[i];
226       band_end = M*eBands[i+1];
227       lg = SATURATE16(ADD32(bandLogE[i], SHL32((opus_val32)eMeans[i],6)));
228 #ifndef FIXED_POINT
229       g = celt_exp2(lg);
230 #else
231       /* Handle the integer part of the log energy */
232       shift = 16-(lg>>DB_SHIFT);
233       if (shift>31)
234       {
235          shift=0;
236          g=0;
237       } else {
238          /* Handle the fractional part. */
239          g = celt_exp2_frac(lg&((1<<DB_SHIFT)-1));
240       }
241       /* Handle extreme gains with negative shift. */
242       if (shift<0)
243       {
244          /* For shift <= -2 and g > 16384 we'd be likely to overflow, so we're
245             capping the gain here, which is equivalent to a cap of 18 on lg.
246             This shouldn't trigger unless the bitstream is already corrupted. */
247          if (shift <= -2)
248          {
249             g = 16384;
250             shift = -2;
251          }
252          do {
253             *f++ = SHL32(MULT16_16(*x++, g), -shift);
254          } while (++j<band_end);
255       } else
256 #endif
257          /* Be careful of the fixed-point "else" just above when changing this code */
258          do {
259             *f++ = SHR32(MULT16_16(*x++, g), shift);
260          } while (++j<band_end);
261    }
262    celt_assert(start <= end);
263    OPUS_CLEAR(&freq[bound], N-bound);
264 }
265
266 /* This prevents energy collapse for transients with multiple short MDCTs */
267 void anti_collapse(const CELTMode *m, celt_norm *X_, unsigned char *collapse_masks, int LM, int C, int size,
268       int start, int end, const opus_val16 *logE, const opus_val16 *prev1logE,
269       const opus_val16 *prev2logE, const int *pulses, opus_uint32 seed, int arch)
270 {
271    int c, i, j, k;
272    for (i=start;i<end;i++)
273    {
274       int N0;
275       opus_val16 thresh, sqrt_1;
276       int depth;
277 #ifdef FIXED_POINT
278       int shift;
279       opus_val32 thresh32;
280 #endif
281
282       N0 = m->eBands[i+1]-m->eBands[i];
283       /* depth in 1/8 bits */
284       celt_assert(pulses[i]>=0);
285       depth = celt_udiv(1+pulses[i], (m->eBands[i+1]-m->eBands[i]))>>LM;
286
287 #ifdef FIXED_POINT
288       thresh32 = SHR32(celt_exp2(-SHL16(depth, 10-BITRES)),1);
289       thresh = MULT16_32_Q15(QCONST16(0.5f, 15), MIN32(32767,thresh32));
290       {
291          opus_val32 t;
292          t = N0<<LM;
293          shift = celt_ilog2(t)>>1;
294          t = SHL32(t, (7-shift)<<1);
295          sqrt_1 = celt_rsqrt_norm(t);
296       }
297 #else
298       thresh = .5f*celt_exp2(-.125f*depth);
299       sqrt_1 = celt_rsqrt(N0<<LM);
300 #endif
301
302       c=0; do
303       {
304          celt_norm *X;
305          opus_val16 prev1;
306          opus_val16 prev2;
307          opus_val32 Ediff;
308          opus_val16 r;
309          int renormalize=0;
310          prev1 = prev1logE[c*m->nbEBands+i];
311          prev2 = prev2logE[c*m->nbEBands+i];
312          if (C==1)
313          {
314             prev1 = MAX16(prev1,prev1logE[m->nbEBands+i]);
315             prev2 = MAX16(prev2,prev2logE[m->nbEBands+i]);
316          }
317          Ediff = EXTEND32(logE[c*m->nbEBands+i])-EXTEND32(MIN16(prev1,prev2));
318          Ediff = MAX32(0, Ediff);
319
320 #ifdef FIXED_POINT
321          if (Ediff < 16384)
322          {
323             opus_val32 r32 = SHR32(celt_exp2(-EXTRACT16(Ediff)),1);
324             r = 2*MIN16(16383,r32);
325          } else {
326             r = 0;
327          }
328          if (LM==3)
329             r = MULT16_16_Q14(23170, MIN32(23169, r));
330          r = SHR16(MIN16(thresh, r),1);
331          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
332 #else
333          /* r needs to be multiplied by 2 or 2*sqrt(2) depending on LM because
334             short blocks don't have the same energy as long */
335          r = 2.f*celt_exp2(-Ediff);
336          if (LM==3)
337             r *= 1.41421356f;
338          r = MIN16(thresh, r);
339          r = r*sqrt_1;
340 #endif
341          X = X_+c*size+(m->eBands[i]<<LM);
342          for (k=0;k<1<<LM;k++)
343          {
344             /* Detect collapse */
345             if (!(collapse_masks[i*C+c]&1<<k))
346             {
347                /* Fill with noise */
348                for (j=0;j<N0;j++)
349                {
350                   seed = celt_lcg_rand(seed);
351                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
352                }
353                renormalize = 1;
354             }
355          }
356          /* We just added some energy, so we need to renormalise */
357          if (renormalize)
358             renormalise_vector(X, N0<<LM, Q15ONE, arch);
359       } while (++c<C);
360    }
361 }
362
363 /* Compute the weights to use for optimizing normalized distortion across
364    channels. We use the amplitude to weight square distortion, which means
365    that we use the square root of the value we would have been using if we
366    wanted to minimize the MSE in the non-normalized domain. This roughly
367    corresponds to some quick-and-dirty perceptual experiments I ran to
368    measure inter-aural masking (there doesn't seem to be any published data
369    on the topic). */
370 static void compute_channel_weights(celt_ener Ex, celt_ener Ey, opus_val16 w[2])
371 {
372    celt_ener minE;
373    minE = MIN32(Ex, Ey);
374 #if FIXED_POINT
375    int shift;
376 #endif
377    /* Adjustment to make the weights a bit more conservative. */
378    Ex = ADD32(Ex, minE/3);
379    Ey = ADD32(Ey, minE/3);
380 #if FIXED_POINT
381    shift = celt_ilog2(EPSILON+MAX32(Ex, Ey))-14;
382 #endif
383    w[0] = VSHR32(Ex, shift);
384    w[1] = VSHR32(Ey, shift);
385 }
386
387 static void intensity_stereo(const CELTMode *m, celt_norm * OPUS_RESTRICT X, const celt_norm * OPUS_RESTRICT Y, const celt_ener *bandE, int bandID, int N)
388 {
389    int i = bandID;
390    int j;
391    opus_val16 a1, a2;
392    opus_val16 left, right;
393    opus_val16 norm;
394 #ifdef FIXED_POINT
395    int shift = celt_zlog2(MAX32(bandE[i], bandE[i+m->nbEBands]))-13;
396 #endif
397    left = VSHR32(bandE[i],shift);
398    right = VSHR32(bandE[i+m->nbEBands],shift);
399    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
400    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
401    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
402    for (j=0;j<N;j++)
403    {
404       celt_norm r, l;
405       l = X[j];
406       r = Y[j];
407       X[j] = EXTRACT16(SHR32(MAC16_16(MULT16_16(a1, l), a2, r), 14));
408       /* Side is not encoded, no need to calculate */
409    }
410 }
411
412 static void stereo_split(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, int N)
413 {
414    int j;
415    for (j=0;j<N;j++)
416    {
417       opus_val32 r, l;
418       l = MULT16_16(QCONST16(.70710678f, 15), X[j]);
419       r = MULT16_16(QCONST16(.70710678f, 15), Y[j]);
420       X[j] = EXTRACT16(SHR32(ADD32(l, r), 15));
421       Y[j] = EXTRACT16(SHR32(SUB32(r, l), 15));
422    }
423 }
424
425 static void stereo_merge(celt_norm * OPUS_RESTRICT X, celt_norm * OPUS_RESTRICT Y, opus_val16 mid, int N, int arch)
426 {
427    int j;
428    opus_val32 xp=0, side=0;
429    opus_val32 El, Er;
430    opus_val16 mid2;
431 #ifdef FIXED_POINT
432    int kl, kr;
433 #endif
434    opus_val32 t, lgain, rgain;
435
436    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
437    dual_inner_prod(Y, X, Y, N, &xp, &side, arch);
438    /* Compensating for the mid normalization */
439    xp = MULT16_32_Q15(mid, xp);
440    /* mid and side are in Q15, not Q14 like X and Y */
441    mid2 = SHR16(mid, 1);
442    El = MULT16_16(mid2, mid2) + side - 2*xp;
443    Er = MULT16_16(mid2, mid2) + side + 2*xp;
444    if (Er < QCONST32(6e-4f, 28) || El < QCONST32(6e-4f, 28))
445    {
446       OPUS_COPY(Y, X, N);
447       return;
448    }
449
450 #ifdef FIXED_POINT
451    kl = celt_ilog2(El)>>1;
452    kr = celt_ilog2(Er)>>1;
453 #endif
454    t = VSHR32(El, (kl-7)<<1);
455    lgain = celt_rsqrt_norm(t);
456    t = VSHR32(Er, (kr-7)<<1);
457    rgain = celt_rsqrt_norm(t);
458
459 #ifdef FIXED_POINT
460    if (kl < 7)
461       kl = 7;
462    if (kr < 7)
463       kr = 7;
464 #endif
465
466    for (j=0;j<N;j++)
467    {
468       celt_norm r, l;
469       /* Apply mid scaling (side is already scaled) */
470       l = MULT16_16_P15(mid, X[j]);
471       r = Y[j];
472       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
473       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
474    }
475 }
476
477 /* Decide whether we should spread the pulses in the current frame */
478 int spreading_decision(const CELTMode *m, const celt_norm *X, int *average,
479       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
480       int end, int C, int M)
481 {
482    int i, c, N0;
483    int sum = 0, nbBands=0;
484    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
485    int decision;
486    int hf_sum=0;
487
488    celt_assert(end>0);
489
490    N0 = M*m->shortMdctSize;
491
492    if (M*(eBands[end]-eBands[end-1]) <= 8)
493       return SPREAD_NONE;
494    c=0; do {
495       for (i=0;i<end;i++)
496       {
497          int j, N, tmp=0;
498          int tcount[3] = {0,0,0};
499          const celt_norm * OPUS_RESTRICT x = X+M*eBands[i]+c*N0;
500          N = M*(eBands[i+1]-eBands[i]);
501          if (N<=8)
502             continue;
503          /* Compute rough CDF of |x[j]| */
504          for (j=0;j<N;j++)
505          {
506             opus_val32 x2N; /* Q13 */
507
508             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
509             if (x2N < QCONST16(0.25f,13))
510                tcount[0]++;
511             if (x2N < QCONST16(0.0625f,13))
512                tcount[1]++;
513             if (x2N < QCONST16(0.015625f,13))
514                tcount[2]++;
515          }
516
517          /* Only include four last bands (8 kHz and up) */
518          if (i>m->nbEBands-4)
519             hf_sum += celt_udiv(32*(tcount[1]+tcount[0]), N);
520          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
521          sum += tmp*256;
522          nbBands++;
523       }
524    } while (++c<C);
525
526    if (update_hf)
527    {
528       if (hf_sum)
529          hf_sum = celt_udiv(hf_sum, C*(4-m->nbEBands+end));
530       *hf_average = (*hf_average+hf_sum)>>1;
531       hf_sum = *hf_average;
532       if (*tapset_decision==2)
533          hf_sum += 4;
534       else if (*tapset_decision==0)
535          hf_sum -= 4;
536       if (hf_sum > 22)
537          *tapset_decision=2;
538       else if (hf_sum > 18)
539          *tapset_decision=1;
540       else
541          *tapset_decision=0;
542    }
543    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
544    celt_assert(nbBands>0); /* end has to be non-zero */
545    celt_assert(sum>=0);
546    sum = celt_udiv(sum, nbBands);
547    /* Recursive averaging */
548    sum = (sum+*average)>>1;
549    *average = sum;
550    /* Hysteresis */
551    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
552    if (sum < 80)
553    {
554       decision = SPREAD_AGGRESSIVE;
555    } else if (sum < 256)
556    {
557       decision = SPREAD_NORMAL;
558    } else if (sum < 384)
559    {
560       decision = SPREAD_LIGHT;
561    } else {
562       decision = SPREAD_NONE;
563    }
564 #ifdef FUZZING
565    decision = rand()&0x3;
566    *tapset_decision=rand()%3;
567 #endif
568    return decision;
569 }
570
571 /* Indexing table for converting from natural Hadamard to ordery Hadamard
572    This is essentially a bit-reversed Gray, on top of which we've added
573    an inversion of the order because we want the DC at the end rather than
574    the beginning. The lines are for N=2, 4, 8, 16 */
575 static const int ordery_table[] = {
576        1,  0,
577        3,  0,  2,  1,
578        7,  0,  4,  3,  6,  1,  5,  2,
579       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
580 };
581
582 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
583 {
584    int i,j;
585    VARDECL(celt_norm, tmp);
586    int N;
587    SAVE_STACK;
588    N = N0*stride;
589    ALLOC(tmp, N, celt_norm);
590    celt_assert(stride>0);
591    if (hadamard)
592    {
593       const int *ordery = ordery_table+stride-2;
594       for (i=0;i<stride;i++)
595       {
596          for (j=0;j<N0;j++)
597             tmp[ordery[i]*N0+j] = X[j*stride+i];
598       }
599    } else {
600       for (i=0;i<stride;i++)
601          for (j=0;j<N0;j++)
602             tmp[i*N0+j] = X[j*stride+i];
603    }
604    OPUS_COPY(X, tmp, N);
605    RESTORE_STACK;
606 }
607
608 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
609 {
610    int i,j;
611    VARDECL(celt_norm, tmp);
612    int N;
613    SAVE_STACK;
614    N = N0*stride;
615    ALLOC(tmp, N, celt_norm);
616    if (hadamard)
617    {
618       const int *ordery = ordery_table+stride-2;
619       for (i=0;i<stride;i++)
620          for (j=0;j<N0;j++)
621             tmp[j*stride+i] = X[ordery[i]*N0+j];
622    } else {
623       for (i=0;i<stride;i++)
624          for (j=0;j<N0;j++)
625             tmp[j*stride+i] = X[i*N0+j];
626    }
627    OPUS_COPY(X, tmp, N);
628    RESTORE_STACK;
629 }
630
631 void haar1(celt_norm *X, int N0, int stride)
632 {
633    int i, j;
634    N0 >>= 1;
635    for (i=0;i<stride;i++)
636       for (j=0;j<N0;j++)
637       {
638          opus_val32 tmp1, tmp2;
639          tmp1 = MULT16_16(QCONST16(.70710678f,15), X[stride*2*j+i]);
640          tmp2 = MULT16_16(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
641          X[stride*2*j+i] = EXTRACT16(PSHR32(ADD32(tmp1, tmp2), 15));
642          X[stride*(2*j+1)+i] = EXTRACT16(PSHR32(SUB32(tmp1, tmp2), 15));
643       }
644 }
645
646 static int compute_qn(int N, int b, int offset, int pulse_cap, int stereo)
647 {
648    static const opus_int16 exp2_table8[8] =
649       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
650    int qn, qb;
651    int N2 = 2*N-1;
652    if (stereo && N==2)
653       N2--;
654    /* The upper limit ensures that in a stereo split with itheta==16384, we'll
655        always have enough bits left over to code at least one pulse in the
656        side; otherwise it would collapse, since it doesn't get folded. */
657    qb = celt_sudiv(b+N2*offset, N2);
658    qb = IMIN(b-pulse_cap-(4<<BITRES), qb);
659
660    qb = IMIN(8<<BITRES, qb);
661
662    if (qb<(1<<BITRES>>1)) {
663       qn = 1;
664    } else {
665       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
666       qn = (qn+1)>>1<<1;
667    }
668    celt_assert(qn <= 256);
669    return qn;
670 }
671
672 struct band_ctx {
673    int encode;
674    int resynth;
675    const CELTMode *m;
676    int i;
677    int intensity;
678    int spread;
679    int tf_change;
680    ec_ctx *ec;
681    opus_int32 remaining_bits;
682    const celt_ener *bandE;
683    opus_uint32 seed;
684    int arch;
685    int theta_round;
686 };
687
688 struct split_ctx {
689    int inv;
690    int imid;
691    int iside;
692    int delta;
693    int itheta;
694    int qalloc;
695 };
696
697 static void compute_theta(struct band_ctx *ctx, struct split_ctx *sctx,
698       celt_norm *X, celt_norm *Y, int N, int *b, int B, int B0,
699       int LM,
700       int stereo, int *fill)
701 {
702    int qn;
703    int itheta=0;
704    int delta;
705    int imid, iside;
706    int qalloc;
707    int pulse_cap;
708    int offset;
709    opus_int32 tell;
710    int inv=0;
711    int encode;
712    const CELTMode *m;
713    int i;
714    int intensity;
715    ec_ctx *ec;
716    const celt_ener *bandE;
717
718    encode = ctx->encode;
719    m = ctx->m;
720    i = ctx->i;
721    intensity = ctx->intensity;
722    ec = ctx->ec;
723    bandE = ctx->bandE;
724
725    /* Decide on the resolution to give to the split parameter theta */
726    pulse_cap = m->logN[i]+LM*(1<<BITRES);
727    offset = (pulse_cap>>1) - (stereo&&N==2 ? QTHETA_OFFSET_TWOPHASE : QTHETA_OFFSET);
728    qn = compute_qn(N, *b, offset, pulse_cap, stereo);
729    if (stereo && i>=intensity)
730       qn = 1;
731    if (encode)
732    {
733       /* theta is the atan() of the ratio between the (normalized)
734          side and mid. With just that parameter, we can re-scale both
735          mid and side because we know that 1) they have unit norm and
736          2) they are orthogonal. */
737       itheta = stereo_itheta(X, Y, stereo, N, ctx->arch);
738    }
739    tell = ec_tell_frac(ec);
740    if (qn!=1)
741    {
742       if (encode)
743       {
744          if (!stereo || ctx->theta_round == 0)
745             itheta = (itheta*(opus_int32)qn+8192)>>14;
746          else if (ctx->theta_round < 0)
747             itheta = (itheta*(opus_int32)qn)>>14;
748          else
749             itheta = (itheta*(opus_int32)qn+16383)>>14;
750       }
751       /* Entropy coding of the angle. We use a uniform pdf for the
752          time split, a step for stereo, and a triangular one for the rest. */
753       if (stereo && N>2)
754       {
755          int p0 = 3;
756          int x = itheta;
757          int x0 = qn/2;
758          int ft = p0*(x0+1) + x0;
759          /* Use a probability of p0 up to itheta=8192 and then use 1 after */
760          if (encode)
761          {
762             ec_encode(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
763          } else {
764             int fs;
765             fs=ec_decode(ec,ft);
766             if (fs<(x0+1)*p0)
767                x=fs/p0;
768             else
769                x=x0+1+(fs-(x0+1)*p0);
770             ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
771             itheta = x;
772          }
773       } else if (B0>1 || stereo) {
774          /* Uniform pdf */
775          if (encode)
776             ec_enc_uint(ec, itheta, qn+1);
777          else
778             itheta = ec_dec_uint(ec, qn+1);
779       } else {
780          int fs=1, ft;
781          ft = ((qn>>1)+1)*((qn>>1)+1);
782          if (encode)
783          {
784             int fl;
785
786             fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
787             fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
788              ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
789
790             ec_encode(ec, fl, fl+fs, ft);
791          } else {
792             /* Triangular pdf */
793             int fl=0;
794             int fm;
795             fm = ec_decode(ec, ft);
796
797             if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
798             {
799                itheta = (isqrt32(8*(opus_uint32)fm + 1) - 1)>>1;
800                fs = itheta + 1;
801                fl = itheta*(itheta + 1)>>1;
802             }
803             else
804             {
805                itheta = (2*(qn + 1)
806                 - isqrt32(8*(opus_uint32)(ft - fm - 1) + 1))>>1;
807                fs = qn + 1 - itheta;
808                fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
809             }
810
811             ec_dec_update(ec, fl, fl+fs, ft);
812          }
813       }
814       celt_assert(itheta>=0);
815       itheta = celt_udiv((opus_int32)itheta*16384, qn);
816       if (encode && stereo)
817       {
818          if (itheta==0)
819             intensity_stereo(m, X, Y, bandE, i, N);
820          else
821             stereo_split(X, Y, N);
822       }
823       /* NOTE: Renormalising X and Y *may* help fixed-point a bit at very high rate.
824                Let's do that at higher complexity */
825    } else if (stereo) {
826       if (encode)
827       {
828          inv = itheta > 8192;
829          if (inv)
830          {
831             int j;
832             for (j=0;j<N;j++)
833                Y[j] = -Y[j];
834          }
835          intensity_stereo(m, X, Y, bandE, i, N);
836       }
837       if (*b>2<<BITRES && ctx->remaining_bits > 2<<BITRES)
838       {
839          if (encode)
840             ec_enc_bit_logp(ec, inv, 2);
841          else
842             inv = ec_dec_bit_logp(ec, 2);
843       } else
844          inv = 0;
845       itheta = 0;
846    }
847    qalloc = ec_tell_frac(ec) - tell;
848    *b -= qalloc;
849
850    if (itheta == 0)
851    {
852       imid = 32767;
853       iside = 0;
854       *fill &= (1<<B)-1;
855       delta = -16384;
856    } else if (itheta == 16384)
857    {
858       imid = 0;
859       iside = 32767;
860       *fill &= ((1<<B)-1)<<B;
861       delta = 16384;
862    } else {
863       imid = bitexact_cos((opus_int16)itheta);
864       iside = bitexact_cos((opus_int16)(16384-itheta));
865       /* This is the mid vs side allocation that minimizes squared error
866          in that band. */
867       delta = FRAC_MUL16((N-1)<<7,bitexact_log2tan(iside,imid));
868    }
869
870    sctx->inv = inv;
871    sctx->imid = imid;
872    sctx->iside = iside;
873    sctx->delta = delta;
874    sctx->itheta = itheta;
875    sctx->qalloc = qalloc;
876 }
877 static unsigned quant_band_n1(struct band_ctx *ctx, celt_norm *X, celt_norm *Y, int b,
878       celt_norm *lowband_out)
879 {
880    int c;
881    int stereo;
882    celt_norm *x = X;
883    int encode;
884    ec_ctx *ec;
885
886    encode = ctx->encode;
887    ec = ctx->ec;
888
889    stereo = Y != NULL;
890    c=0; do {
891       int sign=0;
892       if (ctx->remaining_bits>=1<<BITRES)
893       {
894          if (encode)
895          {
896             sign = x[0]<0;
897             ec_enc_bits(ec, sign, 1);
898          } else {
899             sign = ec_dec_bits(ec, 1);
900          }
901          ctx->remaining_bits -= 1<<BITRES;
902          b-=1<<BITRES;
903       }
904       if (ctx->resynth)
905          x[0] = sign ? -NORM_SCALING : NORM_SCALING;
906       x = Y;
907    } while (++c<1+stereo);
908    if (lowband_out)
909       lowband_out[0] = SHR16(X[0],4);
910    return 1;
911 }
912
913 /* This function is responsible for encoding and decoding a mono partition.
914    It can split the band in two and transmit the energy difference with
915    the two half-bands. It can be called recursively so bands can end up being
916    split in 8 parts. */
917 static unsigned quant_partition(struct band_ctx *ctx, celt_norm *X,
918       int N, int b, int B, celt_norm *lowband,
919       int LM,
920       opus_val16 gain, int fill)
921 {
922    const unsigned char *cache;
923    int q;
924    int curr_bits;
925    int imid=0, iside=0;
926    int B0=B;
927    opus_val16 mid=0, side=0;
928    unsigned cm=0;
929    celt_norm *Y=NULL;
930    int encode;
931    const CELTMode *m;
932    int i;
933    int spread;
934    ec_ctx *ec;
935
936    encode = ctx->encode;
937    m = ctx->m;
938    i = ctx->i;
939    spread = ctx->spread;
940    ec = ctx->ec;
941
942    /* If we need 1.5 more bit than we can produce, split the band in two. */
943    cache = m->cache.bits + m->cache.index[(LM+1)*m->nbEBands+i];
944    if (LM != -1 && b > cache[cache[0]]+12 && N>2)
945    {
946       int mbits, sbits, delta;
947       int itheta;
948       int qalloc;
949       struct split_ctx sctx;
950       celt_norm *next_lowband2=NULL;
951       opus_int32 rebalance;
952
953       N >>= 1;
954       Y = X+N;
955       LM -= 1;
956       if (B==1)
957          fill = (fill&1)|(fill<<1);
958       B = (B+1)>>1;
959
960       compute_theta(ctx, &sctx, X, Y, N, &b, B, B0, LM, 0, &fill);
961       imid = sctx.imid;
962       iside = sctx.iside;
963       delta = sctx.delta;
964       itheta = sctx.itheta;
965       qalloc = sctx.qalloc;
966 #ifdef FIXED_POINT
967       mid = imid;
968       side = iside;
969 #else
970       mid = (1.f/32768)*imid;
971       side = (1.f/32768)*iside;
972 #endif
973
974       /* Give more bits to low-energy MDCTs than they would otherwise deserve */
975       if (B0>1 && (itheta&0x3fff))
976       {
977          if (itheta > 8192)
978             /* Rough approximation for pre-echo masking */
979             delta -= delta>>(4-LM);
980          else
981             /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
982             delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
983       }
984       mbits = IMAX(0, IMIN(b, (b-delta)/2));
985       sbits = b-mbits;
986       ctx->remaining_bits -= qalloc;
987
988       if (lowband)
989          next_lowband2 = lowband+N; /* >32-bit split case */
990
991       rebalance = ctx->remaining_bits;
992       if (mbits >= sbits)
993       {
994          cm = quant_partition(ctx, X, N, mbits, B, lowband, LM,
995                MULT16_16_P15(gain,mid), fill);
996          rebalance = mbits - (rebalance-ctx->remaining_bits);
997          if (rebalance > 3<<BITRES && itheta!=0)
998             sbits += rebalance - (3<<BITRES);
999          cm |= quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1000                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1001       } else {
1002          cm = quant_partition(ctx, Y, N, sbits, B, next_lowband2, LM,
1003                MULT16_16_P15(gain,side), fill>>B)<<(B0>>1);
1004          rebalance = sbits - (rebalance-ctx->remaining_bits);
1005          if (rebalance > 3<<BITRES && itheta!=16384)
1006             mbits += rebalance - (3<<BITRES);
1007          cm |= quant_partition(ctx, X, N, mbits, B, lowband, LM,
1008                MULT16_16_P15(gain,mid), fill);
1009       }
1010    } else {
1011       /* This is the basic no-split case */
1012       q = bits2pulses(m, i, LM, b);
1013       curr_bits = pulses2bits(m, i, LM, q);
1014       ctx->remaining_bits -= curr_bits;
1015
1016       /* Ensures we can never bust the budget */
1017       while (ctx->remaining_bits < 0 && q > 0)
1018       {
1019          ctx->remaining_bits += curr_bits;
1020          q--;
1021          curr_bits = pulses2bits(m, i, LM, q);
1022          ctx->remaining_bits -= curr_bits;
1023       }
1024
1025       if (q!=0)
1026       {
1027          int K = get_pulses(q);
1028
1029          /* Finally do the actual quantization */
1030          if (encode)
1031          {
1032             cm = alg_quant(X, N, K, spread, B, ec, gain, ctx->resynth);
1033          } else {
1034             cm = alg_unquant(X, N, K, spread, B, ec, gain);
1035          }
1036       } else {
1037          /* If there's no pulse, fill the band anyway */
1038          int j;
1039          if (ctx->resynth)
1040          {
1041             unsigned cm_mask;
1042             /* B can be as large as 16, so this shift might overflow an int on a
1043                16-bit platform; use a long to get defined behavior.*/
1044             cm_mask = (unsigned)(1UL<<B)-1;
1045             fill &= cm_mask;
1046             if (!fill)
1047             {
1048                OPUS_CLEAR(X, N);
1049             } else {
1050                if (lowband == NULL)
1051                {
1052                   /* Noise */
1053                   for (j=0;j<N;j++)
1054                   {
1055                      ctx->seed = celt_lcg_rand(ctx->seed);
1056                      X[j] = (celt_norm)((opus_int32)ctx->seed>>20);
1057                   }
1058                   cm = cm_mask;
1059                } else {
1060                   /* Folded spectrum */
1061                   for (j=0;j<N;j++)
1062                   {
1063                      opus_val16 tmp;
1064                      ctx->seed = celt_lcg_rand(ctx->seed);
1065                      /* About 48 dB below the "normal" folding level */
1066                      tmp = QCONST16(1.0f/256, 10);
1067                      tmp = (ctx->seed)&0x8000 ? tmp : -tmp;
1068                      X[j] = lowband[j]+tmp;
1069                   }
1070                   cm = fill;
1071                }
1072                renormalise_vector(X, N, gain, ctx->arch);
1073             }
1074          }
1075       }
1076    }
1077
1078    return cm;
1079 }
1080
1081
1082 /* This function is responsible for encoding and decoding a band for the mono case. */
1083 static unsigned quant_band(struct band_ctx *ctx, celt_norm *X,
1084       int N, int b, int B, celt_norm *lowband,
1085       int LM, celt_norm *lowband_out,
1086       opus_val16 gain, celt_norm *lowband_scratch, int fill)
1087 {
1088    int N0=N;
1089    int N_B=N;
1090    int N_B0;
1091    int B0=B;
1092    int time_divide=0;
1093    int recombine=0;
1094    int longBlocks;
1095    unsigned cm=0;
1096    int k;
1097    int encode;
1098    int tf_change;
1099
1100    encode = ctx->encode;
1101    tf_change = ctx->tf_change;
1102
1103    longBlocks = B0==1;
1104
1105    N_B = celt_udiv(N_B, B);
1106
1107    /* Special case for one sample */
1108    if (N==1)
1109    {
1110       return quant_band_n1(ctx, X, NULL, b, lowband_out);
1111    }
1112
1113    if (tf_change>0)
1114       recombine = tf_change;
1115    /* Band recombining to increase frequency resolution */
1116
1117    if (lowband_scratch && lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
1118    {
1119       OPUS_COPY(lowband_scratch, lowband, N);
1120       lowband = lowband_scratch;
1121    }
1122
1123    for (k=0;k<recombine;k++)
1124    {
1125       static const unsigned char bit_interleave_table[16]={
1126             0,1,1,1,2,3,3,3,2,3,3,3,2,3,3,3
1127       };
1128       if (encode)
1129          haar1(X, N>>k, 1<<k);
1130       if (lowband)
1131          haar1(lowband, N>>k, 1<<k);
1132       fill = bit_interleave_table[fill&0xF]|bit_interleave_table[fill>>4]<<2;
1133    }
1134    B>>=recombine;
1135    N_B<<=recombine;
1136
1137    /* Increasing the time resolution */
1138    while ((N_B&1) == 0 && tf_change<0)
1139    {
1140       if (encode)
1141          haar1(X, N_B, B);
1142       if (lowband)
1143          haar1(lowband, N_B, B);
1144       fill |= fill<<B;
1145       B <<= 1;
1146       N_B >>= 1;
1147       time_divide++;
1148       tf_change++;
1149    }
1150    B0=B;
1151    N_B0 = N_B;
1152
1153    /* Reorganize the samples in time order instead of frequency order */
1154    if (B0>1)
1155    {
1156       if (encode)
1157          deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1158       if (lowband)
1159          deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
1160    }
1161
1162    cm = quant_partition(ctx, X, N, b, B, lowband, LM, gain, fill);
1163
1164    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1165    if (ctx->resynth)
1166    {
1167       /* Undo the sample reorganization going from time order to frequency order */
1168       if (B0>1)
1169          interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1170
1171       /* Undo time-freq changes that we did earlier */
1172       N_B = N_B0;
1173       B = B0;
1174       for (k=0;k<time_divide;k++)
1175       {
1176          B >>= 1;
1177          N_B <<= 1;
1178          cm |= cm>>B;
1179          haar1(X, N_B, B);
1180       }
1181
1182       for (k=0;k<recombine;k++)
1183       {
1184          static const unsigned char bit_deinterleave_table[16]={
1185                0x00,0x03,0x0C,0x0F,0x30,0x33,0x3C,0x3F,
1186                0xC0,0xC3,0xCC,0xCF,0xF0,0xF3,0xFC,0xFF
1187          };
1188          cm = bit_deinterleave_table[cm];
1189          haar1(X, N0>>k, 1<<k);
1190       }
1191       B<<=recombine;
1192
1193       /* Scale output for later folding */
1194       if (lowband_out)
1195       {
1196          int j;
1197          opus_val16 n;
1198          n = celt_sqrt(SHL32(EXTEND32(N0),22));
1199          for (j=0;j<N0;j++)
1200             lowband_out[j] = MULT16_16_Q15(n,X[j]);
1201       }
1202       cm &= (1<<B)-1;
1203    }
1204    return cm;
1205 }
1206
1207
1208 /* This function is responsible for encoding and decoding a band for the stereo case. */
1209 static unsigned quant_band_stereo(struct band_ctx *ctx, celt_norm *X, celt_norm *Y,
1210       int N, int b, int B, celt_norm *lowband,
1211       int LM, celt_norm *lowband_out,
1212       celt_norm *lowband_scratch, int fill)
1213 {
1214    int imid=0, iside=0;
1215    int inv = 0;
1216    opus_val16 mid=0, side=0;
1217    unsigned cm=0;
1218    int mbits, sbits, delta;
1219    int itheta;
1220    int qalloc;
1221    struct split_ctx sctx;
1222    int orig_fill;
1223    int encode;
1224    ec_ctx *ec;
1225
1226    encode = ctx->encode;
1227    ec = ctx->ec;
1228
1229    /* Special case for one sample */
1230    if (N==1)
1231    {
1232       return quant_band_n1(ctx, X, Y, b, lowband_out);
1233    }
1234
1235    orig_fill = fill;
1236
1237    compute_theta(ctx, &sctx, X, Y, N, &b, B, B, LM, 1, &fill);
1238    inv = sctx.inv;
1239    imid = sctx.imid;
1240    iside = sctx.iside;
1241    delta = sctx.delta;
1242    itheta = sctx.itheta;
1243    qalloc = sctx.qalloc;
1244 #ifdef FIXED_POINT
1245    mid = imid;
1246    side = iside;
1247 #else
1248    mid = (1.f/32768)*imid;
1249    side = (1.f/32768)*iside;
1250 #endif
1251
1252    /* This is a special case for N=2 that only works for stereo and takes
1253       advantage of the fact that mid and side are orthogonal to encode
1254       the side with just one bit. */
1255    if (N==2)
1256    {
1257       int c;
1258       int sign=0;
1259       celt_norm *x2, *y2;
1260       mbits = b;
1261       sbits = 0;
1262       /* Only need one bit for the side. */
1263       if (itheta != 0 && itheta != 16384)
1264          sbits = 1<<BITRES;
1265       mbits -= sbits;
1266       c = itheta > 8192;
1267       ctx->remaining_bits -= qalloc+sbits;
1268
1269       x2 = c ? Y : X;
1270       y2 = c ? X : Y;
1271       if (sbits)
1272       {
1273          if (encode)
1274          {
1275             /* Here we only need to encode a sign for the side. */
1276             sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
1277             ec_enc_bits(ec, sign, 1);
1278          } else {
1279             sign = ec_dec_bits(ec, 1);
1280          }
1281       }
1282       sign = 1-2*sign;
1283       /* We use orig_fill here because we want to fold the side, but if
1284          itheta==16384, we'll have cleared the low bits of fill. */
1285       cm = quant_band(ctx, x2, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1286             lowband_scratch, orig_fill);
1287       /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
1288          and there's no need to worry about mixing with the other channel. */
1289       y2[0] = -sign*x2[1];
1290       y2[1] = sign*x2[0];
1291       if (ctx->resynth)
1292       {
1293          celt_norm tmp;
1294          X[0] = MULT16_16_Q15(mid, X[0]);
1295          X[1] = MULT16_16_Q15(mid, X[1]);
1296          Y[0] = MULT16_16_Q15(side, Y[0]);
1297          Y[1] = MULT16_16_Q15(side, Y[1]);
1298          tmp = X[0];
1299          X[0] = SUB16(tmp,Y[0]);
1300          Y[0] = ADD16(tmp,Y[0]);
1301          tmp = X[1];
1302          X[1] = SUB16(tmp,Y[1]);
1303          Y[1] = ADD16(tmp,Y[1]);
1304       }
1305    } else {
1306       /* "Normal" split code */
1307       opus_int32 rebalance;
1308
1309       mbits = IMAX(0, IMIN(b, (b-delta)/2));
1310       sbits = b-mbits;
1311       ctx->remaining_bits -= qalloc;
1312
1313       rebalance = ctx->remaining_bits;
1314       if (mbits >= sbits)
1315       {
1316          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1317             mid for folding later. */
1318          cm = quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1319                lowband_scratch, fill);
1320          rebalance = mbits - (rebalance-ctx->remaining_bits);
1321          if (rebalance > 3<<BITRES && itheta!=0)
1322             sbits += rebalance - (3<<BITRES);
1323
1324          /* For a stereo split, the high bits of fill are always zero, so no
1325             folding will be done to the side. */
1326          cm |= quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1327       } else {
1328          /* For a stereo split, the high bits of fill are always zero, so no
1329             folding will be done to the side. */
1330          cm = quant_band(ctx, Y, N, sbits, B, NULL, LM, NULL, side, NULL, fill>>B);
1331          rebalance = sbits - (rebalance-ctx->remaining_bits);
1332          if (rebalance > 3<<BITRES && itheta!=16384)
1333             mbits += rebalance - (3<<BITRES);
1334          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
1335             mid for folding later. */
1336          cm |= quant_band(ctx, X, N, mbits, B, lowband, LM, lowband_out, Q15ONE,
1337                lowband_scratch, fill);
1338       }
1339    }
1340
1341
1342    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1343    if (ctx->resynth)
1344    {
1345       if (N!=2)
1346          stereo_merge(X, Y, mid, N, ctx->arch);
1347       if (inv)
1348       {
1349          int j;
1350          for (j=0;j<N;j++)
1351             Y[j] = -Y[j];
1352       }
1353    }
1354    return cm;
1355 }
1356
1357
1358 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1359       celt_norm *X_, celt_norm *Y_, unsigned char *collapse_masks,
1360       const celt_ener *bandE, int *pulses, int shortBlocks, int spread,
1361       int dual_stereo, int intensity, int *tf_res, opus_int32 total_bits,
1362       opus_int32 balance, ec_ctx *ec, int LM, int codedBands,
1363       opus_uint32 *seed, int complexity, int arch)
1364 {
1365    int i;
1366    opus_int32 remaining_bits;
1367    const opus_int16 * OPUS_RESTRICT eBands = m->eBands;
1368    celt_norm * OPUS_RESTRICT norm, * OPUS_RESTRICT norm2;
1369    VARDECL(celt_norm, _norm);
1370    VARDECL(celt_norm, _lowband_scratch);
1371    VARDECL(celt_norm, X_save);
1372    VARDECL(celt_norm, Y_save);
1373    VARDECL(celt_norm, X_save2);
1374    VARDECL(celt_norm, Y_save2);
1375    VARDECL(celt_norm, norm_save2);
1376    int resynth_alloc;
1377    celt_norm *lowband_scratch;
1378    int B;
1379    int M;
1380    int lowband_offset;
1381    int update_lowband = 1;
1382    int C = Y_ != NULL ? 2 : 1;
1383    int norm_offset;
1384    int theta_rdo = encode && Y_!=NULL && !dual_stereo && complexity>=8;
1385 #ifdef RESYNTH
1386    int resynth = 1;
1387 #else
1388    int resynth = !encode || theta_rdo;
1389 #endif
1390    struct band_ctx ctx;
1391    SAVE_STACK;
1392
1393    M = 1<<LM;
1394    B = shortBlocks ? M : 1;
1395    norm_offset = M*eBands[start];
1396    /* No need to allocate norm for the last band because we don't need an
1397       output in that band. */
1398    ALLOC(_norm, C*(M*eBands[m->nbEBands-1]-norm_offset), celt_norm);
1399    norm = _norm;
1400    norm2 = norm + M*eBands[m->nbEBands-1]-norm_offset;
1401
1402    /* For decoding, we can use the last band as scratch space because we don't need that
1403       scratch space for the last band and we don't care about the data there until we're
1404       decoding the last band. */
1405    if (encode && resynth)
1406       resynth_alloc = M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]);
1407    else
1408       resynth_alloc = ALLOC_NONE;
1409    ALLOC(_lowband_scratch, resynth_alloc, celt_norm);
1410    if (encode && resynth)
1411       lowband_scratch = _lowband_scratch;
1412    else
1413       lowband_scratch = X_+M*eBands[m->nbEBands-1];
1414    ALLOC(X_save, resynth_alloc, celt_norm);
1415    ALLOC(Y_save, resynth_alloc, celt_norm);
1416    ALLOC(X_save2, resynth_alloc, celt_norm);
1417    ALLOC(Y_save2, resynth_alloc, celt_norm);
1418    ALLOC(norm_save2, resynth_alloc, celt_norm);
1419
1420    lowband_offset = 0;
1421    ctx.bandE = bandE;
1422    ctx.ec = ec;
1423    ctx.encode = encode;
1424    ctx.intensity = intensity;
1425    ctx.m = m;
1426    ctx.seed = *seed;
1427    ctx.spread = spread;
1428    ctx.arch = arch;
1429    ctx.resynth = resynth;
1430    ctx.theta_round = 0;
1431    for (i=start;i<end;i++)
1432    {
1433       opus_int32 tell;
1434       int b;
1435       int N;
1436       opus_int32 curr_balance;
1437       int effective_lowband=-1;
1438       celt_norm * OPUS_RESTRICT X, * OPUS_RESTRICT Y;
1439       int tf_change=0;
1440       unsigned x_cm;
1441       unsigned y_cm;
1442       int last;
1443
1444       ctx.i = i;
1445       last = (i==end-1);
1446
1447       X = X_+M*eBands[i];
1448       if (Y_!=NULL)
1449          Y = Y_+M*eBands[i];
1450       else
1451          Y = NULL;
1452       N = M*eBands[i+1]-M*eBands[i];
1453       tell = ec_tell_frac(ec);
1454
1455       /* Compute how many bits we want to allocate to this band */
1456       if (i != start)
1457          balance -= tell;
1458       remaining_bits = total_bits-tell-1;
1459       ctx.remaining_bits = remaining_bits;
1460       if (i <= codedBands-1)
1461       {
1462          curr_balance = celt_sudiv(balance, IMIN(3, codedBands-i));
1463          b = IMAX(0, IMIN(16383, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1464       } else {
1465          b = 0;
1466       }
1467
1468       if (resynth && M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1469             lowband_offset = i;
1470
1471       tf_change = tf_res[i];
1472       ctx.tf_change = tf_change;
1473       if (i>=m->effEBands)
1474       {
1475          X=norm;
1476          if (Y_!=NULL)
1477             Y = norm;
1478          lowband_scratch = NULL;
1479       }
1480       if (last && !theta_rdo)
1481          lowband_scratch = NULL;
1482
1483       /* Get a conservative estimate of the collapse_mask's for the bands we're
1484          going to be folding from. */
1485       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1 || tf_change<0))
1486       {
1487          int fold_start;
1488          int fold_end;
1489          int fold_i;
1490          /* This ensures we never repeat spectral content within one band */
1491          effective_lowband = IMAX(0, M*eBands[lowband_offset]-norm_offset-N);
1492          fold_start = lowband_offset;
1493          while(M*eBands[--fold_start] > effective_lowband+norm_offset);
1494          fold_end = lowband_offset-1;
1495          while(M*eBands[++fold_end] < effective_lowband+norm_offset+N);
1496          x_cm = y_cm = 0;
1497          fold_i = fold_start; do {
1498            x_cm |= collapse_masks[fold_i*C+0];
1499            y_cm |= collapse_masks[fold_i*C+C-1];
1500          } while (++fold_i<fold_end);
1501       }
1502       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1503          always) be non-zero. */
1504       else
1505          x_cm = y_cm = (1<<B)-1;
1506
1507       if (dual_stereo && i==intensity)
1508       {
1509          int j;
1510
1511          /* Switch off dual stereo to do intensity. */
1512          dual_stereo = 0;
1513          if (resynth)
1514             for (j=0;j<M*eBands[i]-norm_offset;j++)
1515                norm[j] = HALF32(norm[j]+norm2[j]);
1516       }
1517       if (dual_stereo)
1518       {
1519          x_cm = quant_band(&ctx, X, N, b/2, B,
1520                effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1521                last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm);
1522          y_cm = quant_band(&ctx, Y, N, b/2, B,
1523                effective_lowband != -1 ? norm2+effective_lowband : NULL, LM,
1524                last?NULL:norm2+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, y_cm);
1525       } else {
1526          if (Y!=NULL)
1527          {
1528             if (theta_rdo && i < intensity)
1529             {
1530                ec_ctx ec_save, ec_save2;
1531                struct band_ctx ctx_save, ctx_save2;
1532                opus_val32 dist0, dist1;
1533                unsigned cm, cm2;
1534                int nstart_bytes, nend_bytes, save_bytes;
1535                unsigned char *bytes_buf;
1536                unsigned char bytes_save[1275];
1537                opus_val16 w[2];
1538                compute_channel_weights(bandE[i], bandE[i+m->nbEBands], w);
1539                /* Make a copy. */
1540                cm = x_cm|y_cm;
1541                ec_save = *ec;
1542                ctx_save = ctx;
1543                OPUS_COPY(X_save, X, N);
1544                OPUS_COPY(Y_save, Y, N);
1545                /* Encode and round down. */
1546                ctx.theta_round = -1;
1547                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1548                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1549                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1550                dist0 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1551
1552                /* Save first result. */
1553                cm2 = x_cm;
1554                ec_save2 = *ec;
1555                ctx_save2 = ctx;
1556                OPUS_COPY(X_save2, X, N);
1557                OPUS_COPY(Y_save2, Y, N);
1558                if (!last)
1559                   OPUS_COPY(norm_save2, norm+M*eBands[i]-norm_offset, N);
1560                nstart_bytes = ec_save.offs;
1561                nend_bytes = ec_save.storage;
1562                bytes_buf = ec_save.buf+nstart_bytes;
1563                save_bytes = nend_bytes-nstart_bytes;
1564                OPUS_COPY(bytes_save, bytes_buf, save_bytes);
1565
1566                /* Restore */
1567                *ec = ec_save;
1568                ctx = ctx_save;
1569                OPUS_COPY(X, X_save, N);
1570                OPUS_COPY(Y, Y_save, N);
1571                /* Encode and round up. */
1572                ctx.theta_round = 1;
1573                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1574                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1575                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
1576                dist1 = MULT16_32_Q15(w[0], celt_inner_prod(X_save, X, N, arch)) + MULT16_32_Q15(w[1], celt_inner_prod(Y_save, Y, N, arch));
1577                if (dist0 >= dist1) {
1578                   x_cm = cm2;
1579                   *ec = ec_save2;
1580                   ctx = ctx_save2;
1581                   OPUS_COPY(X, X_save2, N);
1582                   OPUS_COPY(Y, Y_save2, N);
1583                   if (!last)
1584                      OPUS_COPY(norm+M*eBands[i]-norm_offset, norm_save2, N);
1585                   OPUS_COPY(bytes_buf, bytes_save, save_bytes);
1586                }
1587             } else {
1588                ctx.theta_round = 0;
1589                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
1590                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1591                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, x_cm|y_cm);
1592             }
1593          } else {
1594             x_cm = quant_band(&ctx, X, N, b, B,
1595                   effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
1596                   last?NULL:norm+M*eBands[i]-norm_offset, Q15ONE, lowband_scratch, x_cm|y_cm);
1597          }
1598          y_cm = x_cm;
1599       }
1600       collapse_masks[i*C+0] = (unsigned char)x_cm;
1601       collapse_masks[i*C+C-1] = (unsigned char)y_cm;
1602       balance += pulses[i] + tell;
1603
1604       /* Update the folding position only as long as we have 1 bit/sample depth. */
1605       update_lowband = b>(N<<BITRES);
1606    }
1607    *seed = ctx.seed;
1608
1609    RESTORE_STACK;
1610 }
1611