Reorganizing the VBR code
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 #define ENCODERVALID   0x4c434554
69 #define ENCODERPARTIAL 0x5445434c
70 #define ENCODERFREED   0x4c004500
71    
72 /** Encoder state 
73  @brief Encoder state
74  */
75 struct CELTEncoder {
76    celt_uint32 marker;
77    const CELTMode *mode;     /**< Mode used by the encoder */
78    int overlap;
79    int channels;
80    
81    int force_intra;
82    int delayedIntra;
83    celt_word16 tonal_average;
84    int fold_decision;
85    celt_word16 gain_prod;
86    celt_word32 frame_max;
87    int start, end;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
96    celt_word32 preemph_memE[2];
97    celt_word32 preemph_memD[2];
98
99    celt_sig *in_mem;
100    celt_sig *out_mem;
101
102    celt_word16 *oldBandE;
103 };
104
105 static int check_encoder(const CELTEncoder *st) 
106 {
107    if (st==NULL)
108    {
109       celt_warning("NULL passed as an encoder structure");  
110       return CELT_INVALID_STATE;
111    }
112    if (st->marker == ENCODERVALID)
113       return CELT_OK;
114    if (st->marker == ENCODERFREED)
115       celt_warning("Referencing an encoder that has already been freed");
116    else
117       celt_warning("This is not a valid CELT encoder structure");
118    return CELT_INVALID_STATE;
119 }
120
121 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
122 {
123    int C;
124    CELTEncoder *st;
125
126    if (check_mode(mode) != CELT_OK)
127    {
128       if (error)
129          *error = CELT_INVALID_MODE;
130       return NULL;
131    }
132
133    if (channels < 0 || channels > 2)
134    {
135       celt_warning("Only mono and stereo supported");
136       if (error)
137          *error = CELT_BAD_ARG;
138       return NULL;
139    }
140
141    C = channels;
142    st = celt_alloc(sizeof(CELTEncoder));
143    
144    if (st==NULL)
145    {
146       if (error)
147          *error = CELT_ALLOC_FAIL;
148       return NULL;
149    }
150    st->marker = ENCODERPARTIAL;
151    st->mode = mode;
152    st->overlap = mode->overlap;
153    st->channels = channels;
154
155    st->start = 0;
156    st->end = st->mode->effEBands;
157
158    st->vbr_rate_norm = 0;
159    st->force_intra  = 0;
160    st->delayedIntra = 1;
161    st->tonal_average = QCONST16(1.f,8);
162    st->fold_decision = 1;
163
164    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig));
165    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig));
166
167    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
168
169    if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL))
170    {
171       if (error)
172          *error = CELT_OK;
173       st->marker   = ENCODERVALID;
174       return st;
175    }
176    /* If the setup fails for some reason deallocate it. */
177    celt_encoder_destroy(st);  
178    if (error)
179       *error = CELT_ALLOC_FAIL;
180    return NULL;
181 }
182
183 void celt_encoder_destroy(CELTEncoder *st)
184 {
185    if (st == NULL)
186    {
187       celt_warning("NULL passed to celt_encoder_destroy");
188       return;
189    }
190
191    if (st->marker == ENCODERFREED)
192    {
193       celt_warning("Freeing an encoder which has already been freed"); 
194       return;
195    }
196
197    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
198    {
199       celt_warning("This is not a valid CELT encoder structure");
200       return;
201    }
202    /*Check_mode is non-fatal here because we can still free
203     the encoder memory even if the mode is bad, although calling
204     the free functions in this order is a violation of the API.*/
205    check_mode(st->mode);
206    
207    celt_free(st->in_mem);
208    celt_free(st->out_mem);
209    celt_free(st->oldBandE);
210    
211    st->marker = ENCODERFREED;
212    
213    celt_free(st);
214 }
215
216 static inline celt_int16 FLOAT2INT16(float x)
217 {
218    x = x*CELT_SIG_SCALE;
219    x = MAX32(x, -32768);
220    x = MIN32(x, 32767);
221    return (celt_int16)float2int(x);
222 }
223
224 static inline celt_word16 SIG2WORD16(celt_sig x)
225 {
226 #ifdef FIXED_POINT
227    x = PSHR32(x, SIG_SHIFT);
228    x = MAX32(x, -32768);
229    x = MIN32(x, 32767);
230    return EXTRACT16(x);
231 #else
232    return (celt_word16)x;
233 #endif
234 }
235
236 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
237                               int *transient_time, int *transient_shift,
238                               celt_word32 *frame_max, int overlap)
239 {
240    int i, n;
241    celt_word32 ratio;
242    celt_word32 threshold;
243    VARDECL(celt_word32, begin);
244    SAVE_STACK;
245    ALLOC(begin, len+1, celt_word32);
246    begin[0] = 0;
247    if (C==1)
248    {
249       for (i=0;i<len;i++)
250          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
251    } else {
252       for (i=0;i<len;i++)
253          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
254                                             ABS32(in[C*i+1])));
255    }
256    n = -1;
257
258    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
259    /* If the following condition isn't met, there's just no way
260       we'll have a transient*/
261    if (*frame_max < threshold)
262    {
263       /* It's likely we have a transient, now find it */
264       for (i=8;i<len-8;i++)
265       {
266          if (begin[i+1] < threshold)
267             n=i;
268       }
269    }
270    if (n<32)
271    {
272       n = -1;
273       ratio = 0;
274    } else {
275       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
276    }
277
278    if (ratio > 45)
279       *transient_shift = 3;
280    else
281       *transient_shift = 0;
282    
283    *transient_time = n;
284    *frame_max = begin[len-overlap];
285
286    RESTORE_STACK;
287    return ratio > 0;
288 }
289
290 /** Apply window and compute the MDCT for all sub-frames and 
291     all channels in a frame */
292 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
293 {
294    const int C = CHANNELS(_C);
295    if (C==1 && !shortBlocks)
296    {
297       const int overlap = OVERLAP(mode);
298       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
299    } else {
300       const int overlap = OVERLAP(mode);
301       int N = mode->shortMdctSize<<LM;
302       int B = 1;
303       int b, c;
304       VARDECL(celt_word32, x);
305       VARDECL(celt_word32, tmp);
306       SAVE_STACK;
307       if (shortBlocks)
308       {
309          /*lookup = &mode->mdct[0];*/
310          N = mode->shortMdctSize;
311          B = shortBlocks;
312       }
313       ALLOC(x, N+overlap, celt_word32);
314       ALLOC(tmp, N, celt_word32);
315       for (c=0;c<C;c++)
316       {
317          for (b=0;b<B;b++)
318          {
319             int j;
320             for (j=0;j<N+overlap;j++)
321                x[j] = in[C*(b*N+j)+c];
322             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
323             /* Interleaving the sub-frames */
324             for (j=0;j<N;j++)
325                out[(j*B+b)+c*N*B] = tmp[j];
326          }
327       }
328       RESTORE_STACK;
329    }
330 }
331
332 /** Compute the IMDCT and apply window for all sub-frames and 
333     all channels in a frame */
334 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X, int transient_time, int transient_shift, celt_sig * restrict out_mem, int _C, int LM)
335 {
336    int c, N4;
337    const int C = CHANNELS(_C);
338    const int N = mode->shortMdctSize<<LM;
339    const int overlap = OVERLAP(mode);
340    N4 = (N-overlap)>>1;
341    for (c=0;c<C;c++)
342    {
343       int j;
344       if (transient_shift==0 && C==1 && !shortBlocks) {
345          clt_mdct_backward(&mode->mdct, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap, mode->maxLM-LM);
346       } else {
347          VARDECL(celt_word32, x);
348          VARDECL(celt_word32, tmp);
349          int b;
350          int N2 = N;
351          int B = 1;
352          int n4offset=0;
353          SAVE_STACK;
354          
355          ALLOC(x, 2*N, celt_word32);
356          ALLOC(tmp, N, celt_word32);
357
358          if (shortBlocks)
359          {
360             /*lookup = &mode->mdct[0];*/
361             N2 = mode->shortMdctSize;
362             B = shortBlocks;
363             n4offset = N4;
364          }
365          /* Prevents problems from the imdct doing the overlap-add */
366          CELT_MEMSET(x+N4, 0, N2);
367
368          for (b=0;b<B;b++)
369          {
370             /* De-interleaving the sub-frames */
371             for (j=0;j<N2;j++)
372                tmp[j] = X[(j*B+b)+c*N2*B];
373             clt_mdct_backward(&mode->mdct, tmp, x+n4offset+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
374          }
375
376          if (transient_shift > 0)
377          {
378 #ifdef FIXED_POINT
379             for (j=0;j<16;j++)
380                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
381             for (j=transient_time;j<N+overlap;j++)
382                x[N4+j] = SHL32(x[N4+j], transient_shift);
383 #else
384             for (j=0;j<16;j++)
385                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
386             for (j=transient_time;j<N+overlap;j++)
387                x[N4+j] *= 1<<transient_shift;
388 #endif
389          }
390          /* The first and last part would need to be set to zero 
391             if we actually wanted to use them. */
392          for (j=0;j<overlap;j++)
393             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
394          for (j=0;j<overlap;j++)
395             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
396          for (j=0;j<2*N4;j++)
397             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
398          RESTORE_STACK;
399       }
400    }
401 }
402
403 static void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
404 {
405    const int C = CHANNELS(_C);
406    int c;
407    for (c=0;c<C;c++)
408    {
409       int j;
410       celt_sig * restrict x;
411       celt_word16  * restrict y;
412       celt_sig m = mem[c];
413       x = &in[C*(MAX_PERIOD-N)+c];
414       y = pcm+c;
415       for (j=0;j<N;j++)
416       {
417          celt_sig tmp = *x + m;
418          m = MULT16_32_Q15(coef[0], tmp)
419            - MULT16_32_Q15(coef[1], *x);
420          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
421          *y = SCALEOUT(SIG2WORD16(tmp));
422          x+=C;
423          y+=C;
424       }
425       mem[c] = m;
426    }
427 }
428
429 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
430                        int end, int N,
431                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
432 {
433    int m, i, c;
434    const int C = CHANNELS(_C);
435    for (c=0;c<C;c++)
436       for (m=start;m<end;m++)
437          for (i=m+c*N;i<(c+1)*N;i+=M)
438 #ifdef FIXED_POINT
439             X[i] = SHR16(X[i], mdct_weight_shift);
440 #else
441             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
442 #endif
443    if (renorm)
444       renormalise_bands(mode, X, end_band, C, M);
445 }
446
447 static signed char tf_select_table[4][8] = {
448       {0, -1, 0, -1,    0,-1, 0,-1},
449       {0, -1, 0, -2,    1, 0, 1 -1},
450       {0, -2, 0, -3,    2, 0, 1 -1},
451       {0, -2, 0, -3,    2, 0, 1 -1},
452 };
453
454 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
455 {
456    int i;
457    celt_word16 threshold;
458    VARDECL(celt_word16, metric);
459    celt_word32 average=0;
460    celt_word32 cost0;
461    celt_word32 cost1;
462    VARDECL(int, path0);
463    VARDECL(int, path1);
464    celt_word16 lambda;
465    int tf_select=0;
466    SAVE_STACK;
467
468    /* FIXME: Should check number of bytes *left* */
469    if (nbCompressedBytes<15*C)
470    {
471       for (i=0;i<len;i++)
472          tf_res[i] = 0;
473       return 0;
474    }
475    if (nbCompressedBytes<40)
476       lambda = QCONST16(5.f, DB_SHIFT);
477    else if (nbCompressedBytes<60)
478       lambda = QCONST16(2.f, DB_SHIFT);
479    else if (nbCompressedBytes<100)
480       lambda = QCONST16(1.f, DB_SHIFT);
481    else
482       lambda = QCONST16(.5f, DB_SHIFT);
483
484    ALLOC(metric, len, celt_word16);
485    ALLOC(path0, len, int);
486    ALLOC(path1, len, int);
487    for (i=0;i<len;i++)
488    {
489       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
490       average += metric[i];
491    }
492    if (C==2)
493    {
494       average = 0;
495       for (i=0;i<len;i++)
496       {
497          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
498          average += metric[i];
499       }
500    }
501    average = DIV32(average, len);
502    /*if (!isTransient)
503       printf ("%f\n", average);*/
504    if (isTransient)
505    {
506       threshold = QCONST16(1.f,DB_SHIFT);
507       tf_select = average > QCONST16(3.f,DB_SHIFT);
508    } else {
509       threshold = QCONST16(.5f,DB_SHIFT);
510       tf_select = average > QCONST16(1.f,DB_SHIFT);
511    }
512    cost0 = 0;
513    cost1 = lambda;
514    /* Viterbi forward pass */
515    for (i=1;i<len;i++)
516    {
517       celt_word32 curr0, curr1;
518       celt_word32 from0, from1;
519
520       from0 = cost0;
521       from1 = cost1 + lambda;
522       if (from0 < from1)
523       {
524          curr0 = from0;
525          path0[i]= 0;
526       } else {
527          curr0 = from1;
528          path0[i]= 1;
529       }
530
531       from0 = cost0 + lambda;
532       from1 = cost1;
533       if (from0 < from1)
534       {
535          curr1 = from0;
536          path1[i]= 0;
537       } else {
538          curr1 = from1;
539          path1[i]= 1;
540       }
541       cost0 = curr0 + (metric[i]-threshold);
542       cost1 = curr1;
543    }
544    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
545    /* Viterbi backward pass to check the decisions */
546    for (i=len-2;i>=0;i--)
547    {
548       if (tf_res[i+1] == 1)
549          tf_res[i] = path1[i+1];
550       else
551          tf_res[i] = path0[i+1];
552    }
553    RESTORE_STACK;
554    return tf_select;
555 }
556
557 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
558 {
559    int curr, i;
560    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
561    {
562       for (i=start;i<end;i++)
563          tf_res[i] = isTransient;
564    } else {
565       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
566       curr = tf_res[start];
567       for (i=start+1;i<end;i++)
568       {
569          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
570          curr = tf_res[i];
571       }
572    }
573    ec_enc_bits(enc, tf_select, 1);
574    for (i=start;i<end;i++)
575       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
576 }
577
578 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
579 {
580    int i, curr, tf_select;
581    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
582    {
583       for (i=start;i<end;i++)
584          tf_res[i] = isTransient;
585    } else {
586       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
587       curr = tf_res[start];
588       for (i=start+1;i<end;i++)
589       {
590          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
591          curr = tf_res[i];
592       }
593    }
594    tf_select = ec_dec_bits(dec, 1);
595    for (i=start;i<end;i++)
596       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
597 }
598
599 #ifdef FIXED_POINT
600 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
601 {
602 #else
603 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
604 {
605 #endif
606    int i, c, N, NN, N4;
607    int bits;
608    int has_fold=1;
609    ec_byte_buffer buf;
610    ec_enc         _enc;
611    VARDECL(celt_sig, in);
612    VARDECL(celt_sig, freq);
613    VARDECL(celt_norm, X);
614    VARDECL(celt_ener, bandE);
615    VARDECL(celt_word16, bandLogE);
616    VARDECL(int, fine_quant);
617    VARDECL(celt_word16, error);
618    VARDECL(int, pulses);
619    VARDECL(int, offsets);
620    VARDECL(int, fine_priority);
621    VARDECL(int, tf_res);
622    int intra_ener = 0;
623    int shortBlocks=0;
624    int isTransient=0;
625    int transient_time, transient_time_quant;
626    int transient_shift;
627    int resynth;
628    const int C = CHANNELS(st->channels);
629    int mdct_weight_shift = 0;
630    int mdct_weight_pos=0;
631    int LM, M;
632    int tf_select;
633    celt_word16 max_decay;
634    int nbFilledBytes, nbAvailableBytes;
635    int effEnd;
636    SAVE_STACK;
637
638    if (check_encoder(st) != CELT_OK)
639       return CELT_INVALID_STATE;
640
641    if (check_mode(st->mode) != CELT_OK)
642       return CELT_INVALID_MODE;
643
644    if (nbCompressedBytes<0 || pcm==NULL)
645      return CELT_BAD_ARG;
646
647    for (LM=0;LM<4;LM++)
648       if (st->mode->shortMdctSize<<LM==frame_size)
649          break;
650    if (LM>=MAX_CONFIG_SIZES)
651       return CELT_BAD_ARG;
652    M=1<<LM;
653
654    if (enc==NULL)
655    {
656       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
657       ec_enc_init(&_enc,&buf);
658       enc = &_enc;
659       nbFilledBytes=0;
660    } else {
661       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
662    }
663    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
664
665    effEnd = st->end;
666    if (effEnd > st->mode->effEBands)
667       effEnd = st->mode->effEBands;
668
669    N = M*st->mode->shortMdctSize;
670    N4 = (N-st->overlap)>>1;
671    ALLOC(in, 2*C*N-2*C*N4, celt_sig);
672
673    CELT_COPY(in, st->in_mem, C*st->overlap);
674    for (c=0;c<C;c++)
675    {
676       const celt_word16 * restrict pcmp = pcm+c;
677       celt_sig * restrict inp = in+C*st->overlap+c;
678       for (i=0;i<N;i++)
679       {
680          /* Apply pre-emphasis */
681          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
682          *inp = tmp + st->preemph_memE[c];
683          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
684                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
685          inp += C;
686          pcmp += C;
687       }
688    }
689    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
690
691    /* Transient handling */
692    transient_time = -1;
693    transient_time_quant = -1;
694    transient_shift = 0;
695    isTransient = 0;
696
697    resynth = optional_resynthesis!=NULL;
698
699    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
700    {
701 #ifndef FIXED_POINT
702       float gain_1;
703 #endif
704       /* Apply the inverse shaping window */
705       if (transient_shift)
706       {
707          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
708          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
709 #ifdef FIXED_POINT
710          for (c=0;c<C;c++)
711             for (i=0;i<16;i++)
712                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
713          for (c=0;c<C;c++)
714             for (i=transient_time;i<N+st->overlap;i++)
715                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
716 #else
717          for (c=0;c<C;c++)
718             for (i=0;i<16;i++)
719                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
720          gain_1 = 1.f/(1<<transient_shift);
721          for (c=0;c<C;c++)
722             for (i=transient_time;i<N+st->overlap;i++)
723                in[C*i+c] *= gain_1;
724 #endif
725       }
726       isTransient = 1;
727       has_fold = 1;
728    }
729
730    if (isTransient)
731       shortBlocks = M;
732    else
733       shortBlocks = 0;
734
735    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
736    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
737    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
738    /* Compute MDCTs */
739    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
740
741    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
742
743    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
744
745    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
746
747    /* Band normalisation */
748    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
749    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
750       has_fold = 0;
751
752    /* Don't use intra energy when we're operating at low bit-rate */
753    intra_ener = st->force_intra || (st->delayedIntra && nbAvailableBytes > st->end);
754    if (shortBlocks || intra_decision(bandLogE, st->oldBandE, st->start, effEnd, st->mode->nbEBands, C))
755       st->delayedIntra = 1;
756    else
757       st->delayedIntra = 0;
758
759    NN = M*st->mode->eBands[effEnd];
760    if (shortBlocks && !transient_shift)
761    {
762       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
763       int m;
764       for (c=0;c<C;c++)
765       {
766          m=0;
767          do {
768             celt_word32 tmp=0;
769             for (i=m+c*N;i<c*N+NN;i+=M)
770                tmp += ABS32(X[i]);
771             sum[m++] += tmp;
772          } while (m<M);
773       }
774       m=0;
775 #ifdef FIXED_POINT
776       do {
777          if (SHR32(sum[m+1],3) > sum[m])
778          {
779             mdct_weight_shift=2;
780             mdct_weight_pos = m;
781          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
782          {
783             mdct_weight_shift=1;
784             mdct_weight_pos = m;
785          }
786          m++;
787       } while (m<M-1);
788 #else
789       do {
790          if (sum[m+1] > 8*sum[m])
791          {
792             mdct_weight_shift=2;
793             mdct_weight_pos = m;
794          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
795          {
796             mdct_weight_shift=1;
797             mdct_weight_pos = m;
798          }
799          m++;
800       } while (m<M-1);
801 #endif
802       if (mdct_weight_shift)
803          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
804    }
805
806    /* Encode the global flags using a simple probability model
807       (first symbols in the stream) */
808    ec_enc_bit_prob(enc, intra_ener, 8192);
809    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
810    ec_enc_bit_prob(enc, has_fold>>1, 8192);
811    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
812
813    if (shortBlocks)
814    {
815       if (transient_shift)
816       {
817          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
818          ec_enc_uint(enc, transient_shift, 4);
819          ec_enc_uint(enc, transient_time_quant, max_time);
820       } else {
821          ec_enc_uint(enc, mdct_weight_shift, 4);
822          if (mdct_weight_shift && M!=2)
823             ec_enc_uint(enc, mdct_weight_pos, M-1);
824       }
825    }
826
827    ALLOC(tf_res, st->mode->nbEBands, int);
828    tf_select = tf_analysis(bandLogE, st->oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
829    for (i=effEnd;i<st->end;i++)
830       tf_res[i] = tf_res[effEnd-1];
831
832    ALLOC(error, C*st->mode->nbEBands, celt_word16);
833
834 #ifdef FIXED_POINT
835       max_decay = MIN32(QCONST16(16,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
836 #else
837    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
838 #endif
839    quant_coarse_energy(st->mode, st->start, st->end, bandLogE, st->oldBandE, nbCompressedBytes*8, intra_ener, st->mode->prob, error, enc, C, LM, max_decay);
840
841    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
842
843    /* Variable bitrate */
844    if (st->vbr_rate_norm>0)
845    {
846      celt_word16 alpha;
847      celt_int32 delta;
848      /* The target rate in 16th bits per frame */
849      celt_int32 vbr_rate;
850      celt_int32 target;
851      celt_int32 vbr_bound, max_allowed;
852
853      vbr_rate = M*st->vbr_rate_norm;
854
855      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
856      vbr_bound = vbr_rate;
857      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
858      if (max_allowed < 4)
859         max_allowed = 4;
860      if (max_allowed < nbAvailableBytes)
861         nbAvailableBytes = max_allowed;
862      target=vbr_rate;
863
864      /* Shortblocks get a large boost in bitrate, but since they 
865         are uncommon long blocks are not greatly effected */
866      if (shortBlocks)
867        target*=2;
868      else if (M > 1)
869        target-=(target+14)/28;
870
871      /* The average energy is removed from the target and the actual 
872         energy added*/
873      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
874
875      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
876      target=IMIN(nbAvailableBytes,target);
877      /* Make the adaptation coef (alpha) higher at the beginning */
878      if (st->vbr_count < 990)
879      {
880         st->vbr_count++;
881         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
882         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
883      } else
884         alpha = QCONST16(.001f,15);
885
886      /* By how much did we "miss" the target on that frame */
887      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
888      /* How many bits have we used in excess of what we're allowed */
889      st->vbr_reservoir += delta;
890      /*printf ("%d\n", st->vbr_reservoir);*/
891
892      /* Compute the offset we need to apply in order to reach the target */
893      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
894      st->vbr_offset = -st->vbr_drift;
895      /*printf ("%d\n", st->vbr_drift);*/
896
897      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
898      if (st->vbr_reservoir < 0)
899      {
900         /* We're under the min value -- increase rate */
901         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
902         st->vbr_reservoir += adjust*(8<<BITRES);
903         target += adjust;
904         /*printf ("+%d\n", adjust);*/
905      }
906      if (target < nbAvailableBytes)
907         nbAvailableBytes = target;
908      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
909
910      /* This moves the raw bits to take into account the new compressed size */
911      ec_byte_shrink(&buf, nbCompressedBytes);
912    }
913
914    /* Bit allocation */
915    ALLOC(fine_quant, st->mode->nbEBands, int);
916    ALLOC(pulses, st->mode->nbEBands, int);
917    ALLOC(offsets, st->mode->nbEBands, int);
918    ALLOC(fine_priority, st->mode->nbEBands, int);
919
920    for (i=0;i<st->mode->nbEBands;i++)
921       offsets[i] = 0;
922    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
923    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
924
925    quant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, enc, C);
926
927 #ifdef MEASURE_NORM_MSE
928    float X0[3000];
929    float bandE0[60];
930    for (c=0;c<C;c++)
931       for (i=0;i<N;i++)
932          X0[i+c*N] = X[i+c*N];
933    for (i=0;i<C*st->mode->nbEBands;i++)
934       bandE0[i] = bandE[i];
935 #endif
936
937    /* Residual quantisation */
938    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
939
940    quant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
941
942    /* Re-synthesis of the coded audio if required */
943    if (resynth)
944    {
945       log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
946
947 #ifdef MEASURE_NORM_MSE
948       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
949 #endif
950
951       if (mdct_weight_shift)
952       {
953          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
954       }
955
956       /* Synthesis */
957       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
958
959       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
960
961       for (c=0;c<C;c++)
962          for (i=0;i<M*st->mode->eBands[st->start];i++)
963             freq[c*N+i] = 0;
964       for (c=0;c<C;c++)
965          for (i=M*st->mode->eBands[st->end];i<N;i++)
966             freq[c*N+i] = 0;
967
968       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
969
970       /* De-emphasis and put everything back at the right place 
971          in the synthesis history */
972       if (optional_resynthesis != NULL) {
973          deemphasis(st->out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
974
975       }
976    }
977
978    /* If there's any room left (can only happen for very high rates),
979       fill it with zeros */
980    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
981       ec_enc_bits(enc, 0, 8);
982    ec_enc_done(enc);
983    
984    RESTORE_STACK;
985    if (ec_enc_get_error(enc))
986       return CELT_CORRUPTED_DATA;
987    else
988       return nbCompressedBytes;
989 }
990
991 #ifdef FIXED_POINT
992 #ifndef DISABLE_FLOAT_API
993 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
994 {
995    int j, ret, C, N, LM, M;
996    VARDECL(celt_int16, in);
997    SAVE_STACK;
998
999    if (check_encoder(st) != CELT_OK)
1000       return CELT_INVALID_STATE;
1001
1002    if (check_mode(st->mode) != CELT_OK)
1003       return CELT_INVALID_MODE;
1004
1005    if (pcm==NULL)
1006       return CELT_BAD_ARG;
1007
1008    for (LM=0;LM<4;LM++)
1009       if (st->mode->shortMdctSize<<LM==frame_size)
1010          break;
1011    if (LM>=MAX_CONFIG_SIZES)
1012       return CELT_BAD_ARG;
1013    M=1<<LM;
1014
1015    C = CHANNELS(st->channels);
1016    N = M*st->mode->shortMdctSize;
1017    ALLOC(in, C*N, celt_int16);
1018
1019    for (j=0;j<C*N;j++)
1020      in[j] = FLOAT2INT16(pcm[j]);
1021
1022    if (optional_resynthesis != NULL) {
1023      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1024       for (j=0;j<C*N;j++)
1025          optional_resynthesis[j]=in[j]*(1.f/32768.f);
1026    } else {
1027      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1028    }
1029    RESTORE_STACK;
1030    return ret;
1031
1032 }
1033 #endif /*DISABLE_FLOAT_API*/
1034 #else
1035 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1036 {
1037    int j, ret, C, N, LM, M;
1038    VARDECL(celt_sig, in);
1039    SAVE_STACK;
1040
1041    if (check_encoder(st) != CELT_OK)
1042       return CELT_INVALID_STATE;
1043
1044    if (check_mode(st->mode) != CELT_OK)
1045       return CELT_INVALID_MODE;
1046
1047    if (pcm==NULL)
1048       return CELT_BAD_ARG;
1049
1050    for (LM=0;LM<4;LM++)
1051       if (st->mode->shortMdctSize<<LM==frame_size)
1052          break;
1053    if (LM>=MAX_CONFIG_SIZES)
1054       return CELT_BAD_ARG;
1055    M=1<<LM;
1056
1057    C=CHANNELS(st->channels);
1058    N=M*st->mode->shortMdctSize;
1059    ALLOC(in, C*N, celt_sig);
1060    for (j=0;j<C*N;j++) {
1061      in[j] = SCALEOUT(pcm[j]);
1062    }
1063
1064    if (optional_resynthesis != NULL) {
1065       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1066       for (j=0;j<C*N;j++)
1067          optional_resynthesis[j] = FLOAT2INT16(in[j]);
1068    } else {
1069       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1070    }
1071    RESTORE_STACK;
1072    return ret;
1073 }
1074 #endif
1075
1076 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1077 {
1078    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1079 }
1080
1081 #ifndef DISABLE_FLOAT_API
1082 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1083 {
1084    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1085 }
1086 #endif /* DISABLE_FLOAT_API */
1087
1088 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1089 {
1090    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1091 }
1092
1093 #ifndef DISABLE_FLOAT_API
1094 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1095 {
1096    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1097 }
1098 #endif /* DISABLE_FLOAT_API */
1099
1100
1101 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1102 {
1103    va_list ap;
1104    
1105    if (check_encoder(st) != CELT_OK)
1106       return CELT_INVALID_STATE;
1107
1108    va_start(ap, request);
1109    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1110      goto bad_mode;
1111    switch (request)
1112    {
1113       case CELT_GET_MODE_REQUEST:
1114       {
1115          const CELTMode ** value = va_arg(ap, const CELTMode**);
1116          if (value==0)
1117             goto bad_arg;
1118          *value=st->mode;
1119       }
1120       break;
1121       case CELT_SET_COMPLEXITY_REQUEST:
1122       {
1123          int value = va_arg(ap, celt_int32);
1124          if (value<0 || value>10)
1125             goto bad_arg;
1126       }
1127       break;
1128       case CELT_SET_START_BAND_REQUEST:
1129       {
1130          celt_int32 value = va_arg(ap, celt_int32);
1131          if (value<0 || value>=st->mode->nbEBands)
1132             goto bad_arg;
1133          st->start = value;
1134       }
1135       break;
1136       case CELT_SET_END_BAND_REQUEST:
1137       {
1138          celt_int32 value = va_arg(ap, celt_int32);
1139          if (value<0 || value>=st->mode->nbEBands)
1140             goto bad_arg;
1141          st->end = value;
1142       }
1143       break;
1144       case CELT_SET_PREDICTION_REQUEST:
1145       {
1146          int value = va_arg(ap, celt_int32);
1147          if (value<0 || value>2)
1148             goto bad_arg;
1149          if (value==0)
1150          {
1151             st->force_intra   = 1;
1152          } else if (value==1) {
1153             st->force_intra   = 0;
1154          } else {
1155             st->force_intra   = 0;
1156          }   
1157       }
1158       break;
1159       case CELT_SET_VBR_RATE_REQUEST:
1160       {
1161          celt_int32 value = va_arg(ap, celt_int32);
1162          int frame_rate;
1163          int N = st->mode->shortMdctSize;
1164          if (value<0)
1165             goto bad_arg;
1166          if (value>3072000)
1167             value = 3072000;
1168          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1169          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1170       }
1171       break;
1172       case CELT_RESET_STATE:
1173       {
1174          const CELTMode *mode = st->mode;
1175          int C = st->channels;
1176
1177          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1178          CELT_MEMSET(st->out_mem, 0, (MAX_PERIOD+st->overlap)*C);
1179
1180          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1181
1182          CELT_MEMSET(st->preemph_memE, 0, C);
1183          CELT_MEMSET(st->preemph_memD, 0, C);
1184          st->delayedIntra = 1;
1185
1186          st->fold_decision = 1;
1187          st->tonal_average = QCONST16(1.f,8);
1188          st->gain_prod = 0;
1189          st->vbr_reservoir = 0;
1190          st->vbr_drift = 0;
1191          st->vbr_offset = 0;
1192          st->vbr_count = 0;
1193          st->frame_max = 0;
1194       }
1195       break;
1196       default:
1197          goto bad_request;
1198    }
1199    va_end(ap);
1200    return CELT_OK;
1201 bad_mode:
1202   va_end(ap);
1203   return CELT_INVALID_MODE;
1204 bad_arg:
1205    va_end(ap);
1206    return CELT_BAD_ARG;
1207 bad_request:
1208    va_end(ap);
1209    return CELT_UNIMPLEMENTED;
1210 }
1211
1212 /**********************************************************************/
1213 /*                                                                    */
1214 /*                             DECODER                                */
1215 /*                                                                    */
1216 /**********************************************************************/
1217 #define DECODE_BUFFER_SIZE 2048
1218
1219 #define DECODERVALID   0x4c434454
1220 #define DECODERPARTIAL 0x5444434c
1221 #define DECODERFREED   0x4c004400
1222
1223 /** Decoder state 
1224  @brief Decoder state
1225  */
1226 struct CELTDecoder {
1227    celt_uint32 marker;
1228    const CELTMode *mode;
1229    int overlap;
1230    int channels;
1231
1232    int start, end;
1233
1234    celt_sig preemph_memD[2];
1235
1236    celt_sig *out_mem;
1237    celt_word32 *decode_mem;
1238
1239    celt_word16 *oldBandE;
1240    
1241    celt_word16 *lpc;
1242
1243    int last_pitch_index;
1244    int loss_count;
1245 };
1246
1247 int check_decoder(const CELTDecoder *st) 
1248 {
1249    if (st==NULL)
1250    {
1251       celt_warning("NULL passed a decoder structure");  
1252       return CELT_INVALID_STATE;
1253    }
1254    if (st->marker == DECODERVALID)
1255       return CELT_OK;
1256    if (st->marker == DECODERFREED)
1257       celt_warning("Referencing a decoder that has already been freed");
1258    else
1259       celt_warning("This is not a valid CELT decoder structure");
1260    return CELT_INVALID_STATE;
1261 }
1262
1263 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1264 {
1265    int C;
1266    CELTDecoder *st;
1267
1268    if (check_mode(mode) != CELT_OK)
1269    {
1270       if (error)
1271          *error = CELT_INVALID_MODE;
1272       return NULL;
1273    }
1274
1275    if (channels < 0 || channels > 2)
1276    {
1277       celt_warning("Only mono and stereo supported");
1278       if (error)
1279          *error = CELT_BAD_ARG;
1280       return NULL;
1281    }
1282
1283    C = CHANNELS(channels);
1284    st = celt_alloc(sizeof(CELTDecoder));
1285
1286    if (st==NULL)
1287    {
1288       if (error)
1289          *error = CELT_ALLOC_FAIL;
1290       return NULL;
1291    }
1292
1293    st->marker = DECODERPARTIAL;
1294    st->mode = mode;
1295    st->overlap = mode->overlap;
1296    st->channels = channels;
1297
1298    st->start = 0;
1299    st->end = st->mode->effEBands;
1300
1301    st->decode_mem = (celt_sig*)celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig));
1302    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
1303    
1304    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
1305    
1306    st->lpc = (celt_word16*)celt_alloc(C*LPC_ORDER*sizeof(celt_word16));
1307
1308    st->loss_count = 0;
1309
1310    if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
1311          (st->lpc!=NULL))
1312    {
1313       if (error)
1314          *error = CELT_OK;
1315       st->marker = DECODERVALID;
1316       return st;
1317    }
1318    /* If the setup fails for some reason deallocate it. */
1319    celt_decoder_destroy(st);
1320    if (error)
1321       *error = CELT_ALLOC_FAIL;
1322    return NULL;
1323 }
1324
1325 void celt_decoder_destroy(CELTDecoder *st)
1326 {
1327    if (st == NULL)
1328    {
1329       celt_warning("NULL passed to celt_decoder_destroy");
1330       return;
1331    }
1332
1333    if (st->marker == DECODERFREED) 
1334    {
1335       celt_warning("Freeing a decoder which has already been freed"); 
1336       return;
1337    }
1338    
1339    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1340    {
1341       celt_warning("This is not a valid CELT decoder structure");
1342       return;
1343    }
1344    
1345    /*Check_mode is non-fatal here because we can still free
1346      the encoder memory even if the mode is bad, although calling
1347      the free functions in this order is a violation of the API.*/
1348    check_mode(st->mode);
1349    
1350    celt_free(st->decode_mem);
1351    celt_free(st->oldBandE);
1352    celt_free(st->lpc);
1353    
1354    st->marker = DECODERFREED;
1355    
1356    celt_free(st);
1357 }
1358
1359 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1360 {
1361    int c;
1362    int pitch_index;
1363    int overlap = st->mode->overlap;
1364    celt_word16 fade = Q15ONE;
1365    int i, len;
1366    const int C = CHANNELS(st->channels);
1367    int offset;
1368    SAVE_STACK;
1369    
1370    len = N+st->mode->overlap;
1371    
1372    if (st->loss_count == 0)
1373    {
1374       celt_word16 pitch_buf[MAX_PERIOD>>1];
1375       celt_word32 tmp=0;
1376       celt_word32 mem0[2]={0,0};
1377       celt_word16 mem1[2]={0,0};
1378       int len2 = len;
1379       /* FIXME: This is a kludge */
1380       if (len2>MAX_PERIOD>>1)
1381          len2 = MAX_PERIOD>>1;
1382       pitch_downsample(st->out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1383                        C, mem0, mem1);
1384       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1385                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1386       pitch_index = MAX_PERIOD-len2-pitch_index;
1387       st->last_pitch_index = pitch_index;
1388    } else {
1389       pitch_index = st->last_pitch_index;
1390       if (st->loss_count < 5)
1391          fade = QCONST16(.8f,15);
1392       else
1393          fade = 0;
1394    }
1395
1396    for (c=0;c<C;c++)
1397    {
1398       /* FIXME: This is more memory than necessary */
1399       celt_word32 e[2*MAX_PERIOD];
1400       celt_word16 exc[2*MAX_PERIOD];
1401       celt_word32 ac[LPC_ORDER+1];
1402       celt_word16 decay = 1;
1403       celt_word32 S1=0;
1404       celt_word16 mem[LPC_ORDER]={0};
1405
1406       offset = MAX_PERIOD-pitch_index;
1407       for (i=0;i<MAX_PERIOD;i++)
1408          exc[i] = ROUND16(st->out_mem[i*C+c], SIG_SHIFT);
1409
1410       if (st->loss_count == 0)
1411       {
1412          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1413                         LPC_ORDER, MAX_PERIOD);
1414
1415          /* Noise floor -40 dB */
1416 #ifdef FIXED_POINT
1417          ac[0] += SHR32(ac[0],13);
1418 #else
1419          ac[0] *= 1.0001f;
1420 #endif
1421          /* Lag windowing */
1422          for (i=1;i<=LPC_ORDER;i++)
1423          {
1424             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1425 #ifdef FIXED_POINT
1426             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1427 #else
1428             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1429 #endif
1430          }
1431
1432          _celt_lpc(st->lpc+c*LPC_ORDER, ac, LPC_ORDER);
1433       }
1434       fir(exc, st->lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1435       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1436       /* Check if the waveform is decaying (and if so how fast) */
1437       {
1438          celt_word32 E1=1, E2=1;
1439          int period;
1440          if (pitch_index <= MAX_PERIOD/2)
1441             period = pitch_index;
1442          else
1443             period = MAX_PERIOD/2;
1444          for (i=0;i<period;i++)
1445          {
1446             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1447             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1448          }
1449          if (E1 > E2)
1450             E1 = E2;
1451          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1452       }
1453
1454       /* Copy excitation, taking decay into account */
1455       for (i=0;i<len+st->mode->overlap;i++)
1456       {
1457          if (offset+i >= MAX_PERIOD)
1458          {
1459             offset -= pitch_index;
1460             decay = MULT16_16_Q15(decay, decay);
1461          }
1462          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1463          S1 += SHR32(MULT16_16(st->out_mem[offset+i],st->out_mem[offset+i]),8);
1464       }
1465
1466       iir(e, st->lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1467
1468       {
1469          celt_word32 S2=0;
1470          for (i=0;i<len+overlap;i++)
1471             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1472          /* This checks for an "explosion" in the synthesis */
1473 #ifdef FIXED_POINT
1474          if (!(S1 > SHR32(S2,2)))
1475 #else
1476          /* Float test is written this way to catch NaNs at the same time */
1477          if (!(S1 > 0.2f*S2))
1478 #endif
1479          {
1480             for (i=0;i<len+overlap;i++)
1481                e[i] = 0;
1482          } else if (S1 < S2)
1483          {
1484             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1485             for (i=0;i<len+overlap;i++)
1486                e[i] = MULT16_16_Q15(ratio, e[i]);
1487          }
1488       }
1489
1490       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1491          st->out_mem[C*i+c] = st->out_mem[C*(N+i)+c];
1492
1493       /* Apply TDAC to the concealed audio so that it blends with the
1494          previous and next frames */
1495       for (i=0;i<overlap/2;i++)
1496       {
1497          celt_word32 tmp1, tmp2;
1498          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1499                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1500          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1501                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1502          tmp1 = MULT16_32_Q15(fade, tmp1);
1503          tmp2 = MULT16_32_Q15(fade, tmp2);
1504          st->out_mem[C*(MAX_PERIOD+i)+c] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1505          st->out_mem[C*(MAX_PERIOD+overlap-i-1)+c] = MULT16_32_Q15(st->mode->window[i], tmp2);
1506          st->out_mem[C*(MAX_PERIOD-N+i)+c] += MULT16_32_Q15(st->mode->window[i], tmp1);
1507          st->out_mem[C*(MAX_PERIOD-N+overlap-i-1)+c] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1508       }
1509       for (i=0;i<N-overlap;i++)
1510          st->out_mem[C*(MAX_PERIOD-N+overlap+i)+c] = MULT16_32_Q15(fade, e[overlap+i]);
1511    }
1512
1513    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1514    
1515    st->loss_count++;
1516
1517    RESTORE_STACK;
1518 }
1519
1520 #ifdef FIXED_POINT
1521 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1522 {
1523 #else
1524 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1525 {
1526 #endif
1527    int c, i, N, N4;
1528    int has_fold;
1529    int bits;
1530    ec_dec _dec;
1531    ec_byte_buffer buf;
1532    VARDECL(celt_sig, freq);
1533    VARDECL(celt_norm, X);
1534    VARDECL(celt_ener, bandE);
1535    VARDECL(int, fine_quant);
1536    VARDECL(int, pulses);
1537    VARDECL(int, offsets);
1538    VARDECL(int, fine_priority);
1539    VARDECL(int, tf_res);
1540
1541    int shortBlocks;
1542    int isTransient;
1543    int intra_ener;
1544    int transient_time;
1545    int transient_shift;
1546    int mdct_weight_shift=0;
1547    const int C = CHANNELS(st->channels);
1548    int mdct_weight_pos=0;
1549    int LM, M;
1550    int nbFilledBytes, nbAvailableBytes;
1551    int effEnd;
1552    SAVE_STACK;
1553
1554    if (check_decoder(st) != CELT_OK)
1555       return CELT_INVALID_STATE;
1556
1557    if (check_mode(st->mode) != CELT_OK)
1558       return CELT_INVALID_MODE;
1559
1560    if (pcm==NULL)
1561       return CELT_BAD_ARG;
1562
1563    for (LM=0;LM<4;LM++)
1564       if (st->mode->shortMdctSize<<LM==frame_size)
1565          break;
1566    if (LM>=MAX_CONFIG_SIZES)
1567       return CELT_BAD_ARG;
1568    M=1<<LM;
1569
1570    N = M*st->mode->shortMdctSize;
1571    N4 = (N-st->overlap)>>1;
1572
1573    effEnd = st->end;
1574    if (effEnd > st->mode->effEBands)
1575       effEnd = st->mode->effEBands;
1576
1577    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1578    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1579    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1580    for (c=0;c<C;c++)
1581       for (i=0;i<M*st->mode->eBands[st->start];i++)
1582          X[c*N+i] = 0;
1583    for (c=0;c<C;c++)
1584       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1585          X[c*N+i] = 0;
1586
1587    if (data == NULL)
1588    {
1589       celt_decode_lost(st, pcm, N, LM);
1590       RESTORE_STACK;
1591       return CELT_OK;
1592    }
1593    if (len<0) {
1594      RESTORE_STACK;
1595      return CELT_BAD_ARG;
1596    }
1597    
1598    if (dec == NULL)
1599    {
1600       ec_byte_readinit(&buf,(unsigned char*)data,len);
1601       ec_dec_init(&_dec,&buf);
1602       dec = &_dec;
1603       nbFilledBytes = 0;
1604    } else {
1605       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1606    }
1607    nbAvailableBytes = len-nbFilledBytes;
1608
1609    /* Decode the global flags (first symbols in the stream) */
1610    intra_ener = ec_dec_bit_prob(dec, 8192);
1611    isTransient = ec_dec_bit_prob(dec, 8192);
1612    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1613    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1614
1615    if (isTransient)
1616       shortBlocks = M;
1617    else
1618       shortBlocks = 0;
1619
1620    if (isTransient)
1621    {
1622       transient_shift = ec_dec_uint(dec, 4);
1623       if (transient_shift == 3)
1624       {
1625          int transient_time_quant;
1626          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1627          transient_time_quant = ec_dec_uint(dec, max_time);
1628          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1629       } else {
1630          mdct_weight_shift = transient_shift;
1631          if (mdct_weight_shift && M>2)
1632             mdct_weight_pos = ec_dec_uint(dec, M-1);
1633          transient_shift = 0;
1634          transient_time = 0;
1635       }
1636    } else {
1637       transient_time = -1;
1638       transient_shift = 0;
1639    }
1640    
1641    ALLOC(fine_quant, st->mode->nbEBands, int);
1642    /* Get band energies */
1643    unquant_coarse_energy(st->mode, st->start, st->end, bandE, st->oldBandE, intra_ener, st->mode->prob, dec, C, LM);
1644
1645    ALLOC(tf_res, st->mode->nbEBands, int);
1646    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1647
1648    ALLOC(pulses, st->mode->nbEBands, int);
1649    ALLOC(offsets, st->mode->nbEBands, int);
1650    ALLOC(fine_priority, st->mode->nbEBands, int);
1651
1652    for (i=0;i<st->mode->nbEBands;i++)
1653       offsets[i] = 0;
1654
1655    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1656    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1657    /*bits = ec_dec_tell(dec, 0);
1658    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1659    
1660    unquant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, dec, C);
1661
1662    /* Decode fixed codebook */
1663    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1664
1665    unquant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1666
1667    log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
1668
1669    if (mdct_weight_shift)
1670    {
1671       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1672    }
1673
1674    /* Synthesis */
1675    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1676
1677
1678    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1679
1680    for (c=0;c<C;c++)
1681       for (i=0;i<M*st->mode->eBands[st->start];i++)
1682          freq[c*N+i] = 0;
1683    for (c=0;c<C;c++)
1684       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1685          freq[c*N+i] = 0;
1686
1687    /* Compute inverse MDCTs */
1688    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
1689
1690    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1691    st->loss_count = 0;
1692    RESTORE_STACK;
1693    if (ec_dec_get_error(dec))
1694       return CELT_CORRUPTED_DATA;
1695    else
1696       return CELT_OK;
1697 }
1698
1699 #ifdef FIXED_POINT
1700 #ifndef DISABLE_FLOAT_API
1701 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1702 {
1703    int j, ret, C, N, LM, M;
1704    VARDECL(celt_int16, out);
1705    SAVE_STACK;
1706
1707    if (check_decoder(st) != CELT_OK)
1708       return CELT_INVALID_STATE;
1709
1710    if (check_mode(st->mode) != CELT_OK)
1711       return CELT_INVALID_MODE;
1712
1713    if (pcm==NULL)
1714       return CELT_BAD_ARG;
1715
1716    for (LM=0;LM<4;LM++)
1717       if (st->mode->shortMdctSize<<LM==frame_size)
1718          break;
1719    if (LM>=MAX_CONFIG_SIZES)
1720       return CELT_BAD_ARG;
1721    M=1<<LM;
1722
1723    C = CHANNELS(st->channels);
1724    N = M*st->mode->shortMdctSize;
1725    
1726    ALLOC(out, C*N, celt_int16);
1727    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1728    if (ret==0)
1729       for (j=0;j<C*N;j++)
1730          pcm[j]=out[j]*(1.f/32768.f);
1731      
1732    RESTORE_STACK;
1733    return ret;
1734 }
1735 #endif /*DISABLE_FLOAT_API*/
1736 #else
1737 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1738 {
1739    int j, ret, C, N, LM, M;
1740    VARDECL(celt_sig, out);
1741    SAVE_STACK;
1742
1743    if (check_decoder(st) != CELT_OK)
1744       return CELT_INVALID_STATE;
1745
1746    if (check_mode(st->mode) != CELT_OK)
1747       return CELT_INVALID_MODE;
1748
1749    if (pcm==NULL)
1750       return CELT_BAD_ARG;
1751
1752    for (LM=0;LM<4;LM++)
1753       if (st->mode->shortMdctSize<<LM==frame_size)
1754          break;
1755    if (LM>=MAX_CONFIG_SIZES)
1756       return CELT_BAD_ARG;
1757    M=1<<LM;
1758
1759    C = CHANNELS(st->channels);
1760    N = M*st->mode->shortMdctSize;
1761    ALLOC(out, C*N, celt_sig);
1762
1763    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1764
1765    if (ret==0)
1766       for (j=0;j<C*N;j++)
1767          pcm[j] = FLOAT2INT16 (out[j]);
1768    
1769    RESTORE_STACK;
1770    return ret;
1771 }
1772 #endif
1773
1774 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1775 {
1776    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1777 }
1778
1779 #ifndef DISABLE_FLOAT_API
1780 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1781 {
1782    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1783 }
1784 #endif /* DISABLE_FLOAT_API */
1785
1786 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1787 {
1788    va_list ap;
1789
1790    if (check_decoder(st) != CELT_OK)
1791       return CELT_INVALID_STATE;
1792
1793    va_start(ap, request);
1794    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1795      goto bad_mode;
1796    switch (request)
1797    {
1798       case CELT_GET_MODE_REQUEST:
1799       {
1800          const CELTMode ** value = va_arg(ap, const CELTMode**);
1801          if (value==0)
1802             goto bad_arg;
1803          *value=st->mode;
1804       }
1805       break;
1806       case CELT_SET_START_BAND_REQUEST:
1807       {
1808          celt_int32 value = va_arg(ap, celt_int32);
1809          if (value<0 || value>=st->mode->nbEBands)
1810             goto bad_arg;
1811          st->start = value;
1812       }
1813       break;
1814       case CELT_SET_END_BAND_REQUEST:
1815       {
1816          celt_int32 value = va_arg(ap, celt_int32);
1817          if (value<0 || value>=st->mode->nbEBands)
1818             goto bad_arg;
1819          st->end = value;
1820       }
1821       break;
1822       case CELT_RESET_STATE:
1823       {
1824          const CELTMode *mode = st->mode;
1825          int C = st->channels;
1826
1827          CELT_MEMSET(st->decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1828          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1829
1830          CELT_MEMSET(st->preemph_memD, 0, C);
1831
1832          st->loss_count = 0;
1833
1834          CELT_MEMSET(st->lpc, 0, C*LPC_ORDER);
1835       }
1836       break;
1837       default:
1838          goto bad_request;
1839    }
1840    va_end(ap);
1841    return CELT_OK;
1842 bad_mode:
1843   va_end(ap);
1844   return CELT_INVALID_MODE;
1845 bad_arg:
1846    va_end(ap);
1847    return CELT_BAD_ARG;
1848 bad_request:
1849       va_end(ap);
1850   return CELT_UNIMPLEMENTED;
1851 }
1852
1853 const char *celt_strerror(int error)
1854 {
1855    static const char *error_strings[8] = {
1856       "success",
1857       "invalid argument",
1858       "invalid mode",
1859       "internal error",
1860       "corrupted stream",
1861       "request not implemented",
1862       "invalid state",
1863       "memory allocation failed"
1864    };
1865    if (error > 0 || error < -7)
1866       return "unknown error";
1867    else 
1868       return error_strings[-error];
1869 }
1870