Make collapse-detection bitexact.
[opus.git] / libcelt / bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008-2009 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #include <math.h>
39 #include "bands.h"
40 #include "modes.h"
41 #include "vq.h"
42 #include "cwrs.h"
43 #include "stack_alloc.h"
44 #include "os_support.h"
45 #include "mathops.h"
46 #include "rate.h"
47
48 static celt_uint32 lcg_rand(celt_uint32 seed)
49 {
50    return 1664525 * seed + 1013904223;
51 }
52
53 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
54    with this approximation is important because it has an impact on the bit allocation */
55 static celt_int16 bitexact_cos(celt_int16 x)
56 {
57    celt_int32 tmp;
58    celt_int16 x2;
59    tmp = (4096+((celt_int32)(x)*(x)))>>13;
60    if (tmp > 32767)
61       tmp = 32767;
62    x2 = tmp;
63    x2 = (32767-x2) + FRAC_MUL16(x2, (-7651 + FRAC_MUL16(x2, (8277 + FRAC_MUL16(-626, x2)))));
64    if (x2 > 32766)
65       x2 = 32766;
66    return 1+x2;
67 }
68
69 static int bitexact_log2tan(int isin,int icos)
70 {
71    int lc;
72    int ls;
73    lc=EC_ILOG(icos);
74    ls=EC_ILOG(isin);
75    icos<<=15-lc;
76    isin<<=15-ls;
77    return (ls-lc<<11)
78          +FRAC_MUL16(isin, FRAC_MUL16(isin, -2597) + 7932)
79          -FRAC_MUL16(icos, FRAC_MUL16(icos, -2597) + 7932);
80 }
81
82 #ifdef FIXED_POINT
83 /* Compute the amplitude (sqrt energy) in each of the bands */
84 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
85 {
86    int i, c, N;
87    const celt_int16 *eBands = m->eBands;
88    const int C = CHANNELS(_C);
89    N = M*m->shortMdctSize;
90    c=0; do {
91       for (i=0;i<end;i++)
92       {
93          int j;
94          celt_word32 maxval=0;
95          celt_word32 sum = 0;
96          
97          j=M*eBands[i]; do {
98             maxval = MAX32(maxval, X[j+c*N]);
99             maxval = MAX32(maxval, -X[j+c*N]);
100          } while (++j<M*eBands[i+1]);
101          
102          if (maxval > 0)
103          {
104             int shift = celt_ilog2(maxval)-10;
105             j=M*eBands[i]; do {
106                sum = MAC16_16(sum, EXTRACT16(VSHR32(X[j+c*N],shift)),
107                                    EXTRACT16(VSHR32(X[j+c*N],shift)));
108             } while (++j<M*eBands[i+1]);
109             /* We're adding one here to make damn sure we never end up with a pitch vector that's
110                larger than unity norm */
111             bank[i+c*m->nbEBands] = EPSILON+VSHR32(EXTEND32(celt_sqrt(sum)),-shift);
112          } else {
113             bank[i+c*m->nbEBands] = EPSILON;
114          }
115          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
116       }
117    } while (++c<C);
118    /*printf ("\n");*/
119 }
120
121 /* Normalise each band such that the energy is one. */
122 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
123 {
124    int i, c, N;
125    const celt_int16 *eBands = m->eBands;
126    const int C = CHANNELS(_C);
127    N = M*m->shortMdctSize;
128    c=0; do {
129       i=0; do {
130          celt_word16 g;
131          int j,shift;
132          celt_word16 E;
133          shift = celt_zlog2(bank[i+c*m->nbEBands])-13;
134          E = VSHR32(bank[i+c*m->nbEBands], shift);
135          g = EXTRACT16(celt_rcp(SHL32(E,3)));
136          j=M*eBands[i]; do {
137             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
138          } while (++j<M*eBands[i+1]);
139       } while (++i<end);
140    } while (++c<C);
141 }
142
143 #else /* FIXED_POINT */
144 /* Compute the amplitude (sqrt energy) in each of the bands */
145 void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
146 {
147    int i, c, N;
148    const celt_int16 *eBands = m->eBands;
149    const int C = CHANNELS(_C);
150    N = M*m->shortMdctSize;
151    c=0; do {
152       for (i=0;i<end;i++)
153       {
154          int j;
155          celt_word32 sum = 1e-10f;
156          for (j=M*eBands[i];j<M*eBands[i+1];j++)
157             sum += X[j+c*N]*X[j+c*N];
158          bank[i+c*m->nbEBands] = celt_sqrt(sum);
159          /*printf ("%f ", bank[i+c*m->nbEBands]);*/
160       }
161    } while (++c<C);
162    /*printf ("\n");*/
163 }
164
165 /* Normalise each band such that the energy is one. */
166 void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
167 {
168    int i, c, N;
169    const celt_int16 *eBands = m->eBands;
170    const int C = CHANNELS(_C);
171    N = M*m->shortMdctSize;
172    c=0; do {
173       for (i=0;i<end;i++)
174       {
175          int j;
176          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
177          for (j=M*eBands[i];j<M*eBands[i+1];j++)
178             X[j+c*N] = freq[j+c*N]*g;
179       }
180    } while (++c<C);
181 }
182
183 #endif /* FIXED_POINT */
184
185 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
186 void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
187 {
188    int i, c, N;
189    const celt_int16 *eBands = m->eBands;
190    const int C = CHANNELS(_C);
191    N = M*m->shortMdctSize;
192    celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
193    c=0; do {
194       celt_sig * restrict f;
195       const celt_norm * restrict x;
196       f = freq+c*N;
197       x = X+c*N;
198       for (i=0;i<end;i++)
199       {
200          int j, band_end;
201          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
202          j=M*eBands[i];
203          band_end = M*eBands[i+1];
204          do {
205             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
206             x++;
207          } while (++j<band_end);
208       }
209       for (i=M*eBands[m->nbEBands];i<N;i++)
210          *f++ = 0;
211    } while (++c<C);
212 }
213
214 /* This prevents energy collapse for transients with multiple short MDCTs */
215 void anti_collapse(const CELTMode *m, celt_norm *_X, unsigned char *collapse_masks, int LM, int C, int size,
216       int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
217       celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
218 {
219    int c, i, j, k;
220    c=0; do
221    {
222       for (i=start;i<end;i++)
223       {
224          celt_norm *X;
225          int N0;
226          celt_word16 Ediff;
227          celt_word16 r;
228          celt_word16 thresh;
229          int depth;
230
231          N0 = m->eBands[i+1]-m->eBands[i];
232          Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
233          Ediff = MAX16(0, Ediff);
234          depth = (1+(pulses[i]>>BITRES))/(m->eBands[i+1]-m->eBands[i]<<LM);
235
236 #ifdef FIXED_POINT
237          thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11)),1) ));
238          if (Ediff < 16384)
239             r = 2*MIN16(16383,SHR32(celt_exp2(-SHL16(Ediff, 11-DB_SHIFT)),1));
240          else
241             r = 0;
242          r = SHR16(MIN16(thresh, r),1);
243          {
244             int shift;
245             celt_word32 t;
246             t = N0<<LM;
247             shift = celt_ilog2(t)>>1;
248             t = SHL32(t, (7-shift)<<1);
249             r = SHR32(MULT16_16_Q15(celt_rsqrt_norm(t), r),shift);
250          }
251 #else
252          thresh = .3f*celt_exp2(-depth);
253          r = 2.f*celt_exp2(-Ediff);
254          r = MIN16(thresh, r);
255          r = r*celt_rsqrt(N0<<LM);
256 #endif
257          X = _X+c*size+(m->eBands[i]<<LM);
258          for (k=0;k<1<<LM;k++)
259          {
260             /* Detect collapse */
261             if (!(collapse_masks[i*C+c]&1<<k))
262             {
263                /* Fill with noise */
264                for (j=0;j<N0;j++)
265                {
266                   seed = lcg_rand(seed);
267                   X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
268                }
269             }
270          }
271          /* We just added some energy, so we need to renormalise */
272          renormalise_vector(X, N0<<LM, Q15ONE);
273       }
274    } while (++c<C);
275
276 }
277
278
279 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
280 {
281    int i = bandID;
282    int j;
283    celt_word16 a1, a2;
284    celt_word16 left, right;
285    celt_word16 norm;
286 #ifdef FIXED_POINT
287    int shift = celt_zlog2(MAX32(bank[i], bank[i+m->nbEBands]))-13;
288 #endif
289    left = VSHR32(bank[i],shift);
290    right = VSHR32(bank[i+m->nbEBands],shift);
291    norm = EPSILON + celt_sqrt(EPSILON+MULT16_16(left,left)+MULT16_16(right,right));
292    a1 = DIV32_16(SHL32(EXTEND32(left),14),norm);
293    a2 = DIV32_16(SHL32(EXTEND32(right),14),norm);
294    for (j=0;j<N;j++)
295    {
296       celt_norm r, l;
297       l = X[j];
298       r = Y[j];
299       X[j] = MULT16_16_Q14(a1,l) + MULT16_16_Q14(a2,r);
300       /* Side is not encoded, no need to calculate */
301    }
302 }
303
304 static void stereo_split(celt_norm *X, celt_norm *Y, int N)
305 {
306    int j;
307    for (j=0;j<N;j++)
308    {
309       celt_norm r, l;
310       l = MULT16_16_Q15(QCONST16(.70710678f,15), X[j]);
311       r = MULT16_16_Q15(QCONST16(.70710678f,15), Y[j]);
312       X[j] = l+r;
313       Y[j] = r-l;
314    }
315 }
316
317 static void stereo_merge(celt_norm *X, celt_norm *Y, celt_word16 mid, int N)
318 {
319    int j;
320    celt_word32 xp=0, side=0;
321    celt_word32 El, Er;
322    celt_word16 mid2;
323 #ifdef FIXED_POINT
324    int kl, kr;
325 #endif
326    celt_word32 t, lgain, rgain;
327
328    /* Compute the norm of X+Y and X-Y as |X|^2 + |Y|^2 +/- sum(xy) */
329    for (j=0;j<N;j++)
330    {
331       xp = MAC16_16(xp, X[j], Y[j]);
332       side = MAC16_16(side, Y[j], Y[j]);
333    }
334    /* Compensating for the mid normalization */
335    xp = MULT16_32_Q15(mid, xp);
336    /* mid and side are in Q15, not Q14 like X and Y */
337    mid2 = SHR32(mid, 1);
338    El = MULT16_16(mid2, mid2) + side - 2*xp;
339    Er = MULT16_16(mid2, mid2) + side + 2*xp;
340    if (Er < EPSILON)
341       Er = EPSILON;
342    if (El < EPSILON)
343       El = EPSILON;
344
345 #ifdef FIXED_POINT
346    kl = celt_ilog2(El)>>1;
347    kr = celt_ilog2(Er)>>1;
348 #endif
349    t = VSHR32(El, (kl-7)<<1);
350    lgain = celt_rsqrt_norm(t);
351    t = VSHR32(Er, (kr-7)<<1);
352    rgain = celt_rsqrt_norm(t);
353
354 #ifdef FIXED_POINT
355    if (kl < 7)
356       kl = 7;
357    if (kr < 7)
358       kr = 7;
359 #endif
360
361    for (j=0;j<N;j++)
362    {
363       celt_norm r, l;
364       /* Apply mid scaling (side is already scaled) */
365       l = MULT16_16_Q15(mid, X[j]);
366       r = Y[j];
367       X[j] = EXTRACT16(PSHR32(MULT16_16(lgain, SUB16(l,r)), kl+1));
368       Y[j] = EXTRACT16(PSHR32(MULT16_16(rgain, ADD16(l,r)), kr+1));
369    }
370 }
371
372 /* Decide whether we should spread the pulses in the current frame */
373 int spreading_decision(const CELTMode *m, celt_norm *X, int *average,
374       int last_decision, int *hf_average, int *tapset_decision, int update_hf,
375       int end, int _C, int M)
376 {
377    int i, c, N0;
378    int sum = 0, nbBands=0;
379    const int C = CHANNELS(_C);
380    const celt_int16 * restrict eBands = m->eBands;
381    int decision;
382    int hf_sum=0;
383    
384    N0 = M*m->shortMdctSize;
385
386    if (M*(eBands[end]-eBands[end-1]) <= 8)
387       return SPREAD_NONE;
388    c=0; do {
389       for (i=0;i<end;i++)
390       {
391          int j, N, tmp=0;
392          int tcount[3] = {0};
393          celt_norm * restrict x = X+M*eBands[i]+c*N0;
394          N = M*(eBands[i+1]-eBands[i]);
395          if (N<=8)
396             continue;
397          /* Compute rough CDF of |x[j]| */
398          for (j=0;j<N;j++)
399          {
400             celt_word32 x2N; /* Q13 */
401
402             x2N = MULT16_16(MULT16_16_Q15(x[j], x[j]), N);
403             if (x2N < QCONST16(0.25f,13))
404                tcount[0]++;
405             if (x2N < QCONST16(0.0625f,13))
406                tcount[1]++;
407             if (x2N < QCONST16(0.015625f,13))
408                tcount[2]++;
409          }
410
411          /* Only include four last bands (8 kHz and up) */
412          if (i>m->nbEBands-4)
413             hf_sum += 32*(tcount[1]+tcount[0])/N;
414          tmp = (2*tcount[2] >= N) + (2*tcount[1] >= N) + (2*tcount[0] >= N);
415          sum += tmp*256;
416          nbBands++;
417       }
418    } while (++c<C);
419
420    if (update_hf)
421    {
422       if (hf_sum)
423          hf_sum /= C*(4-m->nbEBands+end);
424       *hf_average = (*hf_average+hf_sum)>>1;
425       hf_sum = *hf_average;
426       if (*tapset_decision==2)
427          hf_sum += 4;
428       else if (*tapset_decision==0)
429          hf_sum -= 4;
430       if (hf_sum > 22)
431          *tapset_decision=2;
432       else if (hf_sum > 18)
433          *tapset_decision=1;
434       else
435          *tapset_decision=0;
436    }
437    /*printf("%d %d %d\n", hf_sum, *hf_average, *tapset_decision);*/
438    sum /= nbBands;
439    /* Recursive averaging */
440    sum = (sum+*average)>>1;
441    *average = sum;
442    /* Hysteresis */
443    sum = (3*sum + (((3-last_decision)<<7) + 64) + 2)>>2;
444    if (sum < 80)
445    {
446       decision = SPREAD_AGGRESSIVE;
447    } else if (sum < 256)
448    {
449       decision = SPREAD_NORMAL;
450    } else if (sum < 384)
451    {
452       decision = SPREAD_LIGHT;
453    } else {
454       decision = SPREAD_NONE;
455    }
456    return decision;
457 }
458
459 #ifdef MEASURE_NORM_MSE
460
461 float MSE[30] = {0};
462 int nbMSEBands = 0;
463 int MSECount[30] = {0};
464
465 void dump_norm_mse(void)
466 {
467    int i;
468    for (i=0;i<nbMSEBands;i++)
469    {
470       printf ("%g ", MSE[i]/MSECount[i]);
471    }
472    printf ("\n");
473 }
474
475 void measure_norm_mse(const CELTMode *m, float *X, float *X0, float *bandE, float *bandE0, int M, int N, int C)
476 {
477    static int init = 0;
478    int i;
479    if (!init)
480    {
481       atexit(dump_norm_mse);
482       init = 1;
483    }
484    for (i=0;i<m->nbEBands;i++)
485    {
486       int j;
487       int c;
488       float g;
489       if (bandE0[i]<10 || (C==2 && bandE0[i+m->nbEBands]<1))
490          continue;
491       c=0; do {
492          g = bandE[i+c*m->nbEBands]/(1e-15+bandE0[i+c*m->nbEBands]);
493          for (j=M*m->eBands[i];j<M*m->eBands[i+1];j++)
494             MSE[i] += (g*X[j+c*N]-X0[j+c*N])*(g*X[j+c*N]-X0[j+c*N]);
495       } while (++c<C);
496       MSECount[i]+=C;
497    }
498    nbMSEBands = m->nbEBands;
499 }
500
501 #endif
502
503 /* Indexing table for converting from natural Hadamard to ordery Hadamard
504    This is essentially a bit-reversed Gray, on top of which we've added
505    an inversion of the order because we want the DC at the end rather than
506    the beginning. The lines are for N=2, 4, 8, 16 */
507 static const int ordery_table[] = {
508        1,  0,
509        3,  0,  2,  1,
510        7,  0,  4,  3,  6,  1,  5,  2,
511       15,  0,  8,  7, 12,  3, 11,  4, 14,  1,  9,  6, 13,  2, 10,  5,
512 };
513
514 static void deinterleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
515 {
516    int i,j;
517    VARDECL(celt_norm, tmp);
518    int N;
519    SAVE_STACK;
520    N = N0*stride;
521    ALLOC(tmp, N, celt_norm);
522    if (hadamard)
523    {
524       const int *ordery = ordery_table+stride-2;
525       for (i=0;i<stride;i++)
526       {
527          for (j=0;j<N0;j++)
528             tmp[ordery[i]*N0+j] = X[j*stride+i];
529       }
530    } else {
531       for (i=0;i<stride;i++)
532          for (j=0;j<N0;j++)
533             tmp[i*N0+j] = X[j*stride+i];
534    }
535    for (j=0;j<N;j++)
536       X[j] = tmp[j];
537    RESTORE_STACK;
538 }
539
540 static void interleave_hadamard(celt_norm *X, int N0, int stride, int hadamard)
541 {
542    int i,j;
543    VARDECL(celt_norm, tmp);
544    int N;
545    SAVE_STACK;
546    N = N0*stride;
547    ALLOC(tmp, N, celt_norm);
548    if (hadamard)
549    {
550       const int *ordery = ordery_table+stride-2;
551       for (i=0;i<stride;i++)
552          for (j=0;j<N0;j++)
553             tmp[j*stride+i] = X[ordery[i]*N0+j];
554    } else {
555       for (i=0;i<stride;i++)
556          for (j=0;j<N0;j++)
557             tmp[j*stride+i] = X[i*N0+j];
558    }
559    for (j=0;j<N;j++)
560       X[j] = tmp[j];
561    RESTORE_STACK;
562 }
563
564 void haar1(celt_norm *X, int N0, int stride)
565 {
566    int i, j;
567    N0 >>= 1;
568    for (i=0;i<stride;i++)
569       for (j=0;j<N0;j++)
570       {
571          celt_norm tmp1, tmp2;
572          tmp1 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*2*j+i]);
573          tmp2 = MULT16_16_Q15(QCONST16(.70710678f,15), X[stride*(2*j+1)+i]);
574          X[stride*2*j+i] = tmp1 + tmp2;
575          X[stride*(2*j+1)+i] = tmp1 - tmp2;
576       }
577 }
578
579 static int compute_qn(int N, int b, int offset, int stereo)
580 {
581    static const celt_int16 exp2_table8[8] =
582       {16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048};
583    int qn, qb;
584    int N2 = 2*N-1;
585    if (stereo && N==2)
586       N2--;
587    qb = IMIN((b>>1)-(1<<BITRES), (b+N2*offset)/N2);
588
589    qb = IMAX(0, IMIN(8<<BITRES, qb));
590
591    if (qb<(1<<BITRES>>1)) {
592       qn = 1;
593    } else {
594       qn = exp2_table8[qb&0x7]>>(14-(qb>>BITRES));
595       qn = (qn+1)>>1<<1;
596    }
597    celt_assert(qn <= 256);
598    return qn;
599 }
600
601 /* This function is responsible for encoding and decoding a band for both
602    the mono and stereo case. Even in the mono case, it can split the band
603    in two and transmit the energy difference with the two half-bands. It
604    can be called recursively so bands can end up being split in 8 parts. */
605 static unsigned quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
606       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
607       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
608       celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
609 {
610    int q;
611    int curr_bits;
612    int stereo, split;
613    int imid=0, iside=0;
614    int N0=N;
615    int N_B=N;
616    int N_B0;
617    int B0=B;
618    int time_divide=0;
619    int recombine=0;
620    int inv = 0;
621    celt_word16 mid=0, side=0;
622    int longBlocks;
623    unsigned cm;
624
625    longBlocks = B0==1;
626
627    N_B /= B;
628    N_B0 = N_B;
629
630    split = stereo = Y != NULL;
631
632    /* Special case for one sample */
633    if (N==1)
634    {
635       int c;
636       celt_norm *x = X;
637       c=0; do {
638          int sign=0;
639          if (*remaining_bits>=1<<BITRES)
640          {
641             if (encode)
642             {
643                sign = x[0]<0;
644                ec_enc_bits((ec_enc*)ec, sign, 1);
645             } else {
646                sign = ec_dec_bits((ec_dec*)ec, 1);
647             }
648             *remaining_bits -= 1<<BITRES;
649             b-=1<<BITRES;
650          }
651          if (resynth)
652             x[0] = sign ? -NORM_SCALING : NORM_SCALING;
653          x = Y;
654       } while (++c<1+stereo);
655       if (lowband_out)
656          lowband_out[0] = SHR16(X[0],4);
657       return 1;
658    }
659
660    if (!stereo && level == 0)
661    {
662       int k;
663       if (tf_change>0)
664          recombine = tf_change;
665       /* Band recombining to increase frequency resolution */
666
667       if (lowband && (recombine || ((N_B&1) == 0 && tf_change<0) || B0>1))
668       {
669          int j;
670          for (j=0;j<N;j++)
671             lowband_scratch[j] = lowband[j];
672          lowband = lowband_scratch;
673       }
674
675       for (k=0;k<recombine;k++)
676       {
677          if (encode)
678             haar1(X, N>>k, 1<<k);
679          if (lowband)
680             haar1(lowband, N>>k, 1<<k);
681          fill |= fill<<(1<<k);
682       }
683       B>>=recombine;
684       N_B<<=recombine;
685
686       /* Increasing the time resolution */
687       while ((N_B&1) == 0 && tf_change<0)
688       {
689          if (encode)
690             haar1(X, N_B, B);
691          if (lowband)
692             haar1(lowband, N_B, B);
693          fill |= fill<<B;
694          B <<= 1;
695          N_B >>= 1;
696          time_divide++;
697          tf_change++;
698       }
699       B0=B;
700       N_B0 = N_B;
701
702       /* Reorganize the samples in time order instead of frequency order */
703       if (B0>1)
704       {
705          if (encode)
706             deinterleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
707          if (lowband)
708             deinterleave_hadamard(lowband, N_B>>recombine, B0<<recombine, longBlocks);
709       }
710    }
711
712    /* If we need more than 32 bits, try splitting the band in two. */
713    if (!stereo && LM != -1 && b > 32<<BITRES && N>2)
714    {
715       if (LM>0 || (N&1)==0)
716       {
717          N >>= 1;
718          Y = X+N;
719          split = 1;
720          LM -= 1;
721          if (B==1)
722             fill |= fill<<1;
723          B = (B+1)>>1;
724       }
725    }
726
727    if (split)
728    {
729       int qn;
730       int itheta=0;
731       int mbits, sbits, delta;
732       int qalloc;
733       int offset;
734       celt_int32 tell;
735
736       /* Decide on the resolution to give to the split parameter theta */
737       offset = ((m->logN[i]+(LM<<BITRES))>>1) - (stereo ? QTHETA_OFFSET_STEREO : QTHETA_OFFSET);
738       qn = compute_qn(N, b, offset, stereo);
739       if (stereo && i>=intensity)
740          qn = 1;
741       if (encode)
742       {
743          /* theta is the atan() of the ratio between the (normalized)
744             side and mid. With just that parameter, we can re-scale both
745             mid and side because we know that 1) they have unit norm and
746             2) they are orthogonal. */
747          itheta = stereo_itheta(X, Y, stereo, N);
748       }
749       tell = encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES);
750       if (qn!=1)
751       {
752          if (encode)
753             itheta = (itheta*qn+8192)>>14;
754
755          /* Entropy coding of the angle. We use a uniform pdf for the
756             time split, a step for stereo, and a triangular one for the rest. */
757          if (stereo && N>2)
758          {
759             int p0 = 3;
760             int x = itheta;
761             int x0 = qn/2;
762             int ft = p0*(x0+1) + x0;
763             /* Use a probability of p0 up to itheta=8192 and then use 1 after */
764             if (encode)
765             {
766                ec_encode((ec_enc*)ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
767             } else {
768                int fs;
769                fs=ec_decode(ec,ft);
770                if (fs<(x0+1)*p0)
771                   x=fs/p0;
772                else
773                   x=x0+1+(fs-(x0+1)*p0);
774                ec_dec_update(ec,x<=x0?p0*x:(x-1-x0)+(x0+1)*p0,x<=x0?p0*(x+1):(x-x0)+(x0+1)*p0,ft);
775                itheta = x;
776             }
777          } else if (B0>1 || stereo) {
778             /* Uniform pdf */
779             if (encode)
780                ec_enc_uint((ec_enc*)ec, itheta, qn+1);
781             else
782                itheta = ec_dec_uint((ec_dec*)ec, qn+1);
783          } else {
784             int fs=1, ft;
785             ft = ((qn>>1)+1)*((qn>>1)+1);
786             if (encode)
787             {
788                int fl;
789
790                fs = itheta <= (qn>>1) ? itheta + 1 : qn + 1 - itheta;
791                fl = itheta <= (qn>>1) ? itheta*(itheta + 1)>>1 :
792                 ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
793
794                ec_encode((ec_enc*)ec, fl, fl+fs, ft);
795             } else {
796                /* Triangular pdf */
797                int fl=0;
798                int fm;
799                fm = ec_decode((ec_dec*)ec, ft);
800
801                if (fm < ((qn>>1)*((qn>>1) + 1)>>1))
802                {
803                   itheta = (isqrt32(8*(celt_uint32)fm + 1) - 1)>>1;
804                   fs = itheta + 1;
805                   fl = itheta*(itheta + 1)>>1;
806                }
807                else
808                {
809                   itheta = (2*(qn + 1)
810                    - isqrt32(8*(celt_uint32)(ft - fm - 1) + 1))>>1;
811                   fs = qn + 1 - itheta;
812                   fl = ft - ((qn + 1 - itheta)*(qn + 2 - itheta)>>1);
813                }
814
815                ec_dec_update((ec_dec*)ec, fl, fl+fs, ft);
816             }
817          }
818          itheta = (celt_int32)itheta*16384/qn;
819          if (encode && stereo)
820          {
821             if (itheta==0)
822                intensity_stereo(m, X, Y, bandE, i, N);
823             else
824                stereo_split(X, Y, N);
825          }
826          /* TODO: Renormalising X and Y *may* help fixed-point a bit at very high rate.
827                   Let's do that at higher complexity */
828       } else if (stereo) {
829          if (encode)
830          {
831             inv = itheta > 8192;
832             if (inv)
833             {
834                int j;
835                for (j=0;j<N;j++)
836                   Y[j] = -Y[j];
837             }
838             intensity_stereo(m, X, Y, bandE, i, N);
839          }
840          if (b>2<<BITRES && *remaining_bits > 2<<BITRES)
841          {
842             if (encode)
843                ec_enc_bit_logp(ec, inv, 2);
844             else
845                inv = ec_dec_bit_logp(ec, 2);
846          } else
847             inv = 0;
848          itheta = 0;
849       }
850       qalloc = (encode ? ec_enc_tell(ec, BITRES) : ec_dec_tell(ec, BITRES))
851                - tell;
852       b -= qalloc;
853
854       if (itheta == 0)
855       {
856          imid = 32767;
857          iside = 0;
858          fill &= (1<<B)-1;
859          delta = -16384;
860       } else if (itheta == 16384)
861       {
862          imid = 0;
863          iside = 32767;
864          fill &= (1<<B)-1<<B;
865          delta = 16384;
866       } else {
867          imid = bitexact_cos(itheta);
868          iside = bitexact_cos(16384-itheta);
869          /* This is the mid vs side allocation that minimizes squared error
870             in that band. */
871          delta = FRAC_MUL16(N-1<<7,bitexact_log2tan(iside,imid));
872       }
873
874 #ifdef FIXED_POINT
875       mid = imid;
876       side = iside;
877 #else
878       mid = (1.f/32768)*imid;
879       side = (1.f/32768)*iside;
880 #endif
881
882       /* This is a special case for N=2 that only works for stereo and takes
883          advantage of the fact that mid and side are orthogonal to encode
884          the side with just one bit. */
885       if (N==2 && stereo)
886       {
887          int c;
888          int sign=0;
889          celt_norm *x2, *y2;
890          mbits = b;
891          sbits = 0;
892          /* Only need one bit for the side */
893          if (itheta != 0 && itheta != 16384)
894             sbits = 1<<BITRES;
895          mbits -= sbits;
896          c = itheta > 8192;
897          *remaining_bits -= qalloc+sbits;
898
899          x2 = c ? Y : X;
900          y2 = c ? X : Y;
901          if (sbits)
902          {
903             if (encode)
904             {
905                /* Here we only need to encode a sign for the side */
906                sign = x2[0]*y2[1] - x2[1]*y2[0] < 0;
907                ec_enc_bits((ec_enc*)ec, sign, 1);
908             } else {
909                sign = ec_dec_bits((ec_dec*)ec, 1);
910             }
911          }
912          sign = 1-2*sign;
913          cm = quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, fill);
914          /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
915              and there's no need to worry about mixing with the other channel. */
916          y2[0] = -sign*x2[1];
917          y2[1] = sign*x2[0];
918          if (resynth)
919          {
920             celt_norm tmp;
921             X[0] = MULT16_16_Q15(mid, X[0]);
922             X[1] = MULT16_16_Q15(mid, X[1]);
923             Y[0] = MULT16_16_Q15(side, Y[0]);
924             Y[1] = MULT16_16_Q15(side, Y[1]);
925             tmp = X[0];
926             X[0] = SUB16(tmp,Y[0]);
927             Y[0] = ADD16(tmp,Y[0]);
928             tmp = X[1];
929             X[1] = SUB16(tmp,Y[1]);
930             Y[1] = ADD16(tmp,Y[1]);
931          }
932       } else {
933          /* "Normal" split code */
934          celt_norm *next_lowband2=NULL;
935          celt_norm *next_lowband_out1=NULL;
936          int next_level=0;
937
938          /* Give more bits to low-energy MDCTs than they would otherwise deserve */
939          if (B0>1 && !stereo)
940          {
941             if (itheta > 8192)
942                /* Rough approximation for pre-echo masking */
943                delta -= delta>>(4-LM);
944             else
945                /* Corresponds to a forward-masking slope of 1.5 dB per 10 ms */
946                delta = IMIN(0, delta + (N<<BITRES>>(5-LM)));
947          }
948          mbits = IMAX(0, IMIN(b, (b-delta)/2));
949          sbits = b-mbits;
950          *remaining_bits -= qalloc;
951
952          if (lowband && !stereo)
953             next_lowband2 = lowband+N; /* >32-bit split case */
954
955          /* Only stereo needs to pass on lowband_out. Otherwise, it's
956             handled at the end */
957          if (stereo)
958             next_lowband_out1 = lowband_out;
959          else
960             next_level = level+1;
961
962          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
963             mid for folding later */
964          cm = quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
965                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
966                NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
967          /* For a stereo split, the high bits of fill are always zero, so no
968              folding will be done to the side. */
969          cm |= quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
970                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
971                NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<B;
972       }
973
974    } else {
975       /* This is the basic no-split case */
976       q = bits2pulses(m, i, LM, b);
977       curr_bits = pulses2bits(m, i, LM, q);
978       *remaining_bits -= curr_bits;
979
980       /* Ensures we can never bust the budget */
981       while (*remaining_bits < 0 && q > 0)
982       {
983          *remaining_bits += curr_bits;
984          q--;
985          curr_bits = pulses2bits(m, i, LM, q);
986          *remaining_bits -= curr_bits;
987       }
988
989       if (q!=0)
990       {
991          int K = get_pulses(q);
992
993          /* Finally do the actual quantization */
994          if (encode)
995             cm = alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
996          else
997             cm = alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, seed, gain);
998       } else {
999          /* If there's no pulse, fill the band anyway */
1000          int j;
1001          if (resynth)
1002          {
1003             if (!fill)
1004             {
1005                for (j=0;j<N;j++)
1006                   X[j] = 0;
1007                cm = 0;
1008             } else {
1009                if (lowband == NULL || (spread==SPREAD_AGGRESSIVE && B<=1))
1010                {
1011                   /* Noise */
1012                   for (j=0;j<N;j++)
1013                   {
1014                      *seed = lcg_rand(*seed);
1015                      X[j] = (celt_int32)(*seed)>>20;
1016                   }
1017                   cm = (1<<B)-1;
1018                } else {
1019                   /* Folded spectrum */
1020                   for (j=0;j<N;j++)
1021                      X[j] = lowband[j];
1022                   cm = fill;
1023                }
1024                renormalise_vector(X, N, gain);
1025             }
1026          }
1027       }
1028    }
1029
1030    /* This code is used by the decoder and by the resynthesis-enabled encoder */
1031    if (resynth)
1032    {
1033       if (stereo)
1034       {
1035          if (N!=2)
1036          {
1037             cm |= cm>>B;
1038             stereo_merge(X, Y, mid, N);
1039          }
1040          if (inv)
1041          {
1042             int j;
1043             for (j=0;j<N;j++)
1044                Y[j] = -Y[j];
1045          }
1046       } else if (level == 0)
1047       {
1048          int k;
1049
1050          /* Undo the sample reorganization going from time order to frequency order */
1051          if (B0>1)
1052             interleave_hadamard(X, N_B>>recombine, B0<<recombine, longBlocks);
1053
1054          /* Undo time-freq changes that we did earlier */
1055          N_B = N_B0;
1056          B = B0;
1057          for (k=0;k<time_divide;k++)
1058          {
1059             B >>= 1;
1060             N_B <<= 1;
1061             cm |= cm>>B;
1062             haar1(X, N_B, B);
1063          }
1064
1065          for (k=0;k<recombine;k++)
1066          {
1067             cm |= cm<<(1<<k);
1068             haar1(X, N0>>k, 1<<k);
1069          }
1070          B<<=recombine;
1071          N_B>>=recombine;
1072
1073          /* Scale output for later folding */
1074          if (lowband_out)
1075          {
1076             int j;
1077             celt_word16 n;
1078             n = celt_sqrt(SHL32(EXTEND32(N0),22));
1079             for (j=0;j<N0;j++)
1080                lowband_out[j] = MULT16_16_Q15(n,X[j]);
1081          }
1082       }
1083    }
1084    return cm;
1085 }
1086
1087 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
1088       celt_norm *_X, celt_norm *_Y, unsigned char *collapse_masks, const celt_ener *bandE, int *pulses,
1089       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
1090       int total_bits, void *ec, int LM, int codedBands)
1091 {
1092    int i;
1093    celt_int32 balance;
1094    celt_int32 remaining_bits;
1095    const celt_int16 * restrict eBands = m->eBands;
1096    celt_norm * restrict norm, * restrict norm2;
1097    VARDECL(celt_norm, _norm);
1098    VARDECL(celt_norm, lowband_scratch);
1099    int B;
1100    int M;
1101    celt_int32 seed;
1102    int lowband_offset;
1103    int update_lowband = 1;
1104    int C = _Y != NULL ? 2 : 1;
1105    SAVE_STACK;
1106
1107    M = 1<<LM;
1108    B = shortBlocks ? M : 1;
1109    ALLOC(_norm, C*M*eBands[m->nbEBands], celt_norm);
1110    ALLOC(lowband_scratch, M*(eBands[m->nbEBands]-eBands[m->nbEBands-1]), celt_norm);
1111    norm = _norm;
1112    norm2 = norm + M*eBands[m->nbEBands];
1113
1114    if (encode)
1115       seed = ((ec_enc*)ec)->rng;
1116    else
1117       seed = ((ec_dec*)ec)->rng;
1118    balance = 0;
1119    lowband_offset = 0;
1120    for (i=start;i<end;i++)
1121    {
1122       celt_int32 tell;
1123       int b;
1124       int N;
1125       celt_int32 curr_balance;
1126       int effective_lowband=-1;
1127       celt_norm * restrict X, * restrict Y;
1128       int tf_change=0;
1129       unsigned x_cm;
1130       unsigned y_cm;
1131
1132       X = _X+M*eBands[i];
1133       if (_Y!=NULL)
1134          Y = _Y+M*eBands[i];
1135       else
1136          Y = NULL;
1137       N = M*eBands[i+1]-M*eBands[i];
1138       if (encode)
1139          tell = ec_enc_tell((ec_enc*)ec, BITRES);
1140       else
1141          tell = ec_dec_tell((ec_dec*)ec, BITRES);
1142
1143       /* Compute how many bits we want to allocate to this band */
1144       if (i != start)
1145          balance -= tell;
1146       remaining_bits = ((celt_int32)total_bits<<BITRES)-tell-1- (shortBlocks&&LM>=2 ? (1<<BITRES) : 0);
1147       if (i <= codedBands-1)
1148       {
1149          curr_balance = balance / IMIN(3, codedBands-i);
1150          b = IMAX(0, IMIN(16384, IMIN(remaining_bits+1,pulses[i]+curr_balance)));
1151       } else {
1152          b = 0;
1153       }
1154
1155       if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
1156             lowband_offset = i;
1157
1158       tf_change = tf_res[i];
1159       if (i>=m->effEBands)
1160       {
1161          X=norm;
1162          if (_Y!=NULL)
1163             Y = norm;
1164       }
1165
1166       /* This ensures we never repeat spectral content within one band */
1167       if (lowband_offset != 0)
1168          effective_lowband = IMAX(M*eBands[start], M*eBands[lowband_offset]-N);
1169
1170       /* Get a conservative estimate of the collapse_mask's for the bands we're
1171           going to be folding from. */
1172       if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1))
1173       {
1174          int fold_start;
1175          int fold_end;
1176          int fold_i;
1177          fold_start = lowband_offset;
1178          while(M*eBands[--fold_start] > effective_lowband);
1179          fold_end = lowband_offset-1;
1180          while(M*eBands[++fold_end] < effective_lowband+N);
1181          x_cm = y_cm = 0;
1182          fold_i = fold_start; do {
1183            x_cm |= collapse_masks[fold_i*C+0];
1184            y_cm |= collapse_masks[fold_i*C+1];
1185          } while (++fold_i<fold_end);
1186       }
1187       /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
1188           always) be non-zero.*/
1189       else
1190          x_cm = y_cm = (1<<B)-1;
1191
1192       if (dual_stereo && i==intensity)
1193       {
1194          int j;
1195
1196          /* Switch off dual stereo to do intensity */
1197          dual_stereo = 0;
1198          for (j=M*eBands[start];j<M*eBands[i];j++)
1199             norm[j] = HALF32(norm[j]+norm2[j]);
1200       }
1201       if (dual_stereo)
1202       {
1203          x_cm = quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
1204                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1205                norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, x_cm);
1206          y_cm = quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
1207                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1208                norm2+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, y_cm);
1209          collapse_masks[i*2+0] = (unsigned char)(x_cm&(1<<B)-1);
1210          collapse_masks[i*2+1] = (unsigned char)(y_cm&(1<<B)-1);
1211       } else {
1212          x_cm = quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
1213                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
1214                norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, x_cm|y_cm);
1215          collapse_masks[i*C+1] = collapse_masks[i*C+0] = (unsigned char)(x_cm&(1<<B)-1);
1216       }
1217       balance += pulses[i] + tell;
1218
1219       /* Update the folding position only as long as we have 1 bit/sample depth */
1220       update_lowband = (b>>BITRES)>N;
1221    }
1222    RESTORE_STACK;
1223 }
1224