31aee8467bb11170ac91acc8227822b0f330463f
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 celt_uint32 lcg_rand(celt_uint32 seed)
49 {
50    return 1664525 * seed + 1013904223;
51 }
52
53 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
54    with this approximation is important because it has an impact on the bit allocation */
55 static celt_int16 bitexact_cos(celt_int16 x)
56 {
57    celt_int32 tmp;
58    celt_int16 x2;
59    tmp = (4096+((celt_int32)(x)*(x)))>>13;
60    if (tmp > 32767)
61       tmp = 32767;
62    x2 = tmp;
63    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
64    if (x2 > 32766)
65       x2 = 32766;
66    return 1+x2;
67 }
68
69 static int bitexact_log2tan(int isin,int icos)
70 {
71    int lc;
72    int ls;
73    lc=EC_ILOG(icos);
74    ls=EC_ILOG(isin);
75    icos<<=15-lc;
76    isin<<=15-ls;
77    return (ls-lc<<11)
78          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
79          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
80 }
81
82 #ifdef FIXED_POINT
83 /* Compute the amplitude (sqrt energy) in each of the bands */
84 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
85 {
86    int i, c, N;
87    const celt_int16 *eBands = m->eBands;
88    const int C = CHANNELS(_C);
89    N = M*m->shortMdctSize;
90    c=0; do {
91       for (i=0;i<end;i++)
92       {
93          int j;
94          celt_word32 maxval=0;
95          celt_word32 sum = 0;
96          
97          j=M*eBands[i]; do {
98             maxval = MAX32(maxval, X[j+c*N]);
99             maxval = MAX32(maxval, -X[j+c*N]);
100          } while (++j<M*eBands[i+1]);
101          
102          if (maxval > 0)
103          {
104             int shift = celt_ilog2(maxval)-10;
105             j=M*eBands[i]; do {
106                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
107                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
108             } while (++j<M*eBands[i+1]);
109             /* We're adding one here to make damn sure we never end up with a pitch vector that's
110                larger than unity norm */
111             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
112          } else {
113             bank[i+c*m->nbEBands] = EPSILON;
114          }
115          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
116       }
117    } while (++c<C);
118    /*printf ("\n");*/
119 }
120
121 /* Normalise each band such that the energy is one. */
122 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
123 {
124    int i, c, N;
125    const celt_int16 *eBands = m->eBands;
126    const int C = CHANNELS(_C);
127    N = M*m->shortMdctSize;
128    c=0; do {
129       i=0; do {
130          celt_word16 g;
131          int j,shift;
132          celt_word16 E;
133          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
134          E = VSHR32(bank[i+c*m->nbEBands], shift);
135          g = EXTRACT16(celt_rcp(SHL32(E,3)));
136          j=M*eBands[i]; do {
137             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
138          } while (++j<M*eBands[i+1]);
139       } while (++i<end);
140    } while (++c<C);
141 }
142
143 #else /* FIXED_POINT */
144 /* Compute the amplitude (sqrt energy) in each of the bands */
145 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
146 {
147    int i, c, N;
148    const celt_int16 *eBands = m->eBands;
149    const int C = CHANNELS(_C);
150    N = M*m->shortMdctSize;
151    c=0; do {
152       for (i=0;i<end;i++)
153       {
154          int j;
155          celt_word32 sum = 1e-27f;
156          for (j=M*eBands[i];j<M*eBands[i+1];j++)
157             sum += X[j+c*N]*X[j+c*N];
158          bank[i+c*m->nbEBands] = celt_sqrt(sum);
159          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
160       }
161    } while (++c<C);
162    /*printf ("\n");*/
163 }
164
165 /* Normalise each band such that the energy is one. */
166 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
167 {
168    int i, c, N;
169    const celt_int16 *eBands = m->eBands;
170    const int C = CHANNELS(_C);
171    N = M*m->shortMdctSize;
172    c=0; do {
173       for (i=0;i<end;i++)
174       {
175          int j;
176          celt_word16 g = 1.f/(1e-27f+bank[i+c*m->nbEBands]);
177          for (j=M*eBands[i];j<M*eBands[i+1];j++)
178             X[j+c*N] = freq[j+c*N]*g;
179       }
180    } while (++c<C);
181 }
182
183 #endif /* FIXED_POINT */
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    c=0; do {
194       celt_sig * restrict f;
195       const celt_norm * restrict x;
196       f = freq+c*N;
197       x = X+c*N;
198       for (i=0;i<end;i++)
199       {
200          int j, band_end;
201          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
202          j=M*eBands[i];
203          band_end = M*eBands[i+1];
204          do {
205             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
206             x++;
207          } while (++j<band_end);
208       }
209       for (i=M*eBands[m->nbEBands];i<N;i++)
210          *f++ = 0;
211    } while (++c<C);
212 }
213
214 /* This prevents energy collapse for transients with multiple short MDCTs */
215 void anti_collapse(const CELTMode *m, celt_norm *_X, unsigned char *collapse_masks, int LM, int C, int size,
216       int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
217       celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
218 {
219    int c, i, j, k;
220    for (i=start;i<end;i++)
221    {
222       int N0;
223       celt_word16 thresh, sqrt_1;
224       int depth;
225 #ifdef FIXED_POINT
226       int shift;
227 #endif
228
229       N0 = m->eBands[i+1]-m->eBands[i];
230       /* depth in 1/8 bits */
231       depth = (1+pulses[i])/(m->eBands[i+1]-m->eBands[i]<<LM);
232
233 #ifdef FIXED_POINT
234       thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11-BITRES)),1) ));
235       {
236          celt_word32 t;
237          t = N0<<LM;
238          shift = celt_ilog2(t)>>1;
239          t = SHL32(t, (7-shift)<<1);
240          sqrt_1 = celt_rsqrt_norm(t);
241       }
242 #else
243       thresh = .3f*celt_exp2(-.125f*depth);
244       sqrt_1 = celt_rsqrt(N0<<LM);
245 #endif
246
247       c=0; do
248       {
249          celt_norm *X;
250          celt_word16 Ediff;
251          celt_word16 r;
252          Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
253          Ediff = MAX16(0, Ediff);
254
255 #ifdef FIXED_POINT
256          if (Ediff < 16384)
257             r = 2*MIN16(16383,SHR32(celt_exp2(-SHL16(Ediff, 11-DB_SHIFT)),1));
258          else
259             r = 0;
260          if (LM==3)
261             r = MULT16_16_Q15(QCONST16(.70710678f,15), r);
262          r = SHR16(MIN16(thresh, r),1);
263          r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
264 #else
265          /* r needs to be multiplied by 2 or 2*sqrt(2) depending on LM because
266             short blocks don't have the same energy as long */
267          r = 2.f*celt_exp2(-Ediff);
268          if (LM==3)
269             r *= .70710678f;
270          r = MIN16(thresh, r);
271          r = r*sqrt_1;
272 #endif
273          X = _X+c*size+(m->eBands[i]<<LM);
274          for (k=0;k<1<<LM;k++)
275          {
276             /* Detect collapse */
277             if (!(collapse_masks[i*C+c]&1<<k))
278             {
279                /* Fill with noise */
280                for (j=0;j<N0;j++)
281                {
282                   seed = lcg_rand(seed);
283                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
284                }
285             }
286          }
287          /* We just added some energy, so we need to renormalise */
288          renormalise_vector(X, N0<<LM, Q15ONE);
289       } while (++c<C);
290    }
291
292 }
293
294
295 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
296 {
297    int i = bandID;
298    int j;
299    celt_word16 a1, a2;
300    celt_word16 left, right;
301    celt_word16 norm;
302 #ifdef FIXED_POINT
303    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
304 #endif
305    left = VSHR32(bank[i],shift);
306    right = VSHR32(bank[i+m->nbEBands],shift);
307    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
308    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
309    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
310    for (j=0;j<N;j++)
311    {
312       celt_norm r, l;
313       l = X[j];
314       r = Y[j];
315       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
316       /* Side is not encoded, no need to calculate */
317    }
318 }
319
320 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
321 {
322    int j;
323    for (j=0;j<N;j++)
324    {
325       celt_norm r, l;
326       l = MULT16_16_Q15(QCONST16(.70710678f,15), X[j]);
327       r = MULT16_16_Q15(QCONST16(.70710678f,15), Y[j]);
328       X[j] = l+r;
329       Y[j] = r-l;
330    }
331 }
332
333 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, int N)
334 {
335    int j;
336    celt_word32 xp=0, side=0;
337    celt_word32 El, Er;
338    celt_word16 mid2;
339 #ifdef FIXED_POINT
340    int kl, kr;
341 #endif
342    celt_word32 t, lgain, rgain;
343
344    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
345    for (j=0;j<N;j++)
346    {
347       xp = MAC16_16(xp, X[j], Y[j]);
348       side = MAC16_16(side, Y[j], Y[j]);
349    }
350    /* Compensating for the mid normalization */
351    xp = MULT16_32_Q15(mid, xp);
352    /* mid and side are in Q15, not Q14 like X and Y */
353    mid2 = SHR32(mid, 1);
354    El = MULT16_16(mid2, mid2) + side - 2*xp;
355    Er = MULT16_16(mid2, mid2) + side + 2*xp;
356    if (Er < QCONST32(6e-4f, 28) || El < QCONST32(6e-4f, 28))
357    {
358       for (j=0;j<N;j++)
359          Y[j] = X[j];
360       return;
361    }
362
363 #ifdef FIXED_POINT
364    kl = celt_ilog2(El)>>1;
365    kr = celt_ilog2(Er)>>1;
366 #endif
367    t = VSHR32(El, (kl-7)<<1);
368    lgain = celt_rsqrt_norm(t);
369    t = VSHR32(Er, (kr-7)<<1);
370    rgain = celt_rsqrt_norm(t);
371
372 #ifdef FIXED_POINT
373    if (kl < 7)
374       kl = 7;
375    if (kr < 7)
376       kr = 7;
377 #endif
378
379    for (j=0;j<N;j++)
380    {
381       celt_norm r, l;
382       /* Apply mid scaling (side is already scaled) */
383       l = MULT16_16_Q15(mid, X[j]);
384       r = Y[j];
385       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
386       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
387    }
388 }
389
390 /* Decide whether we should spread the pulses in the current frame */
391 int spreading_decision(const CELTMode *m, celt_norm *X, int *average,
392       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
393       int end, int _C, int M)
394 {
395    int i, c, N0;
396    int sum = 0, nbBands=0;
397    const int C = CHANNELS(_C);
398    const celt_int16 * restrict eBands = m->eBands;
399    int decision;
400    int hf_sum=0;
401    
402    N0 = M*m->shortMdctSize;
403
404    if (M*(eBands[end]-eBands[end-1]) <= 8)
405       return SPREAD_NONE;
406    c=0; do {
407       for (i=0;i<end;i++)
408       {
409          int j, N, tmp=0;
410          int tcount[3] = {0};
411          celt_norm * restrict x = X+M*eBands[i]+c*N0;
412          N = M*(eBands[i+1]-eBands[i]);
413          if (N<=8)
414             continue;
415          /* Compute rough CDF of |x[j]| */
416          for (j=0;j<N;j++)
417          {
418             celt_word32 x2N; /* Q13 */
419
420             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
421             if (x2N < QCONST16(0.25f,13))
422                tcount[0]++;
423             if (x2N < QCONST16(0.0625f,13))
424                tcount[1]++;
425             if (x2N < QCONST16(0.015625f,13))
426                tcount[2]++;
427          }
428
429          /* Only include four last bands (8 kHz and up) */
430          if (i>m->nbEBands-4)
431             hf_sum += 32*(tcount[1]+tcount[0])/N;
432          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
433          sum += tmp*256;
434          nbBands++;
435       }
436    } while (++c<C);
437
438    if (update_hf)
439    {
440       if (hf_sum)
441          hf_sum /= C*(4-m->nbEBands+end);
442       *hf_average = (*hf_average+hf_sum)>>1;
443       hf_sum = *hf_average;
444       if (*tapset_decision==2)
445          hf_sum += 4;
446       else if (*tapset_decision==0)
447          hf_sum -= 4;
448       if (hf_sum > 22)
449          *tapset_decision=2;
450       else if (hf_sum > 18)
451          *tapset_decision=1;
452       else
453          *tapset_decision=0;
454    }
455    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
456    sum /= nbBands;
457    /* Recursive averaging */
458    sum = (sum+*average)>>1;
459    *average = sum;
460    /* Hysteresis */
461    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
462    if (sum < 80)
463    {
464       decision = SPREAD_AGGRESSIVE;
465    } else if (sum < 256)
466    {
467       decision = SPREAD_NORMAL;
468    } else if (sum < 384)
469    {
470       decision = SPREAD_LIGHT;
471    } else {
472       decision = SPREAD_NONE;
473    }
474    return decision;
475 }
476
477 #ifdef MEASURE_NORM_MSE
478
479 float MSE[30] = {0};
480 int nbMSEBands = 0;
481 int MSECount[30] = {0};
482
483 void dump_norm_mse(void)
484 {
485    int i;
486    for (i=0;i<nbMSEBands;i++)
487    {
488       printf ("%g ", MSE[i]/MSECount[i]);
489    }
490    printf ("\n");
491 }
492
493 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
494 {
495    static int init = 0;
496    int i;
497    if (!init)
498    {
499       atexit(dump_norm_mse);
500       init = 1;
501    }
502    for (i=0;i<m->nbEBands;i++)
503    {
504       int j;
505       int c;
506       float g;
507       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
508          continue;
509       c=0; do {
510          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
511          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
512             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
513       } while (++c<C);
514       MSECount[i]+=C;
515    }
516    nbMSEBands = m->nbEBands;
517 }
518
519 #endif
520
521 /* Indexing table for converting from natural Hadamard to ordery Hadamard
522    This is essentially a bit-reversed Gray, on top of which we've added
523    an inversion of the order because we want the DC at the end rather than
524    the beginning. The lines are for N=2, 4, 8, 16 */
525 static const int ordery_table[] = {
526        1,  0,
527        3,  0,  2,  1,
528        7,  0,  4,  3,  6,  1,  5,  2,
529       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
530 };
531
532 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
533 {
534    int i,j;
535    VARDECL(celt_norm, tmp);
536    int N;
537    SAVE_STACK;
538    N = N0*stride;
539    ALLOC(tmp, N, celt_norm);
540    if (hadamard)
541    {
542       const int *ordery = ordery_table+stride-2;
543       for (i=0;i<stride;i++)
544       {
545          for (j=0;j<N0;j++)
546             tmp[ordery[i]*N0+j] = X[j*stride+i];
547       }
548    } else {
549       for (i=0;i<stride;i++)
550          for (j=0;j<N0;j++)
551             tmp[i*N0+j] = X[j*stride+i];
552    }
553    for (j=0;j<N;j++)
554       X[j] = tmp[j];
555    RESTORE_STACK;
556 }
557
558 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
559 {
560    int i,j;
561    VARDECL(celt_norm, tmp);
562    int N;
563    SAVE_STACK;
564    N = N0*stride;
565    ALLOC(tmp, N, celt_norm);
566    if (hadamard)
567    {
568       const int *ordery = ordery_table+stride-2;
569       for (i=0;i<stride;i++)
570          for (j=0;j<N0;j++)
571             tmp[j*stride+i] = X[ordery[i]*N0+j];
572    } else {
573       for (i=0;i<stride;i++)
574          for (j=0;j<N0;j++)
575             tmp[j*stride+i] = X[i*N0+j];
576    }
577    for (j=0;j<N;j++)
578       X[j] = tmp[j];
579    RESTORE_STACK;
580 }
581
582 void haar1(celt_norm *X, int N0, int stride)
583 {
584    int i, j;
585    N0 >>= 1;
586    for (i=0;i<stride;i++)
587       for (j=0;j<N0;j++)
588       {
589          celt_norm tmp1, tmp2;
590          tmp1 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*2*j+i]);
591          tmp2 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
592          X[stride*2*j+i] = tmp1 + tmp2;
593          X[stride*(2*j+1)+i] = tmp1 - tmp2;
594       }
595 }
596
597 static int compute_qn(int N, int b, int offset, int stereo)
598 {
599    static const celt_int16 exp2_table8[8] =
600       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
601    int qn, qb;
602    int N2 = 2*N-1;
603    if (stereo && N==2)
604       N2--;
605    qb = IMIN((b>>1)-(1<<BITRES), (b+N2*offset)/N2);
606
607    qb = IMAX(0, IMIN(8<<BITRES, qb));
608
609    if (qb<(1<<BITRES>>1)) {
610       qn = 1;
611    } else {
612       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
613       qn = (qn+1)>>1<<1;
614    }
615    celt_assert(qn <= 256);
616    return qn;
617 }
618
619 /* This function is responsible for encoding and decoding a band for both
620    the mono and stereo case. Even in the mono case, it can split the band
621    in two and transmit the energy difference with the two half-bands. It
622    can be called recursively so bands can end up being split in 8 parts. */
623 static unsigned quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
624       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
625       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
626       celt_uint32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
627 {
628    int q;
629    int curr_bits;
630    int stereo, split;
631    int imid=0, iside=0;
632    int N0=N;
633    int N_B=N;
634    int N_B0;
635    int B0=B;
636    int time_divide=0;
637    int recombine=0;
638    int inv = 0;
639    celt_word16 mid=0, side=0;
640    int longBlocks;
641    unsigned cm=0;
642
643    longBlocks = B0==1;
644
645    N_B /= B;
646    N_B0 = N_B;
647
648    split = stereo = Y != NULL;
649
650    /* Special case for one sample */
651    if (N==1)
652    {
653       int c;
654       celt_norm *x = X;
655       c=0; do {
656          int sign=0;
657          if (*remaining_bits>=1<<BITRES)
658          {
659             if (encode)
660             {
661                sign = x[0]<0;
662                ec_enc_bits((ec_enc*)ec, sign, 1);
663             } else {
664                sign = ec_dec_bits((ec_dec*)ec, 1);
665             }
666             *remaining_bits -= 1<<BITRES;
667             b-=1<<BITRES;
668          }
669          if (resynth)
670             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
671          x = Y;
672       } while (++c<1+stereo);
673       if (lowband_out)
674          lowband_out[0] = SHR16(X[0],4);
675       return 1;
676    }
677
678    if (!stereo && level == 0)
679    {
680       int k;
681       if (tf_change>0)
682          recombine = tf_change;
683       /* Band recombining to increase frequency resolution */
684
685       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
686       {
687          int j;
688          for (j=0;j<N;j++)
689             lowband_scratch[j] = lowband[j];
690          lowband = lowband_scratch;
691       }
692
693       for (k=0;k<recombine;k++)
694       {
695          if (encode)
696             haar1(X, N>>k, 1<<k);
697          if (lowband)
698             haar1(lowband, N>>k, 1<<k);
699          fill |= fill<<(1<<k);
700       }
701       B>>=recombine;
702       N_B<<=recombine;
703
704       /* Increasing the time resolution */
705       while ((N_B&1) == 0 && tf_change<0)
706       {
707          if (encode)
708             haar1(X, N_B, B);
709          if (lowband)
710             haar1(lowband, N_B, B);
711          fill |= fill<<B;
712          B <<= 1;
713          N_B >>= 1;
714          time_divide++;
715          tf_change++;
716       }
717       B0=B;
718       N_B0 = N_B;
719
720       /* Reorganize the samples in time order instead of frequency order */
721       if (B0>1)
722       {
723          if (encode)
724             deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
725          if (lowband)
726             deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
727       }
728    }
729
730    /* If we need more than 32 bits, try splitting the band in two. */
731    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
732    {
733       if (LM>0 || (N&1)==0)
734       {
735          N >>= 1;
736          Y = X+N;
737          split = 1;
738          LM -= 1;
739          if (B==1)
740             fill |= fill<<1;
741          B = (B+1)>>1;
742       }
743    }
744
745    if (split)
746    {
747       int qn;
748       int itheta=0;
749       int mbits, sbits, delta;
750       int qalloc;
751       int offset;
752       celt_int32 tell;
753
754       /* Decide on the resolution to give to the split parameter theta */
755       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
756       qn = compute_qn(N, b, offset, stereo);
757       if (stereo && i>=intensity)
758          qn = 1;
759       if (encode)
760       {
761          /* theta is the atan() of the ratio between the (normalized)
762             side and mid. With just that parameter, we can re-scale both
763             mid and side because we know that 1) they have unit norm and
764             2) they are orthogonal. */
765          itheta = stereo_itheta(X, Y, stereo, N);
766       }
767       tell = encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES);
768       if (qn!=1)
769       {
770          if (encode)
771             itheta = (itheta*qn+8192)>>14;
772
773          /* Entropy coding of the angle. We use a uniform pdf for the
774             time split, a step for stereo, and a triangular one for the rest. */
775          if (stereo && N>2)
776          {
777             int p0 = 3;
778             int x = itheta;
779             int x0 = qn/2;
780             int ft = p0*(x0+1) + x0;
781             /* Use a probability of p0 up to itheta=8192 and then use 1 after */
782             if (encode)
783             {
784                ec_encode((ec_enc*)ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
785             } else {
786                int fs;
787                fs=ec_decode(ec,ft);
788                if (fs<(x0+1)*p0)
789                   x=fs/p0;
790                else
791                   x=x0+1+(fs-(x0+1)*p0);
792                ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
793                itheta = x;
794             }
795          } else if (B0>1 || stereo) {
796             /* Uniform pdf */
797             if (encode)
798                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
799             else
800                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
801          } else {
802             int fs=1, ft;
803             ft = ((qn>>1)+1)*((qn>>1)+1);
804             if (encode)
805             {
806                int fl;
807
808                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
809                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
810                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
811
812                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
813             } else {
814                /* Triangular pdf */
815                int fl=0;
816                int fm;
817                fm = ec_decode((ec_dec*)ec, ft);
818
819                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
820                {
821                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
822                   fs = itheta + 1;
823                   fl = itheta*(itheta + 1)>>1;
824                }
825                else
826                {
827                   itheta = (2*(qn + 1)
828                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
829                   fs = qn + 1 - itheta;
830                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
831                }
832
833                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
834             }
835          }
836          itheta = (celt_int32)itheta*16384/qn;
837          if (encode && stereo)
838          {
839             if (itheta==0)
840                intensity_stereo(m, X, Y, bandE, i, N);
841             else
842                stereo_split(X, Y, N);
843          }
844          /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
845                   Let's do that at higher complexity */
846       } else if (stereo) {
847          if (encode)
848          {
849             inv = itheta > 8192;
850             if (inv)
851             {
852                int j;
853                for (j=0;j<N;j++)
854                   Y[j] = -Y[j];
855             }
856             intensity_stereo(m, X, Y, bandE, i, N);
857          }
858          if (b>2<<BITRES && *remaining_bits > 2<<BITRES)
859          {
860             if (encode)
861                ec_enc_bit_logp(ec, inv, 2);
862             else
863                inv = ec_dec_bit_logp(ec, 2);
864          } else
865             inv = 0;
866          itheta = 0;
867       }
868       qalloc = (encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES))
869                - tell;
870       b -= qalloc;
871
872       if (itheta == 0)
873       {
874          imid = 32767;
875          iside = 0;
876          fill &= (1<<B)-1;
877          delta = -16384;
878       } else if (itheta == 16384)
879       {
880          imid = 0;
881          iside = 32767;
882          fill &= (1<<B)-1<<B;
883          delta = 16384;
884       } else {
885          imid = bitexact_cos(itheta);
886          iside = bitexact_cos(16384-itheta);
887          /* This is the mid vs side allocation that minimizes squared error
888             in that band. */
889          delta = FRAC_MUL16(N-1<<7,bitexact_log2tan(iside,imid));
890       }
891
892 #ifdef FIXED_POINT
893       mid = imid;
894       side = iside;
895 #else
896       mid = (1.f/32768)*imid;
897       side = (1.f/32768)*iside;
898 #endif
899
900       /* This is a special case for N=2 that only works for stereo and takes
901          advantage of the fact that mid and side are orthogonal to encode
902          the side with just one bit. */
903       if (N==2 && stereo)
904       {
905          int c;
906          int sign=0;
907          celt_norm *x2, *y2;
908          mbits = b;
909          sbits = 0;
910          /* Only need one bit for the side */
911          if (itheta != 0 && itheta != 16384)
912             sbits = 1<<BITRES;
913          mbits -= sbits;
914          c = itheta > 8192;
915          *remaining_bits -= qalloc+sbits;
916
917          x2 = c ? Y : X;
918          y2 = c ? X : Y;
919          if (sbits)
920          {
921             if (encode)
922             {
923                /* Here we only need to encode a sign for the side */
924                sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
925                ec_enc_bits((ec_enc*)ec, sign, 1);
926             } else {
927                sign = ec_dec_bits((ec_dec*)ec, 1);
928             }
929          }
930          sign = 1-2*sign;
931          cm = quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, fill);
932          /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
933              and there's no need to worry about mixing with the other channel. */
934          y2[0] = -sign*x2[1];
935          y2[1] = sign*x2[0];
936          if (resynth)
937          {
938             celt_norm tmp;
939             X[0] = MULT16_16_Q15(mid, X[0]);
940             X[1] = MULT16_16_Q15(mid, X[1]);
941             Y[0] = MULT16_16_Q15(side, Y[0]);
942             Y[1] = MULT16_16_Q15(side, Y[1]);
943             tmp = X[0];
944             X[0] = SUB16(tmp,Y[0]);
945             Y[0] = ADD16(tmp,Y[0]);
946             tmp = X[1];
947             X[1] = SUB16(tmp,Y[1]);
948             Y[1] = ADD16(tmp,Y[1]);
949          }
950       } else {
951          /* "Normal" split code */
952          celt_norm *next_lowband2=NULL;
953          celt_norm *next_lowband_out1=NULL;
954          int next_level=0;
955
956          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
957          if (B0>1 && !stereo)
958          {
959             if (itheta > 8192)
960                /* Rough approximation for pre-echo masking */
961                delta -= delta>>(4-LM);
962             else
963                /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
964                delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
965          }
966          mbits = IMAX(0, IMIN(b, (b-delta)/2));
967          sbits = b-mbits;
968          *remaining_bits -= qalloc;
969
970          if (lowband && !stereo)
971             next_lowband2 = lowband+N; /* >32-bit split case */
972
973          /* Only stereo needs to pass on lowband_out. Otherwise, it's
974             handled at the end */
975          if (stereo)
976             next_lowband_out1 = lowband_out;
977          else
978             next_level = level+1;
979
980          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
981             mid for folding later */
982          cm = quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
983                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
984                NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
985          /* For a stereo split, the high bits of fill are always zero, so no
986              folding will be done to the side. */
987          cm |= quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
988                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
989                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<B;
990       }
991
992    } else {
993       /* This is the basic no-split case */
994       q = bits2pulses(m, i, LM, b);
995       curr_bits = pulses2bits(m, i, LM, q);
996       *remaining_bits -= curr_bits;
997
998       /* Ensures we can never bust the budget */
999       while (*remaining_bits < 0 && q > 0)
1000       {
1001          *remaining_bits += curr_bits;
1002          q--;
1003          curr_bits = pulses2bits(m, i, LM, q);
1004          *remaining_bits -= curr_bits;
1005       }
1006
1007       if (q!=0)
1008       {
1009          int K = get_pulses(q);
1010
1011          /* Finally do the actual quantization */
1012          if (encode)
1013             cm = alg_quant(X, N, K, spread, B, resynth, (ec_enc*)ec, gain);
1014          else
1015             cm = alg_unquant(X, N, K, spread, B, (ec_dec*)ec, gain);
1016       } else {
1017          /* If there's no pulse, fill the band anyway */
1018          int j;
1019          if (resynth)
1020          {
1021             if (!fill)
1022             {
1023                for (j=0;j<N;j++)
1024                   X[j] = 0;
1025             } else {
1026                if (lowband == NULL || (spread==SPREAD_AGGRESSIVE && B<=1))
1027                {
1028                   /* Noise */
1029                   for (j=0;j<N;j++)
1030                   {
1031                      *seed = lcg_rand(*seed);
1032                      X[j] = (celt_int32)(*seed)>>20;
1033                   }
1034                   cm = (1<<B)-1;
1035                } else {
1036                   /* Folded spectrum */
1037                   for (j=0;j<N;j++)
1038                      X[j] = lowband[j];
1039                   cm = fill;
1040                }
1041                renormalise_vector(X, N, gain);
1042             }
1043          }
1044       }
1045    }
1046
1047    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1048    if (resynth)
1049    {
1050       if (stereo)
1051       {
1052          if (N!=2)
1053          {
1054             cm |= cm>>B;
1055             stereo_merge(X, Y, mid, N);
1056          }
1057          if (inv)
1058          {
1059             int j;
1060             for (j=0;j<N;j++)
1061                Y[j] = -Y[j];
1062          }
1063       } else if (level == 0)
1064       {
1065          int k;
1066
1067          /* Undo the sample reorganization going from time order to frequency order */
1068          if (B0>1)
1069             interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1070
1071          /* Undo time-freq changes that we did earlier */
1072          N_B = N_B0;
1073          B = B0;
1074          for (k=0;k<time_divide;k++)
1075          {
1076             B >>= 1;
1077             N_B <<= 1;
1078             cm |= cm>>B;
1079             haar1(X, N_B, B);
1080          }
1081
1082          for (k=0;k<recombine;k++)
1083          {
1084             cm |= cm<<(1<<k);
1085             haar1(X, N0>>k, 1<<k);
1086          }
1087          B<<=recombine;
1088          N_B>>=recombine;
1089
1090          /* Scale output for later folding */
1091          if (lowband_out)
1092          {
1093             int j;
1094             celt_word16 n;
1095             n = celt_sqrt(SHL32(EXTEND32(N0),22));
1096             for (j=0;j<N0;j++)
1097                lowband_out[j] = MULT16_16_Q15(n,X[j]);
1098          }
1099       }
1100    }
1101    return cm;
1102 }
1103
1104 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1105       celt_norm *_X, celt_norm *_Y, unsigned char *collapse_masks, const celt_ener *bandE, int *pulses,
1106       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
1107       int total_bits, void *ec, int LM, int codedBands, ec_uint32 *seed)
1108 {
1109    int i;
1110    celt_int32 balance;
1111    celt_int32 remaining_bits;
1112    const celt_int16 * restrict eBands = m->eBands;
1113    celt_norm * restrict norm, * restrict norm2;
1114    VARDECL(celt_norm, _norm);
1115    VARDECL(celt_norm, lowband_scratch);
1116    int B;
1117    int M;
1118    int lowband_offset;
1119    int update_lowband = 1;
1120    int C = _Y != NULL ? 2 : 1;
1121    SAVE_STACK;
1122
1123    M = 1<<LM;
1124    B = shortBlocks ? M : 1;
1125    ALLOC(_norm, C*M*eBands[m->nbEBands], celt_norm);
1126    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
1127    norm = _norm;
1128    norm2 = norm + M*eBands[m->nbEBands];
1129
1130    balance = 0;
1131    lowband_offset = 0;
1132    for (i=start;i<end;i++)
1133    {
1134       celt_int32 tell;
1135       int b;
1136       int N;
1137       celt_int32 curr_balance;
1138       int effective_lowband=-1;
1139       celt_norm * restrict X, * restrict Y;
1140       int tf_change=0;
1141       unsigned x_cm;
1142       unsigned y_cm;
1143
1144       X = _X+M*eBands[i];
1145       if (_Y!=NULL)
1146          Y = _Y+M*eBands[i];
1147       else
1148          Y = NULL;
1149       N = M*eBands[i+1]-M*eBands[i];
1150       if (encode)
1151          tell = ec_enc_tell((ec_enc*)ec, BITRES);
1152       else
1153          tell = ec_dec_tell((ec_dec*)ec, BITRES);
1154
1155       /* Compute how many bits we want to allocate to this band */
1156       if (i != start)
1157          balance -= tell;
1158       remaining_bits = ((celt_int32)total_bits<<BITRES)-tell-1- (shortBlocks&&LM>=2 ? (1<<BITRES) : 0);
1159       if (i <= codedBands-1)
1160       {
1161          curr_balance = balance / IMIN(3, codedBands-i);
1162          b = IMAX(0, IMIN(16384, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1163       } else {
1164          b = 0;
1165       }
1166
1167       if (resynth && M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1168             lowband_offset = i;
1169
1170       tf_change = tf_res[i];
1171       if (i>=m->effEBands)
1172       {
1173          X=norm;
1174          if (_Y!=NULL)
1175             Y = norm;
1176       }
1177
1178       /* This ensures we never repeat spectral content within one band */
1179       if (lowband_offset != 0)
1180          effective_lowband = IMAX(M*eBands[start], M*eBands[lowband_offset]-N);
1181
1182       /* Get a conservative estimate of the collapse_mask's for the bands we're
1183           going to be folding from. */
1184       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1))
1185       {
1186          int fold_start;
1187          int fold_end;
1188          int fold_i;
1189          fold_start = lowband_offset;
1190          while(M*eBands[--fold_start] > effective_lowband);
1191          fold_end = lowband_offset-1;
1192          while(M*eBands[++fold_end] < effective_lowband+N);
1193          x_cm = y_cm = 0;
1194          fold_i = fold_start; do {
1195            x_cm |= collapse_masks[fold_i*C+0];
1196            y_cm |= collapse_masks[fold_i*C+C-1];
1197          } while (++fold_i<fold_end);
1198       }
1199       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1200           always) be non-zero.*/
1201       else
1202          x_cm = y_cm = (1<<B)-1;
1203
1204       if (dual_stereo && i==intensity)
1205       {
1206          int j;
1207
1208          /* Switch off dual stereo to do intensity */
1209          dual_stereo = 0;
1210          for (j=M*eBands[start];j<M*eBands[i];j++)
1211             norm[j] = HALF32(norm[j]+norm2[j]);
1212       }
1213       if (dual_stereo)
1214       {
1215          x_cm = quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
1216                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1217                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm);
1218          y_cm = quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
1219                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1220                norm2+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, y_cm);
1221       } else {
1222          x_cm = quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
1223                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1224                norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm|y_cm);
1225          y_cm = x_cm;
1226       }
1227       collapse_masks[i*C+0] = (unsigned char)(x_cm&(1<<B)-1);
1228       collapse_masks[i*C+C-1] = (unsigned char)(y_cm&(1<<B)-1);
1229       balance += pulses[i] + tell;
1230
1231       /* Update the folding position only as long as we have 1 bit/sample depth */
1232       update_lowband = (b>>BITRES)>N;
1233    }
1234    RESTORE_STACK;
1235 }
1236