2e002f1dcf736c41a7c412f33480160bc48776fd
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2010 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 static const int trim_cdf[7] = {0, 4, 10, 23, 119, 125, 128};
69 static const int trim_coef[6] = {4, 6, 7, 8, 10, 12};
70
71 /** Encoder state 
72  @brief Encoder state
73  */
74 struct CELTEncoder {
75    const CELTMode *mode;     /**< Mode used by the encoder */
76    int overlap;
77    int channels;
78    
79    int force_intra;
80    int complexity;
81    int start, end;
82
83    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
84
85    /* Everything beyond this point gets cleared on a reset */
86 #define ENCODER_RESET_START frame_max
87
88    celt_word32 frame_max;
89    int fold_decision;
90    int delayedIntra;
91    int tonal_average;
92
93    /* VBR-related parameters */
94    celt_int32 vbr_reservoir;
95    celt_int32 vbr_drift;
96    celt_int32 vbr_offset;
97    celt_int32 vbr_count;
98
99    celt_word32 preemph_memE[2];
100    celt_word32 preemph_memD[2];
101
102    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
103    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
104    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
105 };
106
107 int celt_encoder_get_size(const CELTMode *mode, int channels)
108 {
109    int size = sizeof(struct CELTEncoder)
110          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
111          + channels*mode->nbEBands*sizeof(celt_word16);
112    return size;
113 }
114
115 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
116 {
117    return celt_encoder_init(
118          (CELTEncoder *)celt_alloc(celt_encoder_get_size(mode, channels)),
119          mode, channels, error);
120 }
121
122 CELTEncoder *celt_encoder_init(CELTEncoder *st, const CELTMode *mode, int channels, int *error)
123 {
124    if (channels < 0 || channels > 2)
125    {
126       if (error)
127          *error = CELT_BAD_ARG;
128       return NULL;
129    }
130
131    if (st==NULL)
132    {
133       if (error)
134          *error = CELT_ALLOC_FAIL;
135       return NULL;
136    }
137
138    CELT_MEMSET((char*)st, 0, celt_encoder_get_size(mode, channels));
139    
140    st->mode = mode;
141    st->overlap = mode->overlap;
142    st->channels = channels;
143
144    st->start = 0;
145    st->end = st->mode->effEBands;
146
147    st->vbr_rate_norm = 0;
148    st->force_intra  = 0;
149    st->delayedIntra = 1;
150    st->tonal_average = 256;
151    st->fold_decision = 1;
152    st->complexity = 5;
153
154    if (error)
155       *error = CELT_OK;
156    return st;
157 }
158
159 void celt_encoder_destroy(CELTEncoder *st)
160 {
161    celt_free(st);
162 }
163
164 static inline celt_int16 FLOAT2INT16(float x)
165 {
166    x = x*CELT_SIG_SCALE;
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return (celt_int16)float2int(x);
170 }
171
172 static inline celt_word16 SIG2WORD16(celt_sig x)
173 {
174 #ifdef FIXED_POINT
175    x = PSHR32(x, SIG_SHIFT);
176    x = MAX32(x, -32768);
177    x = MIN32(x, 32767);
178    return EXTRACT16(x);
179 #else
180    return (celt_word16)x;
181 #endif
182 }
183
184 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
185                               int *transient_time, int *transient_shift,
186                               celt_word32 *frame_max, int overlap)
187 {
188    int i, n;
189    celt_word32 ratio;
190    celt_word32 threshold;
191    VARDECL(celt_word32, begin);
192    VARDECL(celt_word16, tmp);
193    celt_word32 mem0=0,mem1=0;
194    SAVE_STACK;
195    ALLOC(tmp, len, celt_word16);
196    ALLOC(begin, len+1, celt_word32);
197
198    if (C==1)
199    {
200       for (i=0;i<len;i++)
201          tmp[i] = SHR32(in[i],SIG_SHIFT);
202    } else {
203       for (i=0;i<len;i++)
204          tmp[i] = SHR32(ADD32(in[C*i],in[C*i+1]), SIG_SHIFT+1);
205    }
206
207    /* High-pass filter: (1 - 2*z^-1 + z^-2) / (1 - z^-1 + .5*z^-2) */
208    for (i=0;i<len;i++)
209    {
210       celt_word32 x,y;
211       x = tmp[i];
212       y = ADD32(mem0, x);
213 #ifdef FIXED_POINT
214       mem0 = mem1 + y - SHL32(x,1);
215       mem1 = x - SHR32(y,1);
216 #else
217       mem0 = mem1 + y - 2*x;
218       mem1 = x - .5*y;
219 #endif
220       tmp[i] = EXTRACT16(SHR(y,2));
221    }
222    /* First few samples are bad because we don't propagate the memory */
223    for (i=0;i<24;i++)
224       tmp[i] = 0;
225
226    begin[0] = 0;
227    for (i=0;i<len;i++)
228       begin[i+1] = MAX32(begin[i], ABS32(tmp[i]));
229
230    n = -1;
231
232    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
233    /* If the following condition isn't met, there's just no way
234       we'll have a transient*/
235    if (*frame_max < threshold)
236    {
237       /* It's likely we have a transient, now find it */
238       for (i=8;i<len-8;i++)
239       {
240          if (begin[i+1] < threshold)
241             n=i;
242       }
243    }
244    if (n<32)
245    {
246       n = -1;
247       ratio = 0;
248    } else {
249       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
250    }
251
252    if (ratio > 45)
253       *transient_shift = 3;
254    else
255       *transient_shift = 0;
256    
257    *transient_time = n;
258    *frame_max = begin[len-overlap];
259    /* Only consider the last 7.5 ms for the next transient */
260    if (len>360+overlap)
261    {
262       *frame_max = 0;
263       for (i=len-360-overlap;i<len;i++)
264          *frame_max = MAX32(*frame_max, ABS32(tmp[i]));
265    }
266    RESTORE_STACK;
267    return ratio > 0;
268 }
269
270 /** Apply window and compute the MDCT for all sub-frames and 
271     all channels in a frame */
272 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
273 {
274    const int C = CHANNELS(_C);
275    if (C==1 && !shortBlocks)
276    {
277       const int overlap = OVERLAP(mode);
278       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
279    } else {
280       const int overlap = OVERLAP(mode);
281       int N = mode->shortMdctSize<<LM;
282       int B = 1;
283       int b, c;
284       VARDECL(celt_word32, x);
285       VARDECL(celt_word32, tmp);
286       SAVE_STACK;
287       if (shortBlocks)
288       {
289          /*lookup = &mode->mdct[0];*/
290          N = mode->shortMdctSize;
291          B = shortBlocks;
292       }
293       ALLOC(x, N+overlap, celt_word32);
294       ALLOC(tmp, N, celt_word32);
295       for (c=0;c<C;c++)
296       {
297          for (b=0;b<B;b++)
298          {
299             int j;
300             for (j=0;j<N+overlap;j++)
301                x[j] = in[C*(b*N+j)+c];
302             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
303             /* Interleaving the sub-frames */
304             for (j=0;j<N;j++)
305                out[(j*B+b)+c*N*B] = tmp[j];
306          }
307       }
308       RESTORE_STACK;
309    }
310 }
311
312 /** Compute the IMDCT and apply window for all sub-frames and 
313     all channels in a frame */
314 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
315       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
316       celt_sig * restrict overlap_mem[], int _C, int LM)
317 {
318    int c;
319    const int C = CHANNELS(_C);
320    const int N = mode->shortMdctSize<<LM;
321    const int overlap = OVERLAP(mode);
322    for (c=0;c<C;c++)
323    {
324       int j;
325          VARDECL(celt_word32, x);
326          VARDECL(celt_word32, tmp);
327          int b;
328          int N2 = N;
329          int B = 1;
330          SAVE_STACK;
331          
332          ALLOC(x, N+overlap, celt_word32);
333          ALLOC(tmp, N, celt_word32);
334
335          if (shortBlocks)
336          {
337             N2 = mode->shortMdctSize;
338             B = shortBlocks;
339          }
340          /* Prevents problems from the imdct doing the overlap-add */
341          CELT_MEMSET(x, 0, overlap);
342
343          for (b=0;b<B;b++)
344          {
345             /* De-interleaving the sub-frames */
346             for (j=0;j<N2;j++)
347                tmp[j] = X[(j*B+b)+c*N2*B];
348             clt_mdct_backward(&mode->mdct, tmp, x+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
349          }
350
351          if (transient_shift > 0)
352          {
353 #ifdef FIXED_POINT
354             for (j=0;j<16;j++)
355                x[transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[transient_time+j-16],transient_shift));
356             for (j=transient_time;j<N+overlap;j++)
357                x[j] = SHL32(x[j], transient_shift);
358 #else
359             for (j=0;j<16;j++)
360                x[transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
361             for (j=transient_time;j<N+overlap;j++)
362                x[j] *= 1<<transient_shift;
363 #endif
364          }
365          for (j=0;j<overlap;j++)
366             out_mem[c][j] = x[j] + overlap_mem[c][j];
367          for (;j<N;j++)
368             out_mem[c][j] = x[j];
369          for (j=0;j<overlap;j++)
370             overlap_mem[c][j] = x[N+j];
371          RESTORE_STACK;
372    }
373 }
374
375 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
376 {
377    const int C = CHANNELS(_C);
378    int c;
379    for (c=0;c<C;c++)
380    {
381       int j;
382       celt_sig * restrict x;
383       celt_word16  * restrict y;
384       celt_sig m = mem[c];
385       x =in[c];
386       y = pcm+c;
387       for (j=0;j<N;j++)
388       {
389          celt_sig tmp = *x + m;
390          m = MULT16_32_Q15(coef[0], tmp)
391            - MULT16_32_Q15(coef[1], *x);
392          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
393          *y = SCALEOUT(SIG2WORD16(tmp));
394          x++;
395          y+=C;
396       }
397       mem[c] = m;
398    }
399 }
400
401 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
402                        int end, int N,
403                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
404 {
405    int m, i, c;
406    const int C = CHANNELS(_C);
407    for (c=0;c<C;c++)
408       for (m=start;m<end;m++)
409          for (i=m+c*N;i<(c+1)*N;i+=M)
410 #ifdef FIXED_POINT
411             X[i] = SHR16(X[i], mdct_weight_shift);
412 #else
413             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
414 #endif
415    if (renorm)
416       renormalise_bands(mode, X, end_band, C, M);
417 }
418
419 static const signed char tf_select_table[4][8] = {
420       {0, -1, 0, -1,    0,-1, 0,-1},
421       {0, -1, 0, -2,    1, 0, 1 -1},
422       {0, -2, 0, -3,    2, 0, 1 -1},
423       {0, -2, 0, -3,    2, 0, 1 -1},
424 };
425
426 static int tf_analysis(const CELTMode *m, celt_word16 *bandLogE, celt_word16 *oldBandE,
427       int len, int C, int isTransient, int *tf_res, int nbCompressedBytes, celt_norm *X,
428       int N0, int LM, int *tf_sum)
429 {
430    int i;
431    VARDECL(int, metric);
432    int cost0;
433    int cost1;
434    VARDECL(int, path0);
435    VARDECL(int, path1);
436    VARDECL(celt_norm, tmp);
437    int lambda;
438    int tf_select=0;
439    SAVE_STACK;
440
441    /* FIXME: Should check number of bytes *left* */
442    if (nbCompressedBytes<15*C)
443    {
444       for (i=0;i<len;i++)
445          tf_res[i] = 0;
446       return 0;
447    }
448    if (nbCompressedBytes<40)
449       lambda = 10;
450    else if (nbCompressedBytes<60)
451       lambda = 4;
452    else if (nbCompressedBytes<100)
453       lambda = 2;
454    else
455       lambda = 1;
456
457    ALLOC(metric, len, int);
458    ALLOC(tmp, (m->eBands[len]-m->eBands[len-1])<<LM, celt_norm);
459    ALLOC(path0, len, int);
460    ALLOC(path1, len, int);
461
462    *tf_sum = 0;
463    for (i=0;i<len;i++)
464    {
465       int j, k, N;
466       celt_word32 L1, best_L1;
467       int best_level=0;
468       N = (m->eBands[i+1]-m->eBands[i])<<LM;
469       for (j=0;j<N;j++)
470          tmp[j] = X[j+(m->eBands[i]<<LM)];
471       if (C==2)
472          for (j=0;j<N;j++)
473             tmp[j] = ADD16(tmp[j],X[N0+j+(m->eBands[i]<<LM)]);
474       L1=0;
475       for (j=0;j<N;j++)
476          L1 += ABS16(tmp[j]);
477       /* Biasing towards better freq resolution (because of spreading) */
478       if (isTransient)
479          L1 += MULT16_32_Q15(QCONST16(.08,15), L1);
480       else
481          L1 -= MULT16_32_Q15(QCONST16(.08,15), L1);
482       best_L1 = L1;
483       /*printf ("%f ", L1);*/
484       for (k=0;k<LM;k++)
485       {
486          if (isTransient)
487             haar1(tmp, N>>(LM-k), 1<<(LM-k));
488          else
489             haar1(tmp, N>>k, 1<<k);
490
491          L1=0;
492          for (j=0;j<N;j++)
493             L1 += ABS16(tmp[j]);
494
495          /*printf ("%f ", L1);*/
496          if (L1 < best_L1)
497          {
498             best_L1 = L1;
499             best_level = k+1;
500          }
501       }
502       /*printf ("%d ", isTransient ? LM-best_level : best_level);*/
503       if (isTransient)
504          metric[i] = best_level;
505       else
506          metric[i] = -best_level;
507       *tf_sum += metric[i];
508    }
509    /*printf("\n");*/
510    /* FIXME: Figure out how to set this */
511    tf_select = 1;
512
513    cost0 = 0;
514    cost1 = lambda;
515    /* Viterbi forward pass */
516    for (i=1;i<len;i++)
517    {
518       int curr0, curr1;
519       int from0, from1;
520
521       from0 = cost0;
522       from1 = cost1 + lambda;
523       if (from0 < from1)
524       {
525          curr0 = from0;
526          path0[i]= 0;
527       } else {
528          curr0 = from1;
529          path0[i]= 1;
530       }
531
532       from0 = cost0 + lambda;
533       from1 = cost1;
534       if (from0 < from1)
535       {
536          curr1 = from0;
537          path1[i]= 0;
538       } else {
539          curr1 = from1;
540          path1[i]= 1;
541       }
542       cost0 = curr0 + abs(metric[i]-tf_select_table[LM][4*isTransient+2*tf_select+0]);
543       cost1 = curr1 + abs(metric[i]-tf_select_table[LM][4*isTransient+2*tf_select+1]);
544    }
545    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
546    /* Viterbi backward pass to check the decisions */
547    for (i=len-2;i>=0;i--)
548    {
549       if (tf_res[i+1] == 1)
550          tf_res[i] = path1[i+1];
551       else
552          tf_res[i] = path0[i+1];
553    }
554    RESTORE_STACK;
555    return tf_select;
556 }
557
558 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
559 {
560    int curr, i;
561    ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
562    curr = tf_res[start];
563    for (i=start+1;i<end;i++)
564    {
565       ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
566       curr = tf_res[i];
567    }
568    ec_enc_bits(enc, tf_select, 1);
569    for (i=start;i<end;i++)
570       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
571    /*printf("%d %d ", isTransient, tf_select); for(i=0;i<end;i++)printf("%d ", tf_res[i]);printf("\n");*/
572 }
573
574 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
575 {
576    int i, curr, tf_select;
577    tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
578    curr = tf_res[start];
579    for (i=start+1;i<end;i++)
580    {
581       tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
582       curr = tf_res[i];
583    }
584    tf_select = ec_dec_bits(dec, 1);
585    for (i=start;i<end;i++)
586       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
587 }
588
589 #ifdef FIXED_POINT
590 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
591 {
592 #else
593 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
594 {
595 #endif
596    int i, c, N, NN;
597    int bits;
598    int has_fold=1;
599    ec_byte_buffer buf;
600    ec_enc         _enc;
601    VARDECL(celt_sig, in);
602    VARDECL(celt_sig, freq);
603    VARDECL(celt_norm, X);
604    VARDECL(celt_ener, bandE);
605    VARDECL(celt_word16, bandLogE);
606    VARDECL(int, fine_quant);
607    VARDECL(celt_word16, error);
608    VARDECL(int, pulses);
609    VARDECL(int, offsets);
610    VARDECL(int, fine_priority);
611    VARDECL(int, tf_res);
612    celt_sig *_overlap_mem;
613    celt_word16 *oldBandE;
614    int shortBlocks=0;
615    int isTransient=0;
616    int transient_time, transient_time_quant;
617    int transient_shift;
618    int resynth;
619    const int C = CHANNELS(st->channels);
620    int mdct_weight_shift = 0;
621    int mdct_weight_pos=0;
622    int LM, M;
623    int tf_select;
624    int nbFilledBytes, nbAvailableBytes;
625    int effEnd;
626    int codedBands;
627    int tf_sum;
628    int alloc_trim;
629    SAVE_STACK;
630
631    if (nbCompressedBytes<0 || pcm==NULL)
632      return CELT_BAD_ARG;
633
634    for (LM=0;LM<4;LM++)
635       if (st->mode->shortMdctSize<<LM==frame_size)
636          break;
637    if (LM>=MAX_CONFIG_SIZES)
638       return CELT_BAD_ARG;
639    M=1<<LM;
640
641    _overlap_mem = st->in_mem+C*(st->overlap);
642    oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
643
644    if (enc==NULL)
645    {
646       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
647       ec_enc_init(&_enc,&buf);
648       enc = &_enc;
649       nbFilledBytes=0;
650    } else {
651       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
652    }
653    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
654
655    effEnd = st->end;
656    if (effEnd > st->mode->effEBands)
657       effEnd = st->mode->effEBands;
658
659    N = M*st->mode->shortMdctSize;
660    ALLOC(in, C*(N+st->overlap), celt_sig);
661
662    CELT_COPY(in, st->in_mem, C*st->overlap);
663    for (c=0;c<C;c++)
664    {
665       const celt_word16 * restrict pcmp = pcm+c;
666       celt_sig * restrict inp = in+C*st->overlap+c;
667       for (i=0;i<N;i++)
668       {
669          /* Apply pre-emphasis */
670          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
671          *inp = tmp + st->preemph_memE[c];
672          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
673                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
674          inp += C;
675          pcmp += C;
676       }
677    }
678    CELT_COPY(st->in_mem, in+C*N, C*st->overlap);
679
680    /* Transient handling */
681    transient_time = -1;
682    transient_time_quant = -1;
683    transient_shift = 0;
684    isTransient = 0;
685
686    resynth = optional_resynthesis!=NULL;
687
688    if (st->complexity > 1 && LM>0)
689    {
690       isTransient = M > 1 &&
691          transient_analysis(in, N+st->overlap, C, &transient_time,
692                             &transient_shift, &st->frame_max, st->overlap);
693    } else {
694       isTransient = 0;
695    }
696    if (isTransient)
697    {
698 #ifndef FIXED_POINT
699       float gain_1;
700 #endif
701       /* Apply the inverse shaping window */
702       if (transient_shift)
703       {
704          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
705          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
706 #ifdef FIXED_POINT
707          for (c=0;c<C;c++)
708             for (i=0;i<16;i++)
709                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
710          for (c=0;c<C;c++)
711             for (i=transient_time;i<N+st->overlap;i++)
712                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
713 #else
714          for (c=0;c<C;c++)
715             for (i=0;i<16;i++)
716                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
717          gain_1 = 1.f/(1<<transient_shift);
718          for (c=0;c<C;c++)
719             for (i=transient_time;i<N+st->overlap;i++)
720                in[C*i+c] *= gain_1;
721 #endif
722       }
723       has_fold = 1;
724    }
725
726    if (isTransient)
727       shortBlocks = M;
728    else
729       shortBlocks = 0;
730
731    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
732    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
733    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
734    /* Compute MDCTs */
735    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
736
737    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
738
739    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
740
741    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
742
743    /* Band normalisation */
744    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
745
746    NN = M*st->mode->eBands[effEnd];
747    if (shortBlocks && !transient_shift)
748    {
749       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
750       int m;
751       for (c=0;c<C;c++)
752       {
753          m=0;
754          do {
755             celt_word32 tmp=0;
756             for (i=m+c*N;i<c*N+NN;i+=M)
757                tmp += ABS32(X[i]);
758             sum[m++] += tmp;
759          } while (m<M);
760       }
761       m=0;
762 #ifdef FIXED_POINT
763       do {
764          if (SHR32(sum[m+1],3) > sum[m])
765          {
766             mdct_weight_shift=2;
767             mdct_weight_pos = m;
768          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
769          {
770             mdct_weight_shift=1;
771             mdct_weight_pos = m;
772          }
773          m++;
774       } while (m<M-1);
775 #else
776       do {
777          if (sum[m+1] > 8*sum[m])
778          {
779             mdct_weight_shift=2;
780             mdct_weight_pos = m;
781          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
782          {
783             mdct_weight_shift=1;
784             mdct_weight_pos = m;
785          }
786          m++;
787       } while (m<M-1);
788 #endif
789       if (mdct_weight_shift)
790          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
791    }
792
793    ALLOC(tf_res, st->mode->nbEBands, int);
794    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
795    tf_select = tf_analysis(st->mode, bandLogE, oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes, X, N, LM, &tf_sum);
796    for (i=effEnd;i<st->end;i++)
797       tf_res[i] = tf_res[effEnd-1];
798
799    ALLOC(error, C*st->mode->nbEBands, celt_word16);
800    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
801          oldBandE, nbCompressedBytes*8, st->mode->prob,
802          error, enc, C, LM, nbAvailableBytes, st->force_intra,
803          &st->delayedIntra, st->complexity >= 4);
804
805    if (LM > 0)
806       ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
807
808    if (shortBlocks)
809    {
810       if (transient_shift)
811       {
812          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
813          ec_enc_uint(enc, transient_shift, 4);
814          ec_enc_uint(enc, transient_time_quant, max_time);
815       } else {
816          ec_enc_uint(enc, mdct_weight_shift, 4);
817          if (mdct_weight_shift && M!=2)
818             ec_enc_uint(enc, mdct_weight_pos, M-1);
819       }
820    }
821
822    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
823
824    if (shortBlocks || st->complexity < 3)
825    {
826       if (st->complexity == 0)
827       {
828          has_fold = 0;
829          st->fold_decision = 3;
830       } else {
831          has_fold = 1;
832          st->fold_decision = 1;
833       }
834    } else {
835       has_fold = folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M);
836    }
837    ec_enc_bit_prob(enc, has_fold>>1, 8192);
838    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
839
840    ALLOC(offsets, st->mode->nbEBands, int);
841
842    for (i=0;i<st->mode->nbEBands;i++)
843       offsets[i] = 0;
844    /* Dynamic allocation code */
845    /* Make sure that dynamic allocation can't make us bust the budget */
846    if (nbCompressedBytes > 30)
847    {
848       int t1, t2;
849       if (LM <= 1)
850       {
851          t1 = 3;
852          t2 = 5;
853       } else {
854          t1 = 2;
855          t2 = 4;
856       }
857       for (i=1;i<st->mode->nbEBands-1;i++)
858       {
859          if (2*bandLogE[i]-bandLogE[i-1]-bandLogE[i+1] > SHL16(t1,DB_SHIFT))
860             offsets[i] += 1;
861          if (2*bandLogE[i]-bandLogE[i-1]-bandLogE[i+1] > SHL16(t2,DB_SHIFT))
862             offsets[i] += 1;
863       }
864    }
865    for (i=0;i<st->mode->nbEBands;i++)
866    {
867       int j;
868       ec_enc_bit_prob(enc, offsets[i]!=0, 1024);
869       if (offsets[i]!=0)
870       {
871          for (j=0;j<offsets[i]-1;j++)
872             ec_enc_bit_prob(enc, 1, 32768);
873          ec_enc_bit_prob(enc, 0, 32768);
874       }
875       offsets[i] *= (6<<BITRES);
876    }
877    {
878       int trim_index = 3;
879
880       /*if (isTransient)
881          trim_index--;
882       if (has_fold==0)
883          trim_index--;
884       if (C==2)
885          trim_index--;*/
886       alloc_trim = trim_coef[trim_index];
887       ec_encode_bin(enc, trim_cdf[trim_index], trim_cdf[trim_index+1], 7);
888    }
889
890    /* Variable bitrate */
891    if (st->vbr_rate_norm>0)
892    {
893      celt_word16 alpha;
894      celt_int32 delta;
895      /* The target rate in 16th bits per frame */
896      celt_int32 vbr_rate;
897      celt_int32 target;
898      celt_int32 vbr_bound, max_allowed;
899
900      vbr_rate = M*st->vbr_rate_norm;
901
902      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
903      vbr_bound = vbr_rate;
904      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
905      if (max_allowed < 4)
906         max_allowed = 4;
907      if (max_allowed < nbAvailableBytes)
908         nbAvailableBytes = max_allowed;
909      target=vbr_rate;
910
911      /* Shortblocks get a large boost in bitrate, but since they 
912         are uncommon long blocks are not greatly affected */
913      if (shortBlocks || tf_sum < -2*(st->end-st->start))
914         target*=2;
915      else if (tf_sum < -(st->end-st->start))
916         target = 3*target/2;
917      else if (M > 1)
918         target-=(target+14)/28;
919
920      /* The average energy is removed from the target and the actual 
921         energy added*/
922      target=target+st->vbr_offset-(50<<BITRES)+ec_enc_tell(enc, BITRES);
923
924      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
925      target=IMIN(nbAvailableBytes,target);
926      /* Make the adaptation coef (alpha) higher at the beginning */
927      if (st->vbr_count < 990)
928      {
929         st->vbr_count++;
930         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
931         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
932      } else
933         alpha = QCONST16(.001f,15);
934
935      /* By how much did we "miss" the target on that frame */
936      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
937      /* How many bits have we used in excess of what we're allowed */
938      st->vbr_reservoir += delta;
939      /*printf ("%d\n", st->vbr_reservoir);*/
940
941      /* Compute the offset we need to apply in order to reach the target */
942      st->vbr_drift += (celt_int32)MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
943      st->vbr_offset = -st->vbr_drift;
944      /*printf ("%d\n", st->vbr_drift);*/
945
946      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
947      if (st->vbr_reservoir < 0)
948      {
949         /* We're under the min value -- increase rate */
950         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
951         st->vbr_reservoir += adjust*(8<<BITRES);
952         target += adjust;
953         /*printf ("+%d\n", adjust);*/
954      }
955      if (target < nbAvailableBytes)
956         nbAvailableBytes = target;
957      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
958
959      /* This moves the raw bits to take into account the new compressed size */
960      ec_byte_shrink(&buf, nbCompressedBytes);
961    }
962
963    /* Bit allocation */
964    ALLOC(fine_quant, st->mode->nbEBands, int);
965    ALLOC(pulses, st->mode->nbEBands, int);
966    ALLOC(fine_priority, st->mode->nbEBands, int);
967
968    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
969    codedBands = compute_allocation(st->mode, st->start, st->end, offsets, alloc_trim, bits, pulses, fine_quant, fine_priority, C, LM);
970
971    quant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, enc, C);
972
973 #ifdef MEASURE_NORM_MSE
974    float X0[3000];
975    float bandE0[60];
976    for (c=0;c<C;c++)
977       for (i=0;i<N;i++)
978          X0[i+c*N] = X[i+c*N];
979    for (i=0;i<C*st->mode->nbEBands;i++)
980       bandE0[i] = bandE[i];
981 #endif
982
983    /* Residual quantisation */
984    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM, codedBands);
985
986    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
987
988    /* Re-synthesis of the coded audio if required */
989    if (resynth)
990    {
991       VARDECL(celt_sig, _out_mem);
992       celt_sig *out_mem[2];
993       celt_sig *overlap_mem[2];
994
995       log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
996
997 #ifdef MEASURE_NORM_MSE
998       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
999 #endif
1000
1001       if (mdct_weight_shift)
1002       {
1003          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1004       }
1005
1006       /* Synthesis */
1007       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1008
1009       for (c=0;c<C;c++)
1010          for (i=0;i<M*st->mode->eBands[st->start];i++)
1011             freq[c*N+i] = 0;
1012       for (c=0;c<C;c++)
1013          for (i=M*st->mode->eBands[st->end];i<N;i++)
1014             freq[c*N+i] = 0;
1015
1016       ALLOC(_out_mem, C*N, celt_sig);
1017
1018       for (c=0;c<C;c++)
1019       {
1020          overlap_mem[c] = _overlap_mem + c*st->overlap;
1021          out_mem[c] = _out_mem+c*N;
1022       }
1023
1024       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1025             transient_shift, out_mem, overlap_mem, C, LM);
1026
1027       /* De-emphasis and put everything back at the right place 
1028          in the synthesis history */
1029       if (optional_resynthesis != NULL) {
1030          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
1031
1032       }
1033    }
1034
1035    /* If there's any room left (can only happen for very high rates),
1036       fill it with zeros */
1037    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
1038       ec_enc_bits(enc, 0, 8);
1039    ec_enc_done(enc);
1040    
1041    RESTORE_STACK;
1042    if (ec_enc_get_error(enc))
1043       return CELT_CORRUPTED_DATA;
1044    else
1045       return nbCompressedBytes;
1046 }
1047
1048 #ifdef FIXED_POINT
1049 #ifndef DISABLE_FLOAT_API
1050 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1051 {
1052    int j, ret, C, N, LM, M;
1053    VARDECL(celt_int16, in);
1054    SAVE_STACK;
1055
1056    if (pcm==NULL)
1057       return CELT_BAD_ARG;
1058
1059    for (LM=0;LM<4;LM++)
1060       if (st->mode->shortMdctSize<<LM==frame_size)
1061          break;
1062    if (LM>=MAX_CONFIG_SIZES)
1063       return CELT_BAD_ARG;
1064    M=1<<LM;
1065
1066    C = CHANNELS(st->channels);
1067    N = M*st->mode->shortMdctSize;
1068    ALLOC(in, C*N, celt_int16);
1069
1070    for (j=0;j<C*N;j++)
1071      in[j] = FLOAT2INT16(pcm[j]);
1072
1073    if (optional_resynthesis != NULL) {
1074      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1075       for (j=0;j<C*N;j++)
1076          optional_resynthesis[j]=in[j]*(1.f/32768.f);
1077    } else {
1078      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1079    }
1080    RESTORE_STACK;
1081    return ret;
1082
1083 }
1084 #endif /*DISABLE_FLOAT_API*/
1085 #else
1086 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1087 {
1088    int j, ret, C, N, LM, M;
1089    VARDECL(celt_sig, in);
1090    SAVE_STACK;
1091
1092    if (pcm==NULL)
1093       return CELT_BAD_ARG;
1094
1095    for (LM=0;LM<4;LM++)
1096       if (st->mode->shortMdctSize<<LM==frame_size)
1097          break;
1098    if (LM>=MAX_CONFIG_SIZES)
1099       return CELT_BAD_ARG;
1100    M=1<<LM;
1101
1102    C=CHANNELS(st->channels);
1103    N=M*st->mode->shortMdctSize;
1104    ALLOC(in, C*N, celt_sig);
1105    for (j=0;j<C*N;j++) {
1106      in[j] = SCALEOUT(pcm[j]);
1107    }
1108
1109    if (optional_resynthesis != NULL) {
1110       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1111       for (j=0;j<C*N;j++)
1112          optional_resynthesis[j] = FLOAT2INT16(in[j]);
1113    } else {
1114       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1115    }
1116    RESTORE_STACK;
1117    return ret;
1118 }
1119 #endif
1120
1121 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1122 {
1123    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1124 }
1125
1126 #ifndef DISABLE_FLOAT_API
1127 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1128 {
1129    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1130 }
1131 #endif /* DISABLE_FLOAT_API */
1132
1133 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1134 {
1135    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1136 }
1137
1138 #ifndef DISABLE_FLOAT_API
1139 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1140 {
1141    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1142 }
1143 #endif /* DISABLE_FLOAT_API */
1144
1145
1146 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1147 {
1148    va_list ap;
1149    
1150    va_start(ap, request);
1151    switch (request)
1152    {
1153       case CELT_GET_MODE_REQUEST:
1154       {
1155          const CELTMode ** value = va_arg(ap, const CELTMode**);
1156          if (value==0)
1157             goto bad_arg;
1158          *value=st->mode;
1159       }
1160       break;
1161       case CELT_SET_COMPLEXITY_REQUEST:
1162       {
1163          int value = va_arg(ap, celt_int32);
1164          if (value<0 || value>10)
1165             goto bad_arg;
1166          st->complexity = value;
1167       }
1168       break;
1169       case CELT_SET_START_BAND_REQUEST:
1170       {
1171          celt_int32 value = va_arg(ap, celt_int32);
1172          if (value<0 || value>=st->mode->nbEBands)
1173             goto bad_arg;
1174          st->start = value;
1175       }
1176       break;
1177       case CELT_SET_END_BAND_REQUEST:
1178       {
1179          celt_int32 value = va_arg(ap, celt_int32);
1180          if (value<0 || value>=st->mode->nbEBands)
1181             goto bad_arg;
1182          st->end = value;
1183       }
1184       break;
1185       case CELT_SET_PREDICTION_REQUEST:
1186       {
1187          int value = va_arg(ap, celt_int32);
1188          if (value<0 || value>2)
1189             goto bad_arg;
1190          if (value==0)
1191          {
1192             st->force_intra   = 1;
1193          } else if (value==1) {
1194             st->force_intra   = 0;
1195          } else {
1196             st->force_intra   = 0;
1197          }   
1198       }
1199       break;
1200       case CELT_SET_VBR_RATE_REQUEST:
1201       {
1202          celt_int32 value = va_arg(ap, celt_int32);
1203          int frame_rate;
1204          int N = st->mode->shortMdctSize;
1205          if (value<0)
1206             goto bad_arg;
1207          if (value>3072000)
1208             value = 3072000;
1209          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1210          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1211       }
1212       break;
1213       case CELT_RESET_STATE:
1214       {
1215          CELT_MEMSET((char*)&st->ENCODER_RESET_START, 0,
1216                celt_encoder_get_size(st->mode, st->channels)-
1217                ((char*)&st->ENCODER_RESET_START - (char*)st));
1218          st->delayedIntra = 1;
1219          st->fold_decision = 1;
1220          st->tonal_average = QCONST16(1.f,8);
1221       }
1222       break;
1223       default:
1224          goto bad_request;
1225    }
1226    va_end(ap);
1227    return CELT_OK;
1228 bad_arg:
1229    va_end(ap);
1230    return CELT_BAD_ARG;
1231 bad_request:
1232    va_end(ap);
1233    return CELT_UNIMPLEMENTED;
1234 }
1235
1236 /**********************************************************************/
1237 /*                                                                    */
1238 /*                             DECODER                                */
1239 /*                                                                    */
1240 /**********************************************************************/
1241 #define DECODE_BUFFER_SIZE 2048
1242
1243 /** Decoder state 
1244  @brief Decoder state
1245  */
1246 struct CELTDecoder {
1247    const CELTMode *mode;
1248    int overlap;
1249    int channels;
1250
1251    int start, end;
1252
1253    /* Everything beyond this point gets cleared on a reset */
1254 #define DECODER_RESET_START last_pitch_index
1255
1256    int last_pitch_index;
1257    int loss_count;
1258
1259    celt_sig preemph_memD[2];
1260    
1261    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
1262    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
1263    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
1264 };
1265
1266 int celt_decoder_get_size(const CELTMode *mode, int channels)
1267 {
1268    int size = sizeof(struct CELTDecoder)
1269             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
1270             + channels*LPC_ORDER*sizeof(celt_word16)
1271             + channels*mode->nbEBands*sizeof(celt_word16);
1272    return size;
1273 }
1274
1275 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1276 {
1277    return celt_decoder_init(
1278          (CELTDecoder *)celt_alloc(celt_decoder_get_size(mode, channels)),
1279          mode, channels, error);
1280 }
1281
1282 CELTDecoder *celt_decoder_init(CELTDecoder *st, const CELTMode *mode, int channels, int *error)
1283 {
1284    if (channels < 0 || channels > 2)
1285    {
1286       if (error)
1287          *error = CELT_BAD_ARG;
1288       return NULL;
1289    }
1290
1291    if (st==NULL)
1292    {
1293       if (error)
1294          *error = CELT_ALLOC_FAIL;
1295       return NULL;
1296    }
1297
1298    CELT_MEMSET((char*)st, 0, celt_decoder_get_size(mode, channels));
1299
1300    st->mode = mode;
1301    st->overlap = mode->overlap;
1302    st->channels = channels;
1303
1304    st->start = 0;
1305    st->end = st->mode->effEBands;
1306
1307    st->loss_count = 0;
1308
1309    if (error)
1310       *error = CELT_OK;
1311    return st;
1312 }
1313
1314 void celt_decoder_destroy(CELTDecoder *st)
1315 {
1316    celt_free(st);
1317 }
1318
1319 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1320 {
1321    int c;
1322    int pitch_index;
1323    int overlap = st->mode->overlap;
1324    celt_word16 fade = Q15ONE;
1325    int i, len;
1326    const int C = CHANNELS(st->channels);
1327    int offset;
1328    celt_sig *out_mem[2];
1329    celt_sig *decode_mem[2];
1330    celt_sig *overlap_mem[2];
1331    celt_word16 *lpc;
1332    celt_word16 *oldBandE;
1333    SAVE_STACK;
1334    
1335    for (c=0;c<C;c++)
1336    {
1337       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1338       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1339       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1340    }
1341    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1342    oldBandE = lpc+C*LPC_ORDER;
1343
1344    len = N+st->mode->overlap;
1345    
1346    if (st->loss_count == 0)
1347    {
1348       celt_word16 pitch_buf[MAX_PERIOD>>1];
1349       celt_word32 tmp=0;
1350       celt_word32 mem0[2]={0,0};
1351       celt_word16 mem1[2]={0,0};
1352       int len2 = len;
1353       /* FIXME: This is a kludge */
1354       if (len2>MAX_PERIOD>>1)
1355          len2 = MAX_PERIOD>>1;
1356       pitch_downsample(out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1357                        C, mem0, mem1);
1358       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1359                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1360       pitch_index = MAX_PERIOD-len2-pitch_index;
1361       st->last_pitch_index = pitch_index;
1362    } else {
1363       pitch_index = st->last_pitch_index;
1364       if (st->loss_count < 5)
1365          fade = QCONST16(.8f,15);
1366       else
1367          fade = 0;
1368    }
1369
1370    for (c=0;c<C;c++)
1371    {
1372       /* FIXME: This is more memory than necessary */
1373       celt_word32 e[2*MAX_PERIOD];
1374       celt_word16 exc[2*MAX_PERIOD];
1375       celt_word32 ac[LPC_ORDER+1];
1376       celt_word16 decay = 1;
1377       celt_word32 S1=0;
1378       celt_word16 mem[LPC_ORDER]={0};
1379
1380       offset = MAX_PERIOD-pitch_index;
1381       for (i=0;i<MAX_PERIOD;i++)
1382          exc[i] = ROUND16(out_mem[c][i], SIG_SHIFT);
1383
1384       if (st->loss_count == 0)
1385       {
1386          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1387                         LPC_ORDER, MAX_PERIOD);
1388
1389          /* Noise floor -40 dB */
1390 #ifdef FIXED_POINT
1391          ac[0] += SHR32(ac[0],13);
1392 #else
1393          ac[0] *= 1.0001f;
1394 #endif
1395          /* Lag windowing */
1396          for (i=1;i<=LPC_ORDER;i++)
1397          {
1398             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1399 #ifdef FIXED_POINT
1400             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1401 #else
1402             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1403 #endif
1404          }
1405
1406          _celt_lpc(lpc+c*LPC_ORDER, ac, LPC_ORDER);
1407       }
1408       fir(exc, lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1409       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1410       /* Check if the waveform is decaying (and if so how fast) */
1411       {
1412          celt_word32 E1=1, E2=1;
1413          int period;
1414          if (pitch_index <= MAX_PERIOD/2)
1415             period = pitch_index;
1416          else
1417             period = MAX_PERIOD/2;
1418          for (i=0;i<period;i++)
1419          {
1420             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1421             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1422          }
1423          if (E1 > E2)
1424             E1 = E2;
1425          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1426       }
1427
1428       /* Copy excitation, taking decay into account */
1429       for (i=0;i<len+st->mode->overlap;i++)
1430       {
1431          if (offset+i >= MAX_PERIOD)
1432          {
1433             offset -= pitch_index;
1434             decay = MULT16_16_Q15(decay, decay);
1435          }
1436          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1437          S1 += SHR32(MULT16_16(out_mem[c][offset+i],out_mem[c][offset+i]),8);
1438       }
1439
1440       iir(e, lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1441
1442       {
1443          celt_word32 S2=0;
1444          for (i=0;i<len+overlap;i++)
1445             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1446          /* This checks for an "explosion" in the synthesis */
1447 #ifdef FIXED_POINT
1448          if (!(S1 > SHR32(S2,2)))
1449 #else
1450          /* Float test is written this way to catch NaNs at the same time */
1451          if (!(S1 > 0.2f*S2))
1452 #endif
1453          {
1454             for (i=0;i<len+overlap;i++)
1455                e[i] = 0;
1456          } else if (S1 < S2)
1457          {
1458             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1459             for (i=0;i<len+overlap;i++)
1460                e[i] = MULT16_16_Q15(ratio, e[i]);
1461          }
1462       }
1463
1464       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1465          out_mem[c][i] = out_mem[c][N+i];
1466
1467       /* Apply TDAC to the concealed audio so that it blends with the
1468          previous and next frames */
1469       for (i=0;i<overlap/2;i++)
1470       {
1471          celt_word32 tmp1, tmp2;
1472          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1473                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1474          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1475                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1476          tmp1 = MULT16_32_Q15(fade, tmp1);
1477          tmp2 = MULT16_32_Q15(fade, tmp2);
1478          out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1479          out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1480          out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1481          out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1482       }
1483       for (i=0;i<N-overlap;i++)
1484          out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1485    }
1486
1487    deemphasis(out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1488    
1489    st->loss_count++;
1490
1491    RESTORE_STACK;
1492 }
1493
1494 #ifdef FIXED_POINT
1495 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1496 {
1497 #else
1498 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1499 {
1500 #endif
1501    int c, i, N;
1502    int has_fold;
1503    int bits;
1504    ec_dec _dec;
1505    ec_byte_buffer buf;
1506    VARDECL(celt_sig, freq);
1507    VARDECL(celt_norm, X);
1508    VARDECL(celt_ener, bandE);
1509    VARDECL(int, fine_quant);
1510    VARDECL(int, pulses);
1511    VARDECL(int, offsets);
1512    VARDECL(int, fine_priority);
1513    VARDECL(int, tf_res);
1514    celt_sig *out_mem[2];
1515    celt_sig *decode_mem[2];
1516    celt_sig *overlap_mem[2];
1517    celt_sig *out_syn[2];
1518    celt_word16 *lpc;
1519    celt_word16 *oldBandE;
1520
1521    int shortBlocks;
1522    int isTransient;
1523    int intra_ener;
1524    int transient_time;
1525    int transient_shift;
1526    int mdct_weight_shift=0;
1527    const int C = CHANNELS(st->channels);
1528    int mdct_weight_pos=0;
1529    int LM, M;
1530    int nbFilledBytes, nbAvailableBytes;
1531    int effEnd;
1532    int codedBands;
1533    int alloc_trim;
1534    SAVE_STACK;
1535
1536    if (pcm==NULL)
1537       return CELT_BAD_ARG;
1538
1539    for (LM=0;LM<4;LM++)
1540       if (st->mode->shortMdctSize<<LM==frame_size)
1541          break;
1542    if (LM>=MAX_CONFIG_SIZES)
1543       return CELT_BAD_ARG;
1544    M=1<<LM;
1545
1546    for (c=0;c<C;c++)
1547    {
1548       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1549       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1550       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1551    }
1552    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1553    oldBandE = lpc+C*LPC_ORDER;
1554
1555    N = M*st->mode->shortMdctSize;
1556
1557    effEnd = st->end;
1558    if (effEnd > st->mode->effEBands)
1559       effEnd = st->mode->effEBands;
1560
1561    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1562    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1563    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1564    for (c=0;c<C;c++)
1565       for (i=0;i<M*st->mode->eBands[st->start];i++)
1566          X[c*N+i] = 0;
1567    for (c=0;c<C;c++)
1568       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1569          X[c*N+i] = 0;
1570
1571    if (data == NULL)
1572    {
1573       celt_decode_lost(st, pcm, N, LM);
1574       RESTORE_STACK;
1575       return CELT_OK;
1576    }
1577    if (len<0) {
1578      RESTORE_STACK;
1579      return CELT_BAD_ARG;
1580    }
1581    
1582    if (dec == NULL)
1583    {
1584       ec_byte_readinit(&buf,(unsigned char*)data,len);
1585       ec_dec_init(&_dec,&buf);
1586       dec = &_dec;
1587       nbFilledBytes = 0;
1588    } else {
1589       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1590    }
1591    nbAvailableBytes = len-nbFilledBytes;
1592
1593    /* Decode the global flags (first symbols in the stream) */
1594    intra_ener = ec_dec_bit_prob(dec, 8192);
1595    /* Get band energies */
1596    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
1597          intra_ener, st->mode->prob, dec, C, LM);
1598
1599    if (LM > 0)
1600       isTransient = ec_dec_bit_prob(dec, 8192);
1601    else
1602       isTransient = 0;
1603
1604    if (isTransient)
1605       shortBlocks = M;
1606    else
1607       shortBlocks = 0;
1608
1609    if (isTransient)
1610    {
1611       transient_shift = ec_dec_uint(dec, 4);
1612       if (transient_shift == 3)
1613       {
1614          int transient_time_quant;
1615          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1616          transient_time_quant = ec_dec_uint(dec, max_time);
1617          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1618       } else {
1619          mdct_weight_shift = transient_shift;
1620          if (mdct_weight_shift && M>2)
1621             mdct_weight_pos = ec_dec_uint(dec, M-1);
1622          transient_shift = 0;
1623          transient_time = 0;
1624       }
1625    } else {
1626       transient_time = -1;
1627       transient_shift = 0;
1628    }
1629
1630    ALLOC(tf_res, st->mode->nbEBands, int);
1631    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1632
1633    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1634    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1635
1636    ALLOC(pulses, st->mode->nbEBands, int);
1637    ALLOC(offsets, st->mode->nbEBands, int);
1638    ALLOC(fine_priority, st->mode->nbEBands, int);
1639
1640    for (i=0;i<st->mode->nbEBands;i++)
1641       offsets[i] = 0;
1642    for (i=0;i<st->mode->nbEBands;i++)
1643    {
1644       if (ec_dec_bit_prob(dec, 1024))
1645       {
1646          while (ec_dec_bit_prob(dec, 32768))
1647             offsets[i]++;
1648          offsets[i]++;
1649          offsets[i] *= (6<<BITRES);
1650       }
1651    }
1652
1653    ALLOC(fine_quant, st->mode->nbEBands, int);
1654    {
1655       int fl;
1656       int trim_index=0;
1657       fl = ec_decode_bin(dec, 7);
1658       while (trim_cdf[trim_index+1] <= fl)
1659          trim_index++;
1660       ec_dec_update(dec, trim_cdf[trim_index], trim_cdf[trim_index+1], 128);
1661       alloc_trim = trim_coef[trim_index];
1662    }
1663
1664    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1665    codedBands = compute_allocation(st->mode, st->start, st->end, offsets, alloc_trim, bits, pulses, fine_quant, fine_priority, C, LM);
1666    
1667    unquant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, fine_quant, dec, C);
1668
1669    /* Decode fixed codebook */
1670    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM, codedBands);
1671
1672    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
1673          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1674
1675    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
1676
1677    if (mdct_weight_shift)
1678    {
1679       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1680    }
1681
1682    /* Synthesis */
1683    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1684
1685    CELT_MOVE(decode_mem[0], decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1686    if (C==2)
1687       CELT_MOVE(decode_mem[1], decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1688
1689    for (c=0;c<C;c++)
1690       for (i=0;i<M*st->mode->eBands[st->start];i++)
1691          freq[c*N+i] = 0;
1692    for (c=0;c<C;c++)
1693       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1694          freq[c*N+i] = 0;
1695
1696    out_syn[0] = out_mem[0]+MAX_PERIOD-N;
1697    if (C==2)
1698       out_syn[1] = out_mem[1]+MAX_PERIOD-N;
1699
1700    /* Compute inverse MDCTs */
1701    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1702          transient_shift, out_syn, overlap_mem, C, LM);
1703
1704    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
1705    st->loss_count = 0;
1706    RESTORE_STACK;
1707    if (ec_dec_get_error(dec))
1708       return CELT_CORRUPTED_DATA;
1709    else
1710       return CELT_OK;
1711 }
1712
1713 #ifdef FIXED_POINT
1714 #ifndef DISABLE_FLOAT_API
1715 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1716 {
1717    int j, ret, C, N, LM, M;
1718    VARDECL(celt_int16, out);
1719    SAVE_STACK;
1720
1721    if (pcm==NULL)
1722       return CELT_BAD_ARG;
1723
1724    for (LM=0;LM<4;LM++)
1725       if (st->mode->shortMdctSize<<LM==frame_size)
1726          break;
1727    if (LM>=MAX_CONFIG_SIZES)
1728       return CELT_BAD_ARG;
1729    M=1<<LM;
1730
1731    C = CHANNELS(st->channels);
1732    N = M*st->mode->shortMdctSize;
1733    
1734    ALLOC(out, C*N, celt_int16);
1735    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1736    if (ret==0)
1737       for (j=0;j<C*N;j++)
1738          pcm[j]=out[j]*(1.f/32768.f);
1739      
1740    RESTORE_STACK;
1741    return ret;
1742 }
1743 #endif /*DISABLE_FLOAT_API*/
1744 #else
1745 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1746 {
1747    int j, ret, C, N, LM, M;
1748    VARDECL(celt_sig, out);
1749    SAVE_STACK;
1750
1751    if (pcm==NULL)
1752       return CELT_BAD_ARG;
1753
1754    for (LM=0;LM<4;LM++)
1755       if (st->mode->shortMdctSize<<LM==frame_size)
1756          break;
1757    if (LM>=MAX_CONFIG_SIZES)
1758       return CELT_BAD_ARG;
1759    M=1<<LM;
1760
1761    C = CHANNELS(st->channels);
1762    N = M*st->mode->shortMdctSize;
1763    ALLOC(out, C*N, celt_sig);
1764
1765    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1766
1767    if (ret==0)
1768       for (j=0;j<C*N;j++)
1769          pcm[j] = FLOAT2INT16 (out[j]);
1770    
1771    RESTORE_STACK;
1772    return ret;
1773 }
1774 #endif
1775
1776 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1777 {
1778    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1779 }
1780
1781 #ifndef DISABLE_FLOAT_API
1782 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1783 {
1784    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1785 }
1786 #endif /* DISABLE_FLOAT_API */
1787
1788 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1789 {
1790    va_list ap;
1791
1792    va_start(ap, request);
1793    switch (request)
1794    {
1795       case CELT_GET_MODE_REQUEST:
1796       {
1797          const CELTMode ** value = va_arg(ap, const CELTMode**);
1798          if (value==0)
1799             goto bad_arg;
1800          *value=st->mode;
1801       }
1802       break;
1803       case CELT_SET_START_BAND_REQUEST:
1804       {
1805          celt_int32 value = va_arg(ap, celt_int32);
1806          if (value<0 || value>=st->mode->nbEBands)
1807             goto bad_arg;
1808          st->start = value;
1809       }
1810       break;
1811       case CELT_SET_END_BAND_REQUEST:
1812       {
1813          celt_int32 value = va_arg(ap, celt_int32);
1814          if (value<0 || value>=st->mode->nbEBands)
1815             goto bad_arg;
1816          st->end = value;
1817       }
1818       break;
1819       case CELT_RESET_STATE:
1820       {
1821          CELT_MEMSET((char*)&st->DECODER_RESET_START, 0,
1822                celt_decoder_get_size(st->mode, st->channels)-
1823                ((char*)&st->DECODER_RESET_START - (char*)st));
1824       }
1825       break;
1826       default:
1827          goto bad_request;
1828    }
1829    va_end(ap);
1830    return CELT_OK;
1831 bad_arg:
1832    va_end(ap);
1833    return CELT_BAD_ARG;
1834 bad_request:
1835       va_end(ap);
1836   return CELT_UNIMPLEMENTED;
1837 }
1838
1839 const char *celt_strerror(int error)
1840 {
1841    static const char *error_strings[8] = {
1842       "success",
1843       "invalid argument",
1844       "invalid mode",
1845       "internal error",
1846       "corrupted stream",
1847       "request not implemented",
1848       "invalid state",
1849       "memory allocation failed"
1850    };
1851    if (error > 0 || error < -7)
1852       return "unknown error";
1853    else 
1854       return error_strings[-error];
1855 }
1856