Removing out_mem from the encoder state.
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 #define ENCODERVALID   0x4c434554
69 #define ENCODERPARTIAL 0x5445434c
70 #define ENCODERFREED   0x4c004500
71    
72 /** Encoder state 
73  @brief Encoder state
74  */
75 struct CELTEncoder {
76    celt_uint32 marker;
77    const CELTMode *mode;     /**< Mode used by the encoder */
78    int overlap;
79    int channels;
80    
81    int force_intra;
82    int delayedIntra;
83    celt_word16 tonal_average;
84    int fold_decision;
85    celt_word16 gain_prod;
86    celt_word32 frame_max;
87    int start, end;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
96    celt_word32 preemph_memE[2];
97    celt_word32 preemph_memD[2];
98
99    celt_sig *in_mem;
100    celt_sig *overlap_mem[2];
101
102    celt_word16 *oldBandE;
103 };
104
105 static int check_encoder(const CELTEncoder *st) 
106 {
107    if (st==NULL)
108    {
109       celt_warning("NULL passed as an encoder structure");  
110       return CELT_INVALID_STATE;
111    }
112    if (st->marker == ENCODERVALID)
113       return CELT_OK;
114    if (st->marker == ENCODERFREED)
115       celt_warning("Referencing an encoder that has already been freed");
116    else
117       celt_warning("This is not a valid CELT encoder structure");
118    return CELT_INVALID_STATE;
119 }
120
121 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
122 {
123    int C;
124    CELTEncoder *st;
125
126    if (check_mode(mode) != CELT_OK)
127    {
128       if (error)
129          *error = CELT_INVALID_MODE;
130       return NULL;
131    }
132
133    if (channels < 0 || channels > 2)
134    {
135       celt_warning("Only mono and stereo supported");
136       if (error)
137          *error = CELT_BAD_ARG;
138       return NULL;
139    }
140
141    C = channels;
142    st = celt_alloc(sizeof(CELTEncoder));
143    
144    if (st==NULL)
145    {
146       if (error)
147          *error = CELT_ALLOC_FAIL;
148       return NULL;
149    }
150    st->marker = ENCODERPARTIAL;
151    st->mode = mode;
152    st->overlap = mode->overlap;
153    st->channels = channels;
154
155    st->start = 0;
156    st->end = st->mode->effEBands;
157
158    st->vbr_rate_norm = 0;
159    st->force_intra  = 0;
160    st->delayedIntra = 1;
161    st->tonal_average = QCONST16(1.f,8);
162    st->fold_decision = 1;
163
164    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig));
165    st->overlap_mem[0] = celt_alloc(st->overlap*C*sizeof(celt_sig));
166    if (C==2)
167       st->overlap_mem[1] = st->overlap_mem[0] + st->overlap;
168
169    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
170
171    if ((st->in_mem!=NULL) && (st->overlap_mem[0]!=NULL) && (st->oldBandE!=NULL))
172    {
173       if (error)
174          *error = CELT_OK;
175       st->marker   = ENCODERVALID;
176       return st;
177    }
178    /* If the setup fails for some reason deallocate it. */
179    celt_encoder_destroy(st);  
180    if (error)
181       *error = CELT_ALLOC_FAIL;
182    return NULL;
183 }
184
185 void celt_encoder_destroy(CELTEncoder *st)
186 {
187    if (st == NULL)
188    {
189       celt_warning("NULL passed to celt_encoder_destroy");
190       return;
191    }
192
193    if (st->marker == ENCODERFREED)
194    {
195       celt_warning("Freeing an encoder which has already been freed"); 
196       return;
197    }
198
199    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
200    {
201       celt_warning("This is not a valid CELT encoder structure");
202       return;
203    }
204    /*Check_mode is non-fatal here because we can still free
205     the encoder memory even if the mode is bad, although calling
206     the free functions in this order is a violation of the API.*/
207    check_mode(st->mode);
208    
209    celt_free(st->in_mem);
210    celt_free(st->overlap_mem[0]);
211    celt_free(st->oldBandE);
212    
213    st->marker = ENCODERFREED;
214    
215    celt_free(st);
216 }
217
218 static inline celt_int16 FLOAT2INT16(float x)
219 {
220    x = x*CELT_SIG_SCALE;
221    x = MAX32(x, -32768);
222    x = MIN32(x, 32767);
223    return (celt_int16)float2int(x);
224 }
225
226 static inline celt_word16 SIG2WORD16(celt_sig x)
227 {
228 #ifdef FIXED_POINT
229    x = PSHR32(x, SIG_SHIFT);
230    x = MAX32(x, -32768);
231    x = MIN32(x, 32767);
232    return EXTRACT16(x);
233 #else
234    return (celt_word16)x;
235 #endif
236 }
237
238 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
239                               int *transient_time, int *transient_shift,
240                               celt_word32 *frame_max, int overlap)
241 {
242    int i, n;
243    celt_word32 ratio;
244    celt_word32 threshold;
245    VARDECL(celt_word32, begin);
246    SAVE_STACK;
247    ALLOC(begin, len+1, celt_word32);
248    begin[0] = 0;
249    if (C==1)
250    {
251       for (i=0;i<len;i++)
252          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
253    } else {
254       for (i=0;i<len;i++)
255          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
256                                             ABS32(in[C*i+1])));
257    }
258    n = -1;
259
260    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
261    /* If the following condition isn't met, there's just no way
262       we'll have a transient*/
263    if (*frame_max < threshold)
264    {
265       /* It's likely we have a transient, now find it */
266       for (i=8;i<len-8;i++)
267       {
268          if (begin[i+1] < threshold)
269             n=i;
270       }
271    }
272    if (n<32)
273    {
274       n = -1;
275       ratio = 0;
276    } else {
277       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
278    }
279
280    if (ratio > 45)
281       *transient_shift = 3;
282    else
283       *transient_shift = 0;
284    
285    *transient_time = n;
286    *frame_max = begin[len-overlap];
287
288    RESTORE_STACK;
289    return ratio > 0;
290 }
291
292 /** Apply window and compute the MDCT for all sub-frames and 
293     all channels in a frame */
294 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
295 {
296    const int C = CHANNELS(_C);
297    if (C==1 && !shortBlocks)
298    {
299       const int overlap = OVERLAP(mode);
300       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
301    } else {
302       const int overlap = OVERLAP(mode);
303       int N = mode->shortMdctSize<<LM;
304       int B = 1;
305       int b, c;
306       VARDECL(celt_word32, x);
307       VARDECL(celt_word32, tmp);
308       SAVE_STACK;
309       if (shortBlocks)
310       {
311          /*lookup = &mode->mdct[0];*/
312          N = mode->shortMdctSize;
313          B = shortBlocks;
314       }
315       ALLOC(x, N+overlap, celt_word32);
316       ALLOC(tmp, N, celt_word32);
317       for (c=0;c<C;c++)
318       {
319          for (b=0;b<B;b++)
320          {
321             int j;
322             for (j=0;j<N+overlap;j++)
323                x[j] = in[C*(b*N+j)+c];
324             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
325             /* Interleaving the sub-frames */
326             for (j=0;j<N;j++)
327                out[(j*B+b)+c*N*B] = tmp[j];
328          }
329       }
330       RESTORE_STACK;
331    }
332 }
333
334 /** Compute the IMDCT and apply window for all sub-frames and 
335     all channels in a frame */
336 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
337       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
338       celt_sig * restrict overlap_mem[], int _C, int LM)
339 {
340    int c, N4;
341    const int C = CHANNELS(_C);
342    const int N = mode->shortMdctSize<<LM;
343    const int overlap = OVERLAP(mode);
344    N4 = (N-overlap)>>1;
345    for (c=0;c<C;c++)
346    {
347       int j;
348       if (0 && transient_shift==0 && C==1 && !shortBlocks) {
349          clt_mdct_backward(&mode->mdct, X, out_mem[0]+(MAX_PERIOD-N-N4), mode->window, overlap, mode->maxLM-LM);
350       } else {
351          VARDECL(celt_word32, x);
352          VARDECL(celt_word32, tmp);
353          int b;
354          int N2 = N;
355          int B = 1;
356          int n4offset=0;
357          SAVE_STACK;
358          
359          ALLOC(x, 2*N, celt_word32);
360          ALLOC(tmp, N, celt_word32);
361
362          if (shortBlocks)
363          {
364             /*lookup = &mode->mdct[0];*/
365             N2 = mode->shortMdctSize;
366             B = shortBlocks;
367             n4offset = N4;
368          }
369          /* Prevents problems from the imdct doing the overlap-add */
370          CELT_MEMSET(x+N4, 0, N2);
371
372          for (b=0;b<B;b++)
373          {
374             /* De-interleaving the sub-frames */
375             for (j=0;j<N2;j++)
376                tmp[j] = X[(j*B+b)+c*N2*B];
377             clt_mdct_backward(&mode->mdct, tmp, x+n4offset+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
378          }
379
380          if (transient_shift > 0)
381          {
382 #ifdef FIXED_POINT
383             for (j=0;j<16;j++)
384                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
385             for (j=transient_time;j<N+overlap;j++)
386                x[N4+j] = SHL32(x[N4+j], transient_shift);
387 #else
388             for (j=0;j<16;j++)
389                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
390             for (j=transient_time;j<N+overlap;j++)
391                x[N4+j] *= 1<<transient_shift;
392 #endif
393          }
394          /* The first and last part would need to be set to zero 
395             if we actually wanted to use them. */
396          for (j=0;j<overlap;j++)
397             out_mem[c][MAX_PERIOD-N+j] = x[j+N4] + overlap_mem[c][j];
398          for (;j<N;j++)
399             out_mem[c][MAX_PERIOD-N+j] = x[j+N4];
400          for (j=0;j<overlap;j++)
401             overlap_mem[c][j] = x[N+j+N4];
402          RESTORE_STACK;
403       }
404    }
405 }
406
407 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
408 {
409    const int C = CHANNELS(_C);
410    int c;
411    for (c=0;c<C;c++)
412    {
413       int j;
414       celt_sig * restrict x;
415       celt_word16  * restrict y;
416       celt_sig m = mem[c];
417       x = &in[c][MAX_PERIOD-N];
418       y = pcm+c;
419       for (j=0;j<N;j++)
420       {
421          celt_sig tmp = *x + m;
422          m = MULT16_32_Q15(coef[0], tmp)
423            - MULT16_32_Q15(coef[1], *x);
424          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
425          *y = SCALEOUT(SIG2WORD16(tmp));
426          x++;
427          y+=C;
428       }
429       mem[c] = m;
430    }
431 }
432
433 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
434                        int end, int N,
435                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
436 {
437    int m, i, c;
438    const int C = CHANNELS(_C);
439    for (c=0;c<C;c++)
440       for (m=start;m<end;m++)
441          for (i=m+c*N;i<(c+1)*N;i+=M)
442 #ifdef FIXED_POINT
443             X[i] = SHR16(X[i], mdct_weight_shift);
444 #else
445             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
446 #endif
447    if (renorm)
448       renormalise_bands(mode, X, end_band, C, M);
449 }
450
451 static signed char tf_select_table[4][8] = {
452       {0, -1, 0, -1,    0,-1, 0,-1},
453       {0, -1, 0, -2,    1, 0, 1 -1},
454       {0, -2, 0, -3,    2, 0, 1 -1},
455       {0, -2, 0, -3,    2, 0, 1 -1},
456 };
457
458 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
459 {
460    int i;
461    celt_word16 threshold;
462    VARDECL(celt_word16, metric);
463    celt_word32 average=0;
464    celt_word32 cost0;
465    celt_word32 cost1;
466    VARDECL(int, path0);
467    VARDECL(int, path1);
468    celt_word16 lambda;
469    int tf_select=0;
470    SAVE_STACK;
471
472    /* FIXME: Should check number of bytes *left* */
473    if (nbCompressedBytes<15*C)
474    {
475       for (i=0;i<len;i++)
476          tf_res[i] = 0;
477       return 0;
478    }
479    if (nbCompressedBytes<40)
480       lambda = QCONST16(5.f, DB_SHIFT);
481    else if (nbCompressedBytes<60)
482       lambda = QCONST16(2.f, DB_SHIFT);
483    else if (nbCompressedBytes<100)
484       lambda = QCONST16(1.f, DB_SHIFT);
485    else
486       lambda = QCONST16(.5f, DB_SHIFT);
487
488    ALLOC(metric, len, celt_word16);
489    ALLOC(path0, len, int);
490    ALLOC(path1, len, int);
491    for (i=0;i<len;i++)
492    {
493       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
494       average += metric[i];
495    }
496    if (C==2)
497    {
498       average = 0;
499       for (i=0;i<len;i++)
500       {
501          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
502          average += metric[i];
503       }
504    }
505    average = DIV32(average, len);
506    /*if (!isTransient)
507       printf ("%f\n", average);*/
508    if (isTransient)
509    {
510       threshold = QCONST16(1.f,DB_SHIFT);
511       tf_select = average > QCONST16(3.f,DB_SHIFT);
512    } else {
513       threshold = QCONST16(.5f,DB_SHIFT);
514       tf_select = average > QCONST16(1.f,DB_SHIFT);
515    }
516    cost0 = 0;
517    cost1 = lambda;
518    /* Viterbi forward pass */
519    for (i=1;i<len;i++)
520    {
521       celt_word32 curr0, curr1;
522       celt_word32 from0, from1;
523
524       from0 = cost0;
525       from1 = cost1 + lambda;
526       if (from0 < from1)
527       {
528          curr0 = from0;
529          path0[i]= 0;
530       } else {
531          curr0 = from1;
532          path0[i]= 1;
533       }
534
535       from0 = cost0 + lambda;
536       from1 = cost1;
537       if (from0 < from1)
538       {
539          curr1 = from0;
540          path1[i]= 0;
541       } else {
542          curr1 = from1;
543          path1[i]= 1;
544       }
545       cost0 = curr0 + (metric[i]-threshold);
546       cost1 = curr1;
547    }
548    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
549    /* Viterbi backward pass to check the decisions */
550    for (i=len-2;i>=0;i--)
551    {
552       if (tf_res[i+1] == 1)
553          tf_res[i] = path1[i+1];
554       else
555          tf_res[i] = path0[i+1];
556    }
557    RESTORE_STACK;
558    return tf_select;
559 }
560
561 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
562 {
563    int curr, i;
564    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
565    {
566       for (i=start;i<end;i++)
567          tf_res[i] = isTransient;
568    } else {
569       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
570       curr = tf_res[start];
571       for (i=start+1;i<end;i++)
572       {
573          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
574          curr = tf_res[i];
575       }
576    }
577    ec_enc_bits(enc, tf_select, 1);
578    for (i=start;i<end;i++)
579       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
580 }
581
582 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
583 {
584    int i, curr, tf_select;
585    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
586    {
587       for (i=start;i<end;i++)
588          tf_res[i] = isTransient;
589    } else {
590       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
591       curr = tf_res[start];
592       for (i=start+1;i<end;i++)
593       {
594          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
595          curr = tf_res[i];
596       }
597    }
598    tf_select = ec_dec_bits(dec, 1);
599    for (i=start;i<end;i++)
600       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
601 }
602
603 #ifdef FIXED_POINT
604 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
605 {
606 #else
607 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
608 {
609 #endif
610    int i, c, N, NN, N4;
611    int bits;
612    int has_fold=1;
613    ec_byte_buffer buf;
614    ec_enc         _enc;
615    VARDECL(celt_sig, in);
616    VARDECL(celt_sig, freq);
617    VARDECL(celt_norm, X);
618    VARDECL(celt_ener, bandE);
619    VARDECL(celt_word16, bandLogE);
620    VARDECL(int, fine_quant);
621    VARDECL(celt_word16, error);
622    VARDECL(int, pulses);
623    VARDECL(int, offsets);
624    VARDECL(int, fine_priority);
625    VARDECL(int, tf_res);
626    int shortBlocks=0;
627    int isTransient=0;
628    int transient_time, transient_time_quant;
629    int transient_shift;
630    int resynth;
631    const int C = CHANNELS(st->channels);
632    int mdct_weight_shift = 0;
633    int mdct_weight_pos=0;
634    int LM, M;
635    int tf_select;
636    int nbFilledBytes, nbAvailableBytes;
637    int effEnd;
638    SAVE_STACK;
639
640    if (check_encoder(st) != CELT_OK)
641       return CELT_INVALID_STATE;
642
643    if (check_mode(st->mode) != CELT_OK)
644       return CELT_INVALID_MODE;
645
646    if (nbCompressedBytes<0 || pcm==NULL)
647      return CELT_BAD_ARG;
648
649    for (LM=0;LM<4;LM++)
650       if (st->mode->shortMdctSize<<LM==frame_size)
651          break;
652    if (LM>=MAX_CONFIG_SIZES)
653       return CELT_BAD_ARG;
654    M=1<<LM;
655
656    if (enc==NULL)
657    {
658       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
659       ec_enc_init(&_enc,&buf);
660       enc = &_enc;
661       nbFilledBytes=0;
662    } else {
663       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
664    }
665    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
666
667    effEnd = st->end;
668    if (effEnd > st->mode->effEBands)
669       effEnd = st->mode->effEBands;
670
671    N = M*st->mode->shortMdctSize;
672    N4 = (N-st->overlap)>>1;
673    ALLOC(in, 2*C*N-2*C*N4, celt_sig);
674
675    CELT_COPY(in, st->in_mem, C*st->overlap);
676    for (c=0;c<C;c++)
677    {
678       const celt_word16 * restrict pcmp = pcm+c;
679       celt_sig * restrict inp = in+C*st->overlap+c;
680       for (i=0;i<N;i++)
681       {
682          /* Apply pre-emphasis */
683          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
684          *inp = tmp + st->preemph_memE[c];
685          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
686                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
687          inp += C;
688          pcmp += C;
689       }
690    }
691    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
692
693    /* Transient handling */
694    transient_time = -1;
695    transient_time_quant = -1;
696    transient_shift = 0;
697    isTransient = 0;
698
699    resynth = optional_resynthesis!=NULL;
700
701    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
702    {
703 #ifndef FIXED_POINT
704       float gain_1;
705 #endif
706       /* Apply the inverse shaping window */
707       if (transient_shift)
708       {
709          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
710          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
711 #ifdef FIXED_POINT
712          for (c=0;c<C;c++)
713             for (i=0;i<16;i++)
714                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
715          for (c=0;c<C;c++)
716             for (i=transient_time;i<N+st->overlap;i++)
717                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
718 #else
719          for (c=0;c<C;c++)
720             for (i=0;i<16;i++)
721                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
722          gain_1 = 1.f/(1<<transient_shift);
723          for (c=0;c<C;c++)
724             for (i=transient_time;i<N+st->overlap;i++)
725                in[C*i+c] *= gain_1;
726 #endif
727       }
728       isTransient = 1;
729       has_fold = 1;
730    }
731
732    if (isTransient)
733       shortBlocks = M;
734    else
735       shortBlocks = 0;
736
737    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
738    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
739    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
740    /* Compute MDCTs */
741    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
742
743    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
744
745    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
746
747    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
748
749    /* Band normalisation */
750    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
751
752    NN = M*st->mode->eBands[effEnd];
753    if (shortBlocks && !transient_shift)
754    {
755       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
756       int m;
757       for (c=0;c<C;c++)
758       {
759          m=0;
760          do {
761             celt_word32 tmp=0;
762             for (i=m+c*N;i<c*N+NN;i+=M)
763                tmp += ABS32(X[i]);
764             sum[m++] += tmp;
765          } while (m<M);
766       }
767       m=0;
768 #ifdef FIXED_POINT
769       do {
770          if (SHR32(sum[m+1],3) > sum[m])
771          {
772             mdct_weight_shift=2;
773             mdct_weight_pos = m;
774          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
775          {
776             mdct_weight_shift=1;
777             mdct_weight_pos = m;
778          }
779          m++;
780       } while (m<M-1);
781 #else
782       do {
783          if (sum[m+1] > 8*sum[m])
784          {
785             mdct_weight_shift=2;
786             mdct_weight_pos = m;
787          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
788          {
789             mdct_weight_shift=1;
790             mdct_weight_pos = m;
791          }
792          m++;
793       } while (m<M-1);
794 #endif
795       if (mdct_weight_shift)
796          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
797    }
798
799    ALLOC(tf_res, st->mode->nbEBands, int);
800    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
801    tf_select = tf_analysis(bandLogE, st->oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
802    for (i=effEnd;i<st->end;i++)
803       tf_res[i] = tf_res[effEnd-1];
804
805    ALLOC(error, C*st->mode->nbEBands, celt_word16);
806    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
807          st->oldBandE, nbCompressedBytes*8, st->mode->prob,
808          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
809
810    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
811
812    if (shortBlocks)
813    {
814       if (transient_shift)
815       {
816          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
817          ec_enc_uint(enc, transient_shift, 4);
818          ec_enc_uint(enc, transient_time_quant, max_time);
819       } else {
820          ec_enc_uint(enc, mdct_weight_shift, 4);
821          if (mdct_weight_shift && M!=2)
822             ec_enc_uint(enc, mdct_weight_pos, M-1);
823       }
824    }
825
826    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
827
828    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
829       has_fold = 0;
830    ec_enc_bit_prob(enc, has_fold>>1, 8192);
831    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
832
833    /* Variable bitrate */
834    if (st->vbr_rate_norm>0)
835    {
836      celt_word16 alpha;
837      celt_int32 delta;
838      /* The target rate in 16th bits per frame */
839      celt_int32 vbr_rate;
840      celt_int32 target;
841      celt_int32 vbr_bound, max_allowed;
842
843      vbr_rate = M*st->vbr_rate_norm;
844
845      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
846      vbr_bound = vbr_rate;
847      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
848      if (max_allowed < 4)
849         max_allowed = 4;
850      if (max_allowed < nbAvailableBytes)
851         nbAvailableBytes = max_allowed;
852      target=vbr_rate;
853
854      /* Shortblocks get a large boost in bitrate, but since they 
855         are uncommon long blocks are not greatly effected */
856      if (shortBlocks)
857        target*=2;
858      else if (M > 1)
859        target-=(target+14)/28;
860
861      /* The average energy is removed from the target and the actual 
862         energy added*/
863      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
864
865      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
866      target=IMIN(nbAvailableBytes,target);
867      /* Make the adaptation coef (alpha) higher at the beginning */
868      if (st->vbr_count < 990)
869      {
870         st->vbr_count++;
871         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
872         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
873      } else
874         alpha = QCONST16(.001f,15);
875
876      /* By how much did we "miss" the target on that frame */
877      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
878      /* How many bits have we used in excess of what we're allowed */
879      st->vbr_reservoir += delta;
880      /*printf ("%d\n", st->vbr_reservoir);*/
881
882      /* Compute the offset we need to apply in order to reach the target */
883      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
884      st->vbr_offset = -st->vbr_drift;
885      /*printf ("%d\n", st->vbr_drift);*/
886
887      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
888      if (st->vbr_reservoir < 0)
889      {
890         /* We're under the min value -- increase rate */
891         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
892         st->vbr_reservoir += adjust*(8<<BITRES);
893         target += adjust;
894         /*printf ("+%d\n", adjust);*/
895      }
896      if (target < nbAvailableBytes)
897         nbAvailableBytes = target;
898      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
899
900      /* This moves the raw bits to take into account the new compressed size */
901      ec_byte_shrink(&buf, nbCompressedBytes);
902    }
903
904    /* Bit allocation */
905    ALLOC(fine_quant, st->mode->nbEBands, int);
906    ALLOC(pulses, st->mode->nbEBands, int);
907    ALLOC(offsets, st->mode->nbEBands, int);
908    ALLOC(fine_priority, st->mode->nbEBands, int);
909
910    for (i=0;i<st->mode->nbEBands;i++)
911       offsets[i] = 0;
912    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
913    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
914
915    quant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, enc, C);
916
917 #ifdef MEASURE_NORM_MSE
918    float X0[3000];
919    float bandE0[60];
920    for (c=0;c<C;c++)
921       for (i=0;i<N;i++)
922          X0[i+c*N] = X[i+c*N];
923    for (i=0;i<C*st->mode->nbEBands;i++)
924       bandE0[i] = bandE[i];
925 #endif
926
927    /* Residual quantisation */
928    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
929
930    quant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
931
932    /* Re-synthesis of the coded audio if required */
933    if (resynth)
934    {
935       VARDECL(celt_sig, _out_mem);
936       celt_sig *out_mem[2];
937
938       log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
939
940 #ifdef MEASURE_NORM_MSE
941       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
942 #endif
943
944       if (mdct_weight_shift)
945       {
946          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
947       }
948
949       /* Synthesis */
950       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
951
952       for (c=0;c<C;c++)
953          for (i=0;i<M*st->mode->eBands[st->start];i++)
954             freq[c*N+i] = 0;
955       for (c=0;c<C;c++)
956          for (i=M*st->mode->eBands[st->end];i<N;i++)
957             freq[c*N+i] = 0;
958
959       ALLOC(_out_mem, C*N, celt_sig);
960       out_mem[0] = _out_mem;
961       if (C==2)
962          out_mem[1] = _out_mem+N;
963
964       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
965             transient_shift, out_mem, st->overlap_mem, C, LM);
966
967       /* De-emphasis and put everything back at the right place 
968          in the synthesis history */
969       if (optional_resynthesis != NULL) {
970          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
971
972       }
973    }
974
975    /* If there's any room left (can only happen for very high rates),
976       fill it with zeros */
977    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
978       ec_enc_bits(enc, 0, 8);
979    ec_enc_done(enc);
980    
981    RESTORE_STACK;
982    if (ec_enc_get_error(enc))
983       return CELT_CORRUPTED_DATA;
984    else
985       return nbCompressedBytes;
986 }
987
988 #ifdef FIXED_POINT
989 #ifndef DISABLE_FLOAT_API
990 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
991 {
992    int j, ret, C, N, LM, M;
993    VARDECL(celt_int16, in);
994    SAVE_STACK;
995
996    if (check_encoder(st) != CELT_OK)
997       return CELT_INVALID_STATE;
998
999    if (check_mode(st->mode) != CELT_OK)
1000       return CELT_INVALID_MODE;
1001
1002    if (pcm==NULL)
1003       return CELT_BAD_ARG;
1004
1005    for (LM=0;LM<4;LM++)
1006       if (st->mode->shortMdctSize<<LM==frame_size)
1007          break;
1008    if (LM>=MAX_CONFIG_SIZES)
1009       return CELT_BAD_ARG;
1010    M=1<<LM;
1011
1012    C = CHANNELS(st->channels);
1013    N = M*st->mode->shortMdctSize;
1014    ALLOC(in, C*N, celt_int16);
1015
1016    for (j=0;j<C*N;j++)
1017      in[j] = FLOAT2INT16(pcm[j]);
1018
1019    if (optional_resynthesis != NULL) {
1020      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1021       for (j=0;j<C*N;j++)
1022          optional_resynthesis[j]=in[j]*(1.f/32768.f);
1023    } else {
1024      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1025    }
1026    RESTORE_STACK;
1027    return ret;
1028
1029 }
1030 #endif /*DISABLE_FLOAT_API*/
1031 #else
1032 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1033 {
1034    int j, ret, C, N, LM, M;
1035    VARDECL(celt_sig, in);
1036    SAVE_STACK;
1037
1038    if (check_encoder(st) != CELT_OK)
1039       return CELT_INVALID_STATE;
1040
1041    if (check_mode(st->mode) != CELT_OK)
1042       return CELT_INVALID_MODE;
1043
1044    if (pcm==NULL)
1045       return CELT_BAD_ARG;
1046
1047    for (LM=0;LM<4;LM++)
1048       if (st->mode->shortMdctSize<<LM==frame_size)
1049          break;
1050    if (LM>=MAX_CONFIG_SIZES)
1051       return CELT_BAD_ARG;
1052    M=1<<LM;
1053
1054    C=CHANNELS(st->channels);
1055    N=M*st->mode->shortMdctSize;
1056    ALLOC(in, C*N, celt_sig);
1057    for (j=0;j<C*N;j++) {
1058      in[j] = SCALEOUT(pcm[j]);
1059    }
1060
1061    if (optional_resynthesis != NULL) {
1062       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1063       for (j=0;j<C*N;j++)
1064          optional_resynthesis[j] = FLOAT2INT16(in[j]);
1065    } else {
1066       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1067    }
1068    RESTORE_STACK;
1069    return ret;
1070 }
1071 #endif
1072
1073 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1074 {
1075    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1076 }
1077
1078 #ifndef DISABLE_FLOAT_API
1079 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1080 {
1081    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1082 }
1083 #endif /* DISABLE_FLOAT_API */
1084
1085 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1086 {
1087    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1088 }
1089
1090 #ifndef DISABLE_FLOAT_API
1091 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1092 {
1093    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1094 }
1095 #endif /* DISABLE_FLOAT_API */
1096
1097
1098 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1099 {
1100    va_list ap;
1101    
1102    if (check_encoder(st) != CELT_OK)
1103       return CELT_INVALID_STATE;
1104
1105    va_start(ap, request);
1106    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1107      goto bad_mode;
1108    switch (request)
1109    {
1110       case CELT_GET_MODE_REQUEST:
1111       {
1112          const CELTMode ** value = va_arg(ap, const CELTMode**);
1113          if (value==0)
1114             goto bad_arg;
1115          *value=st->mode;
1116       }
1117       break;
1118       case CELT_SET_COMPLEXITY_REQUEST:
1119       {
1120          int value = va_arg(ap, celt_int32);
1121          if (value<0 || value>10)
1122             goto bad_arg;
1123       }
1124       break;
1125       case CELT_SET_START_BAND_REQUEST:
1126       {
1127          celt_int32 value = va_arg(ap, celt_int32);
1128          if (value<0 || value>=st->mode->nbEBands)
1129             goto bad_arg;
1130          st->start = value;
1131       }
1132       break;
1133       case CELT_SET_END_BAND_REQUEST:
1134       {
1135          celt_int32 value = va_arg(ap, celt_int32);
1136          if (value<0 || value>=st->mode->nbEBands)
1137             goto bad_arg;
1138          st->end = value;
1139       }
1140       break;
1141       case CELT_SET_PREDICTION_REQUEST:
1142       {
1143          int value = va_arg(ap, celt_int32);
1144          if (value<0 || value>2)
1145             goto bad_arg;
1146          if (value==0)
1147          {
1148             st->force_intra   = 1;
1149          } else if (value==1) {
1150             st->force_intra   = 0;
1151          } else {
1152             st->force_intra   = 0;
1153          }   
1154       }
1155       break;
1156       case CELT_SET_VBR_RATE_REQUEST:
1157       {
1158          celt_int32 value = va_arg(ap, celt_int32);
1159          int frame_rate;
1160          int N = st->mode->shortMdctSize;
1161          if (value<0)
1162             goto bad_arg;
1163          if (value>3072000)
1164             value = 3072000;
1165          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1166          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1167       }
1168       break;
1169       case CELT_RESET_STATE:
1170       {
1171          const CELTMode *mode = st->mode;
1172          int C = st->channels;
1173
1174          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1175          CELT_MEMSET(st->overlap_mem[0], 0, st->overlap*C);
1176
1177          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1178
1179          CELT_MEMSET(st->preemph_memE, 0, C);
1180          CELT_MEMSET(st->preemph_memD, 0, C);
1181          st->delayedIntra = 1;
1182
1183          st->fold_decision = 1;
1184          st->tonal_average = QCONST16(1.f,8);
1185          st->gain_prod = 0;
1186          st->vbr_reservoir = 0;
1187          st->vbr_drift = 0;
1188          st->vbr_offset = 0;
1189          st->vbr_count = 0;
1190          st->frame_max = 0;
1191       }
1192       break;
1193       default:
1194          goto bad_request;
1195    }
1196    va_end(ap);
1197    return CELT_OK;
1198 bad_mode:
1199   va_end(ap);
1200   return CELT_INVALID_MODE;
1201 bad_arg:
1202    va_end(ap);
1203    return CELT_BAD_ARG;
1204 bad_request:
1205    va_end(ap);
1206    return CELT_UNIMPLEMENTED;
1207 }
1208
1209 /**********************************************************************/
1210 /*                                                                    */
1211 /*                             DECODER                                */
1212 /*                                                                    */
1213 /**********************************************************************/
1214 #define DECODE_BUFFER_SIZE 2048
1215
1216 #define DECODERVALID   0x4c434454
1217 #define DECODERPARTIAL 0x5444434c
1218 #define DECODERFREED   0x4c004400
1219
1220 /** Decoder state 
1221  @brief Decoder state
1222  */
1223 struct CELTDecoder {
1224    celt_uint32 marker;
1225    const CELTMode *mode;
1226    int overlap;
1227    int channels;
1228
1229    int start, end;
1230
1231    celt_sig preemph_memD[2];
1232
1233    celt_sig *out_mem[2];
1234    celt_sig *decode_mem[2];
1235    celt_sig *overlap_mem[2];
1236
1237    celt_word16 *oldBandE;
1238    
1239    celt_word16 *lpc;
1240
1241    int last_pitch_index;
1242    int loss_count;
1243 };
1244
1245 int check_decoder(const CELTDecoder *st) 
1246 {
1247    if (st==NULL)
1248    {
1249       celt_warning("NULL passed a decoder structure");  
1250       return CELT_INVALID_STATE;
1251    }
1252    if (st->marker == DECODERVALID)
1253       return CELT_OK;
1254    if (st->marker == DECODERFREED)
1255       celt_warning("Referencing a decoder that has already been freed");
1256    else
1257       celt_warning("This is not a valid CELT decoder structure");
1258    return CELT_INVALID_STATE;
1259 }
1260
1261 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1262 {
1263    int C, c;
1264    CELTDecoder *st;
1265
1266    if (check_mode(mode) != CELT_OK)
1267    {
1268       if (error)
1269          *error = CELT_INVALID_MODE;
1270       return NULL;
1271    }
1272
1273    if (channels < 0 || channels > 2)
1274    {
1275       celt_warning("Only mono and stereo supported");
1276       if (error)
1277          *error = CELT_BAD_ARG;
1278       return NULL;
1279    }
1280
1281    C = CHANNELS(channels);
1282    st = celt_alloc(sizeof(CELTDecoder));
1283
1284    if (st==NULL)
1285    {
1286       if (error)
1287          *error = CELT_ALLOC_FAIL;
1288       return NULL;
1289    }
1290
1291    st->marker = DECODERPARTIAL;
1292    st->mode = mode;
1293    st->overlap = mode->overlap;
1294    st->channels = channels;
1295
1296    st->start = 0;
1297    st->end = st->mode->effEBands;
1298
1299    st->decode_mem[0] = (celt_sig*)celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig));
1300    if (C==2)
1301       st->decode_mem[1] = st->decode_mem[0] + (DECODE_BUFFER_SIZE+st->overlap);
1302    for (c=0;c<C;c++)
1303    {
1304       st->out_mem[c] = st->decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1305       st->overlap_mem[c] = st->decode_mem[c]+DECODE_BUFFER_SIZE;
1306    }
1307    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
1308    
1309    st->lpc = (celt_word16*)celt_alloc(C*LPC_ORDER*sizeof(celt_word16));
1310
1311    st->loss_count = 0;
1312
1313    if ((st->decode_mem!=NULL) && (st->out_mem[0]!=NULL) && (st->oldBandE!=NULL) &&
1314          (st->lpc!=NULL))
1315    {
1316       if (error)
1317          *error = CELT_OK;
1318       st->marker = DECODERVALID;
1319       return st;
1320    }
1321    /* If the setup fails for some reason deallocate it. */
1322    celt_decoder_destroy(st);
1323    if (error)
1324       *error = CELT_ALLOC_FAIL;
1325    return NULL;
1326 }
1327
1328 void celt_decoder_destroy(CELTDecoder *st)
1329 {
1330    if (st == NULL)
1331    {
1332       celt_warning("NULL passed to celt_decoder_destroy");
1333       return;
1334    }
1335
1336    if (st->marker == DECODERFREED) 
1337    {
1338       celt_warning("Freeing a decoder which has already been freed"); 
1339       return;
1340    }
1341    
1342    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1343    {
1344       celt_warning("This is not a valid CELT decoder structure");
1345       return;
1346    }
1347    
1348    /*Check_mode is non-fatal here because we can still free
1349      the encoder memory even if the mode is bad, although calling
1350      the free functions in this order is a violation of the API.*/
1351    check_mode(st->mode);
1352    
1353    celt_free(st->decode_mem[0]);
1354    celt_free(st->oldBandE);
1355    celt_free(st->lpc);
1356    
1357    st->marker = DECODERFREED;
1358    
1359    celt_free(st);
1360 }
1361
1362 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1363 {
1364    int c;
1365    int pitch_index;
1366    int overlap = st->mode->overlap;
1367    celt_word16 fade = Q15ONE;
1368    int i, len;
1369    const int C = CHANNELS(st->channels);
1370    int offset;
1371    SAVE_STACK;
1372    
1373    len = N+st->mode->overlap;
1374    
1375    if (st->loss_count == 0)
1376    {
1377       celt_word16 pitch_buf[MAX_PERIOD>>1];
1378       celt_word32 tmp=0;
1379       celt_word32 mem0[2]={0,0};
1380       celt_word16 mem1[2]={0,0};
1381       int len2 = len;
1382       /* FIXME: This is a kludge */
1383       if (len2>MAX_PERIOD>>1)
1384          len2 = MAX_PERIOD>>1;
1385       pitch_downsample(st->out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1386                        C, mem0, mem1);
1387       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1388                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1389       pitch_index = MAX_PERIOD-len2-pitch_index;
1390       st->last_pitch_index = pitch_index;
1391    } else {
1392       pitch_index = st->last_pitch_index;
1393       if (st->loss_count < 5)
1394          fade = QCONST16(.8f,15);
1395       else
1396          fade = 0;
1397    }
1398
1399    for (c=0;c<C;c++)
1400    {
1401       /* FIXME: This is more memory than necessary */
1402       celt_word32 e[2*MAX_PERIOD];
1403       celt_word16 exc[2*MAX_PERIOD];
1404       celt_word32 ac[LPC_ORDER+1];
1405       celt_word16 decay = 1;
1406       celt_word32 S1=0;
1407       celt_word16 mem[LPC_ORDER]={0};
1408
1409       offset = MAX_PERIOD-pitch_index;
1410       for (i=0;i<MAX_PERIOD;i++)
1411          exc[i] = ROUND16(st->out_mem[c][i], SIG_SHIFT);
1412
1413       if (st->loss_count == 0)
1414       {
1415          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1416                         LPC_ORDER, MAX_PERIOD);
1417
1418          /* Noise floor -40 dB */
1419 #ifdef FIXED_POINT
1420          ac[0] += SHR32(ac[0],13);
1421 #else
1422          ac[0] *= 1.0001f;
1423 #endif
1424          /* Lag windowing */
1425          for (i=1;i<=LPC_ORDER;i++)
1426          {
1427             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1428 #ifdef FIXED_POINT
1429             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1430 #else
1431             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1432 #endif
1433          }
1434
1435          _celt_lpc(st->lpc+c*LPC_ORDER, ac, LPC_ORDER);
1436       }
1437       fir(exc, st->lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1438       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1439       /* Check if the waveform is decaying (and if so how fast) */
1440       {
1441          celt_word32 E1=1, E2=1;
1442          int period;
1443          if (pitch_index <= MAX_PERIOD/2)
1444             period = pitch_index;
1445          else
1446             period = MAX_PERIOD/2;
1447          for (i=0;i<period;i++)
1448          {
1449             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1450             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1451          }
1452          if (E1 > E2)
1453             E1 = E2;
1454          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1455       }
1456
1457       /* Copy excitation, taking decay into account */
1458       for (i=0;i<len+st->mode->overlap;i++)
1459       {
1460          if (offset+i >= MAX_PERIOD)
1461          {
1462             offset -= pitch_index;
1463             decay = MULT16_16_Q15(decay, decay);
1464          }
1465          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1466          S1 += SHR32(MULT16_16(st->out_mem[c][offset+i],st->out_mem[c][offset+i]),8);
1467       }
1468
1469       iir(e, st->lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1470
1471       {
1472          celt_word32 S2=0;
1473          for (i=0;i<len+overlap;i++)
1474             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1475          /* This checks for an "explosion" in the synthesis */
1476 #ifdef FIXED_POINT
1477          if (!(S1 > SHR32(S2,2)))
1478 #else
1479          /* Float test is written this way to catch NaNs at the same time */
1480          if (!(S1 > 0.2f*S2))
1481 #endif
1482          {
1483             for (i=0;i<len+overlap;i++)
1484                e[i] = 0;
1485          } else if (S1 < S2)
1486          {
1487             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1488             for (i=0;i<len+overlap;i++)
1489                e[i] = MULT16_16_Q15(ratio, e[i]);
1490          }
1491       }
1492
1493       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1494          st->out_mem[c][i] = st->out_mem[c][N+i];
1495
1496       /* Apply TDAC to the concealed audio so that it blends with the
1497          previous and next frames */
1498       for (i=0;i<overlap/2;i++)
1499       {
1500          celt_word32 tmp1, tmp2;
1501          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1502                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1503          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1504                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1505          tmp1 = MULT16_32_Q15(fade, tmp1);
1506          tmp2 = MULT16_32_Q15(fade, tmp2);
1507          st->out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1508          st->out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1509          st->out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1510          st->out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1511       }
1512       for (i=0;i<N-overlap;i++)
1513          st->out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1514    }
1515
1516    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1517    
1518    st->loss_count++;
1519
1520    RESTORE_STACK;
1521 }
1522
1523 #ifdef FIXED_POINT
1524 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1525 {
1526 #else
1527 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1528 {
1529 #endif
1530    int c, i, N, N4;
1531    int has_fold;
1532    int bits;
1533    ec_dec _dec;
1534    ec_byte_buffer buf;
1535    VARDECL(celt_sig, freq);
1536    VARDECL(celt_norm, X);
1537    VARDECL(celt_ener, bandE);
1538    VARDECL(int, fine_quant);
1539    VARDECL(int, pulses);
1540    VARDECL(int, offsets);
1541    VARDECL(int, fine_priority);
1542    VARDECL(int, tf_res);
1543
1544    int shortBlocks;
1545    int isTransient;
1546    int intra_ener;
1547    int transient_time;
1548    int transient_shift;
1549    int mdct_weight_shift=0;
1550    const int C = CHANNELS(st->channels);
1551    int mdct_weight_pos=0;
1552    int LM, M;
1553    int nbFilledBytes, nbAvailableBytes;
1554    int effEnd;
1555    SAVE_STACK;
1556
1557    if (check_decoder(st) != CELT_OK)
1558       return CELT_INVALID_STATE;
1559
1560    if (check_mode(st->mode) != CELT_OK)
1561       return CELT_INVALID_MODE;
1562
1563    if (pcm==NULL)
1564       return CELT_BAD_ARG;
1565
1566    for (LM=0;LM<4;LM++)
1567       if (st->mode->shortMdctSize<<LM==frame_size)
1568          break;
1569    if (LM>=MAX_CONFIG_SIZES)
1570       return CELT_BAD_ARG;
1571    M=1<<LM;
1572
1573    N = M*st->mode->shortMdctSize;
1574    N4 = (N-st->overlap)>>1;
1575
1576    effEnd = st->end;
1577    if (effEnd > st->mode->effEBands)
1578       effEnd = st->mode->effEBands;
1579
1580    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1581    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1582    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1583    for (c=0;c<C;c++)
1584       for (i=0;i<M*st->mode->eBands[st->start];i++)
1585          X[c*N+i] = 0;
1586    for (c=0;c<C;c++)
1587       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1588          X[c*N+i] = 0;
1589
1590    if (data == NULL)
1591    {
1592       celt_decode_lost(st, pcm, N, LM);
1593       RESTORE_STACK;
1594       return CELT_OK;
1595    }
1596    if (len<0) {
1597      RESTORE_STACK;
1598      return CELT_BAD_ARG;
1599    }
1600    
1601    if (dec == NULL)
1602    {
1603       ec_byte_readinit(&buf,(unsigned char*)data,len);
1604       ec_dec_init(&_dec,&buf);
1605       dec = &_dec;
1606       nbFilledBytes = 0;
1607    } else {
1608       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1609    }
1610    nbAvailableBytes = len-nbFilledBytes;
1611
1612    /* Decode the global flags (first symbols in the stream) */
1613    intra_ener = ec_dec_bit_prob(dec, 8192);
1614    /* Get band energies */
1615    unquant_coarse_energy(st->mode, st->start, st->end, bandE, st->oldBandE, intra_ener, st->mode->prob, dec, C, LM);
1616
1617    isTransient = ec_dec_bit_prob(dec, 8192);
1618
1619    if (isTransient)
1620       shortBlocks = M;
1621    else
1622       shortBlocks = 0;
1623
1624    if (isTransient)
1625    {
1626       transient_shift = ec_dec_uint(dec, 4);
1627       if (transient_shift == 3)
1628       {
1629          int transient_time_quant;
1630          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1631          transient_time_quant = ec_dec_uint(dec, max_time);
1632          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1633       } else {
1634          mdct_weight_shift = transient_shift;
1635          if (mdct_weight_shift && M>2)
1636             mdct_weight_pos = ec_dec_uint(dec, M-1);
1637          transient_shift = 0;
1638          transient_time = 0;
1639       }
1640    } else {
1641       transient_time = -1;
1642       transient_shift = 0;
1643    }
1644
1645    ALLOC(tf_res, st->mode->nbEBands, int);
1646    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1647
1648    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1649    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1650
1651    ALLOC(pulses, st->mode->nbEBands, int);
1652    ALLOC(offsets, st->mode->nbEBands, int);
1653    ALLOC(fine_priority, st->mode->nbEBands, int);
1654
1655    for (i=0;i<st->mode->nbEBands;i++)
1656       offsets[i] = 0;
1657
1658    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1659    ALLOC(fine_quant, st->mode->nbEBands, int);
1660    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1661    /*bits = ec_dec_tell(dec, 0);
1662    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1663    
1664    unquant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, dec, C);
1665
1666    /* Decode fixed codebook */
1667    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1668
1669    unquant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1670
1671    log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
1672
1673    if (mdct_weight_shift)
1674    {
1675       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1676    }
1677
1678    /* Synthesis */
1679    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1680
1681    CELT_MOVE(st->decode_mem[0], st->decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1682    if (C==2)
1683       CELT_MOVE(st->decode_mem[1], st->decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1684
1685    for (c=0;c<C;c++)
1686       for (i=0;i<M*st->mode->eBands[st->start];i++)
1687          freq[c*N+i] = 0;
1688    for (c=0;c<C;c++)
1689       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1690          freq[c*N+i] = 0;
1691
1692    /* Compute inverse MDCTs */
1693    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1694          transient_shift, st->out_mem, st->overlap_mem, C, LM);
1695
1696    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1697    st->loss_count = 0;
1698    RESTORE_STACK;
1699    if (ec_dec_get_error(dec))
1700       return CELT_CORRUPTED_DATA;
1701    else
1702       return CELT_OK;
1703 }
1704
1705 #ifdef FIXED_POINT
1706 #ifndef DISABLE_FLOAT_API
1707 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1708 {
1709    int j, ret, C, N, LM, M;
1710    VARDECL(celt_int16, out);
1711    SAVE_STACK;
1712
1713    if (check_decoder(st) != CELT_OK)
1714       return CELT_INVALID_STATE;
1715
1716    if (check_mode(st->mode) != CELT_OK)
1717       return CELT_INVALID_MODE;
1718
1719    if (pcm==NULL)
1720       return CELT_BAD_ARG;
1721
1722    for (LM=0;LM<4;LM++)
1723       if (st->mode->shortMdctSize<<LM==frame_size)
1724          break;
1725    if (LM>=MAX_CONFIG_SIZES)
1726       return CELT_BAD_ARG;
1727    M=1<<LM;
1728
1729    C = CHANNELS(st->channels);
1730    N = M*st->mode->shortMdctSize;
1731    
1732    ALLOC(out, C*N, celt_int16);
1733    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1734    if (ret==0)
1735       for (j=0;j<C*N;j++)
1736          pcm[j]=out[j]*(1.f/32768.f);
1737      
1738    RESTORE_STACK;
1739    return ret;
1740 }
1741 #endif /*DISABLE_FLOAT_API*/
1742 #else
1743 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1744 {
1745    int j, ret, C, N, LM, M;
1746    VARDECL(celt_sig, out);
1747    SAVE_STACK;
1748
1749    if (check_decoder(st) != CELT_OK)
1750       return CELT_INVALID_STATE;
1751
1752    if (check_mode(st->mode) != CELT_OK)
1753       return CELT_INVALID_MODE;
1754
1755    if (pcm==NULL)
1756       return CELT_BAD_ARG;
1757
1758    for (LM=0;LM<4;LM++)
1759       if (st->mode->shortMdctSize<<LM==frame_size)
1760          break;
1761    if (LM>=MAX_CONFIG_SIZES)
1762       return CELT_BAD_ARG;
1763    M=1<<LM;
1764
1765    C = CHANNELS(st->channels);
1766    N = M*st->mode->shortMdctSize;
1767    ALLOC(out, C*N, celt_sig);
1768
1769    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1770
1771    if (ret==0)
1772       for (j=0;j<C*N;j++)
1773          pcm[j] = FLOAT2INT16 (out[j]);
1774    
1775    RESTORE_STACK;
1776    return ret;
1777 }
1778 #endif
1779
1780 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1781 {
1782    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1783 }
1784
1785 #ifndef DISABLE_FLOAT_API
1786 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1787 {
1788    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1789 }
1790 #endif /* DISABLE_FLOAT_API */
1791
1792 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1793 {
1794    va_list ap;
1795
1796    if (check_decoder(st) != CELT_OK)
1797       return CELT_INVALID_STATE;
1798
1799    va_start(ap, request);
1800    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1801      goto bad_mode;
1802    switch (request)
1803    {
1804       case CELT_GET_MODE_REQUEST:
1805       {
1806          const CELTMode ** value = va_arg(ap, const CELTMode**);
1807          if (value==0)
1808             goto bad_arg;
1809          *value=st->mode;
1810       }
1811       break;
1812       case CELT_SET_START_BAND_REQUEST:
1813       {
1814          celt_int32 value = va_arg(ap, celt_int32);
1815          if (value<0 || value>=st->mode->nbEBands)
1816             goto bad_arg;
1817          st->start = value;
1818       }
1819       break;
1820       case CELT_SET_END_BAND_REQUEST:
1821       {
1822          celt_int32 value = va_arg(ap, celt_int32);
1823          if (value<0 || value>=st->mode->nbEBands)
1824             goto bad_arg;
1825          st->end = value;
1826       }
1827       break;
1828       case CELT_RESET_STATE:
1829       {
1830          const CELTMode *mode = st->mode;
1831          int C = st->channels;
1832
1833          CELT_MEMSET(st->decode_mem[0], 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1834          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1835
1836          CELT_MEMSET(st->preemph_memD, 0, C);
1837
1838          st->loss_count = 0;
1839
1840          CELT_MEMSET(st->lpc, 0, C*LPC_ORDER);
1841       }
1842       break;
1843       default:
1844          goto bad_request;
1845    }
1846    va_end(ap);
1847    return CELT_OK;
1848 bad_mode:
1849   va_end(ap);
1850   return CELT_INVALID_MODE;
1851 bad_arg:
1852    va_end(ap);
1853    return CELT_BAD_ARG;
1854 bad_request:
1855       va_end(ap);
1856   return CELT_UNIMPLEMENTED;
1857 }
1858
1859 const char *celt_strerror(int error)
1860 {
1861    static const char *error_strings[8] = {
1862       "success",
1863       "invalid argument",
1864       "invalid mode",
1865       "internal error",
1866       "corrupted stream",
1867       "request not implemented",
1868       "invalid state",
1869       "memory allocation failed"
1870    };
1871    if (error > 0 || error < -7)
1872       return "unknown error";
1873    else 
1874       return error_strings[-error];
1875 }
1876