Moving intra decision to quant_coarse_energy()
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 #define ENCODERVALID   0x4c434554
69 #define ENCODERPARTIAL 0x5445434c
70 #define ENCODERFREED   0x4c004500
71    
72 /** Encoder state 
73  @brief Encoder state
74  */
75 struct CELTEncoder {
76    celt_uint32 marker;
77    const CELTMode *mode;     /**< Mode used by the encoder */
78    int overlap;
79    int channels;
80    
81    int force_intra;
82    int delayedIntra;
83    celt_word16 tonal_average;
84    int fold_decision;
85    celt_word16 gain_prod;
86    celt_word32 frame_max;
87    int start, end;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
96    celt_word32 preemph_memE[2];
97    celt_word32 preemph_memD[2];
98
99    celt_sig *in_mem;
100    celt_sig *out_mem;
101
102    celt_word16 *oldBandE;
103 };
104
105 static int check_encoder(const CELTEncoder *st) 
106 {
107    if (st==NULL)
108    {
109       celt_warning("NULL passed as an encoder structure");  
110       return CELT_INVALID_STATE;
111    }
112    if (st->marker == ENCODERVALID)
113       return CELT_OK;
114    if (st->marker == ENCODERFREED)
115       celt_warning("Referencing an encoder that has already been freed");
116    else
117       celt_warning("This is not a valid CELT encoder structure");
118    return CELT_INVALID_STATE;
119 }
120
121 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
122 {
123    int C;
124    CELTEncoder *st;
125
126    if (check_mode(mode) != CELT_OK)
127    {
128       if (error)
129          *error = CELT_INVALID_MODE;
130       return NULL;
131    }
132
133    if (channels < 0 || channels > 2)
134    {
135       celt_warning("Only mono and stereo supported");
136       if (error)
137          *error = CELT_BAD_ARG;
138       return NULL;
139    }
140
141    C = channels;
142    st = celt_alloc(sizeof(CELTEncoder));
143    
144    if (st==NULL)
145    {
146       if (error)
147          *error = CELT_ALLOC_FAIL;
148       return NULL;
149    }
150    st->marker = ENCODERPARTIAL;
151    st->mode = mode;
152    st->overlap = mode->overlap;
153    st->channels = channels;
154
155    st->start = 0;
156    st->end = st->mode->effEBands;
157
158    st->vbr_rate_norm = 0;
159    st->force_intra  = 0;
160    st->delayedIntra = 1;
161    st->tonal_average = QCONST16(1.f,8);
162    st->fold_decision = 1;
163
164    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig));
165    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig));
166
167    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
168
169    if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL))
170    {
171       if (error)
172          *error = CELT_OK;
173       st->marker   = ENCODERVALID;
174       return st;
175    }
176    /* If the setup fails for some reason deallocate it. */
177    celt_encoder_destroy(st);  
178    if (error)
179       *error = CELT_ALLOC_FAIL;
180    return NULL;
181 }
182
183 void celt_encoder_destroy(CELTEncoder *st)
184 {
185    if (st == NULL)
186    {
187       celt_warning("NULL passed to celt_encoder_destroy");
188       return;
189    }
190
191    if (st->marker == ENCODERFREED)
192    {
193       celt_warning("Freeing an encoder which has already been freed"); 
194       return;
195    }
196
197    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
198    {
199       celt_warning("This is not a valid CELT encoder structure");
200       return;
201    }
202    /*Check_mode is non-fatal here because we can still free
203     the encoder memory even if the mode is bad, although calling
204     the free functions in this order is a violation of the API.*/
205    check_mode(st->mode);
206    
207    celt_free(st->in_mem);
208    celt_free(st->out_mem);
209    celt_free(st->oldBandE);
210    
211    st->marker = ENCODERFREED;
212    
213    celt_free(st);
214 }
215
216 static inline celt_int16 FLOAT2INT16(float x)
217 {
218    x = x*CELT_SIG_SCALE;
219    x = MAX32(x, -32768);
220    x = MIN32(x, 32767);
221    return (celt_int16)float2int(x);
222 }
223
224 static inline celt_word16 SIG2WORD16(celt_sig x)
225 {
226 #ifdef FIXED_POINT
227    x = PSHR32(x, SIG_SHIFT);
228    x = MAX32(x, -32768);
229    x = MIN32(x, 32767);
230    return EXTRACT16(x);
231 #else
232    return (celt_word16)x;
233 #endif
234 }
235
236 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
237                               int *transient_time, int *transient_shift,
238                               celt_word32 *frame_max, int overlap)
239 {
240    int i, n;
241    celt_word32 ratio;
242    celt_word32 threshold;
243    VARDECL(celt_word32, begin);
244    SAVE_STACK;
245    ALLOC(begin, len+1, celt_word32);
246    begin[0] = 0;
247    if (C==1)
248    {
249       for (i=0;i<len;i++)
250          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
251    } else {
252       for (i=0;i<len;i++)
253          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
254                                             ABS32(in[C*i+1])));
255    }
256    n = -1;
257
258    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
259    /* If the following condition isn't met, there's just no way
260       we'll have a transient*/
261    if (*frame_max < threshold)
262    {
263       /* It's likely we have a transient, now find it */
264       for (i=8;i<len-8;i++)
265       {
266          if (begin[i+1] < threshold)
267             n=i;
268       }
269    }
270    if (n<32)
271    {
272       n = -1;
273       ratio = 0;
274    } else {
275       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
276    }
277
278    if (ratio > 45)
279       *transient_shift = 3;
280    else
281       *transient_shift = 0;
282    
283    *transient_time = n;
284    *frame_max = begin[len-overlap];
285
286    RESTORE_STACK;
287    return ratio > 0;
288 }
289
290 /** Apply window and compute the MDCT for all sub-frames and 
291     all channels in a frame */
292 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
293 {
294    const int C = CHANNELS(_C);
295    if (C==1 && !shortBlocks)
296    {
297       const int overlap = OVERLAP(mode);
298       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
299    } else {
300       const int overlap = OVERLAP(mode);
301       int N = mode->shortMdctSize<<LM;
302       int B = 1;
303       int b, c;
304       VARDECL(celt_word32, x);
305       VARDECL(celt_word32, tmp);
306       SAVE_STACK;
307       if (shortBlocks)
308       {
309          /*lookup = &mode->mdct[0];*/
310          N = mode->shortMdctSize;
311          B = shortBlocks;
312       }
313       ALLOC(x, N+overlap, celt_word32);
314       ALLOC(tmp, N, celt_word32);
315       for (c=0;c<C;c++)
316       {
317          for (b=0;b<B;b++)
318          {
319             int j;
320             for (j=0;j<N+overlap;j++)
321                x[j] = in[C*(b*N+j)+c];
322             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
323             /* Interleaving the sub-frames */
324             for (j=0;j<N;j++)
325                out[(j*B+b)+c*N*B] = tmp[j];
326          }
327       }
328       RESTORE_STACK;
329    }
330 }
331
332 /** Compute the IMDCT and apply window for all sub-frames and 
333     all channels in a frame */
334 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X, int transient_time, int transient_shift, celt_sig * restrict out_mem, int _C, int LM)
335 {
336    int c, N4;
337    const int C = CHANNELS(_C);
338    const int N = mode->shortMdctSize<<LM;
339    const int overlap = OVERLAP(mode);
340    N4 = (N-overlap)>>1;
341    for (c=0;c<C;c++)
342    {
343       int j;
344       if (transient_shift==0 && C==1 && !shortBlocks) {
345          clt_mdct_backward(&mode->mdct, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap, mode->maxLM-LM);
346       } else {
347          VARDECL(celt_word32, x);
348          VARDECL(celt_word32, tmp);
349          int b;
350          int N2 = N;
351          int B = 1;
352          int n4offset=0;
353          SAVE_STACK;
354          
355          ALLOC(x, 2*N, celt_word32);
356          ALLOC(tmp, N, celt_word32);
357
358          if (shortBlocks)
359          {
360             /*lookup = &mode->mdct[0];*/
361             N2 = mode->shortMdctSize;
362             B = shortBlocks;
363             n4offset = N4;
364          }
365          /* Prevents problems from the imdct doing the overlap-add */
366          CELT_MEMSET(x+N4, 0, N2);
367
368          for (b=0;b<B;b++)
369          {
370             /* De-interleaving the sub-frames */
371             for (j=0;j<N2;j++)
372                tmp[j] = X[(j*B+b)+c*N2*B];
373             clt_mdct_backward(&mode->mdct, tmp, x+n4offset+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
374          }
375
376          if (transient_shift > 0)
377          {
378 #ifdef FIXED_POINT
379             for (j=0;j<16;j++)
380                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
381             for (j=transient_time;j<N+overlap;j++)
382                x[N4+j] = SHL32(x[N4+j], transient_shift);
383 #else
384             for (j=0;j<16;j++)
385                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
386             for (j=transient_time;j<N+overlap;j++)
387                x[N4+j] *= 1<<transient_shift;
388 #endif
389          }
390          /* The first and last part would need to be set to zero 
391             if we actually wanted to use them. */
392          for (j=0;j<overlap;j++)
393             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
394          for (j=0;j<overlap;j++)
395             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
396          for (j=0;j<2*N4;j++)
397             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
398          RESTORE_STACK;
399       }
400    }
401 }
402
403 static void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
404 {
405    const int C = CHANNELS(_C);
406    int c;
407    for (c=0;c<C;c++)
408    {
409       int j;
410       celt_sig * restrict x;
411       celt_word16  * restrict y;
412       celt_sig m = mem[c];
413       x = &in[C*(MAX_PERIOD-N)+c];
414       y = pcm+c;
415       for (j=0;j<N;j++)
416       {
417          celt_sig tmp = *x + m;
418          m = MULT16_32_Q15(coef[0], tmp)
419            - MULT16_32_Q15(coef[1], *x);
420          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
421          *y = SCALEOUT(SIG2WORD16(tmp));
422          x+=C;
423          y+=C;
424       }
425       mem[c] = m;
426    }
427 }
428
429 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
430                        int end, int N,
431                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
432 {
433    int m, i, c;
434    const int C = CHANNELS(_C);
435    for (c=0;c<C;c++)
436       for (m=start;m<end;m++)
437          for (i=m+c*N;i<(c+1)*N;i+=M)
438 #ifdef FIXED_POINT
439             X[i] = SHR16(X[i], mdct_weight_shift);
440 #else
441             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
442 #endif
443    if (renorm)
444       renormalise_bands(mode, X, end_band, C, M);
445 }
446
447 static signed char tf_select_table[4][8] = {
448       {0, -1, 0, -1,    0,-1, 0,-1},
449       {0, -1, 0, -2,    1, 0, 1 -1},
450       {0, -2, 0, -3,    2, 0, 1 -1},
451       {0, -2, 0, -3,    2, 0, 1 -1},
452 };
453
454 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
455 {
456    int i;
457    celt_word16 threshold;
458    VARDECL(celt_word16, metric);
459    celt_word32 average=0;
460    celt_word32 cost0;
461    celt_word32 cost1;
462    VARDECL(int, path0);
463    VARDECL(int, path1);
464    celt_word16 lambda;
465    int tf_select=0;
466    SAVE_STACK;
467
468    /* FIXME: Should check number of bytes *left* */
469    if (nbCompressedBytes<15*C)
470    {
471       for (i=0;i<len;i++)
472          tf_res[i] = 0;
473       return 0;
474    }
475    if (nbCompressedBytes<40)
476       lambda = QCONST16(5.f, DB_SHIFT);
477    else if (nbCompressedBytes<60)
478       lambda = QCONST16(2.f, DB_SHIFT);
479    else if (nbCompressedBytes<100)
480       lambda = QCONST16(1.f, DB_SHIFT);
481    else
482       lambda = QCONST16(.5f, DB_SHIFT);
483
484    ALLOC(metric, len, celt_word16);
485    ALLOC(path0, len, int);
486    ALLOC(path1, len, int);
487    for (i=0;i<len;i++)
488    {
489       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
490       average += metric[i];
491    }
492    if (C==2)
493    {
494       average = 0;
495       for (i=0;i<len;i++)
496       {
497          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
498          average += metric[i];
499       }
500    }
501    average = DIV32(average, len);
502    /*if (!isTransient)
503       printf ("%f\n", average);*/
504    if (isTransient)
505    {
506       threshold = QCONST16(1.f,DB_SHIFT);
507       tf_select = average > QCONST16(3.f,DB_SHIFT);
508    } else {
509       threshold = QCONST16(.5f,DB_SHIFT);
510       tf_select = average > QCONST16(1.f,DB_SHIFT);
511    }
512    cost0 = 0;
513    cost1 = lambda;
514    /* Viterbi forward pass */
515    for (i=1;i<len;i++)
516    {
517       celt_word32 curr0, curr1;
518       celt_word32 from0, from1;
519
520       from0 = cost0;
521       from1 = cost1 + lambda;
522       if (from0 < from1)
523       {
524          curr0 = from0;
525          path0[i]= 0;
526       } else {
527          curr0 = from1;
528          path0[i]= 1;
529       }
530
531       from0 = cost0 + lambda;
532       from1 = cost1;
533       if (from0 < from1)
534       {
535          curr1 = from0;
536          path1[i]= 0;
537       } else {
538          curr1 = from1;
539          path1[i]= 1;
540       }
541       cost0 = curr0 + (metric[i]-threshold);
542       cost1 = curr1;
543    }
544    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
545    /* Viterbi backward pass to check the decisions */
546    for (i=len-2;i>=0;i--)
547    {
548       if (tf_res[i+1] == 1)
549          tf_res[i] = path1[i+1];
550       else
551          tf_res[i] = path0[i+1];
552    }
553    RESTORE_STACK;
554    return tf_select;
555 }
556
557 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
558 {
559    int curr, i;
560    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
561    {
562       for (i=start;i<end;i++)
563          tf_res[i] = isTransient;
564    } else {
565       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
566       curr = tf_res[start];
567       for (i=start+1;i<end;i++)
568       {
569          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
570          curr = tf_res[i];
571       }
572    }
573    ec_enc_bits(enc, tf_select, 1);
574    for (i=start;i<end;i++)
575       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
576 }
577
578 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
579 {
580    int i, curr, tf_select;
581    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
582    {
583       for (i=start;i<end;i++)
584          tf_res[i] = isTransient;
585    } else {
586       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
587       curr = tf_res[start];
588       for (i=start+1;i<end;i++)
589       {
590          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
591          curr = tf_res[i];
592       }
593    }
594    tf_select = ec_dec_bits(dec, 1);
595    for (i=start;i<end;i++)
596       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
597 }
598
599 #ifdef FIXED_POINT
600 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
601 {
602 #else
603 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
604 {
605 #endif
606    int i, c, N, NN, N4;
607    int bits;
608    int has_fold=1;
609    ec_byte_buffer buf;
610    ec_enc         _enc;
611    VARDECL(celt_sig, in);
612    VARDECL(celt_sig, freq);
613    VARDECL(celt_norm, X);
614    VARDECL(celt_ener, bandE);
615    VARDECL(celt_word16, bandLogE);
616    VARDECL(int, fine_quant);
617    VARDECL(celt_word16, error);
618    VARDECL(int, pulses);
619    VARDECL(int, offsets);
620    VARDECL(int, fine_priority);
621    VARDECL(int, tf_res);
622    int shortBlocks=0;
623    int isTransient=0;
624    int transient_time, transient_time_quant;
625    int transient_shift;
626    int resynth;
627    const int C = CHANNELS(st->channels);
628    int mdct_weight_shift = 0;
629    int mdct_weight_pos=0;
630    int LM, M;
631    int tf_select;
632    int nbFilledBytes, nbAvailableBytes;
633    int effEnd;
634    SAVE_STACK;
635
636    if (check_encoder(st) != CELT_OK)
637       return CELT_INVALID_STATE;
638
639    if (check_mode(st->mode) != CELT_OK)
640       return CELT_INVALID_MODE;
641
642    if (nbCompressedBytes<0 || pcm==NULL)
643      return CELT_BAD_ARG;
644
645    for (LM=0;LM<4;LM++)
646       if (st->mode->shortMdctSize<<LM==frame_size)
647          break;
648    if (LM>=MAX_CONFIG_SIZES)
649       return CELT_BAD_ARG;
650    M=1<<LM;
651
652    if (enc==NULL)
653    {
654       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
655       ec_enc_init(&_enc,&buf);
656       enc = &_enc;
657       nbFilledBytes=0;
658    } else {
659       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
660    }
661    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
662
663    effEnd = st->end;
664    if (effEnd > st->mode->effEBands)
665       effEnd = st->mode->effEBands;
666
667    N = M*st->mode->shortMdctSize;
668    N4 = (N-st->overlap)>>1;
669    ALLOC(in, 2*C*N-2*C*N4, celt_sig);
670
671    CELT_COPY(in, st->in_mem, C*st->overlap);
672    for (c=0;c<C;c++)
673    {
674       const celt_word16 * restrict pcmp = pcm+c;
675       celt_sig * restrict inp = in+C*st->overlap+c;
676       for (i=0;i<N;i++)
677       {
678          /* Apply pre-emphasis */
679          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
680          *inp = tmp + st->preemph_memE[c];
681          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
682                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
683          inp += C;
684          pcmp += C;
685       }
686    }
687    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
688
689    /* Transient handling */
690    transient_time = -1;
691    transient_time_quant = -1;
692    transient_shift = 0;
693    isTransient = 0;
694
695    resynth = optional_resynthesis!=NULL;
696
697    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
698    {
699 #ifndef FIXED_POINT
700       float gain_1;
701 #endif
702       /* Apply the inverse shaping window */
703       if (transient_shift)
704       {
705          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
706          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
707 #ifdef FIXED_POINT
708          for (c=0;c<C;c++)
709             for (i=0;i<16;i++)
710                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
711          for (c=0;c<C;c++)
712             for (i=transient_time;i<N+st->overlap;i++)
713                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
714 #else
715          for (c=0;c<C;c++)
716             for (i=0;i<16;i++)
717                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
718          gain_1 = 1.f/(1<<transient_shift);
719          for (c=0;c<C;c++)
720             for (i=transient_time;i<N+st->overlap;i++)
721                in[C*i+c] *= gain_1;
722 #endif
723       }
724       isTransient = 1;
725       has_fold = 1;
726    }
727
728    if (isTransient)
729       shortBlocks = M;
730    else
731       shortBlocks = 0;
732
733    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
734    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
735    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
736    /* Compute MDCTs */
737    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
738
739    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
740
741    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
742
743    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
744
745    /* Band normalisation */
746    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
747
748    NN = M*st->mode->eBands[effEnd];
749    if (shortBlocks && !transient_shift)
750    {
751       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
752       int m;
753       for (c=0;c<C;c++)
754       {
755          m=0;
756          do {
757             celt_word32 tmp=0;
758             for (i=m+c*N;i<c*N+NN;i+=M)
759                tmp += ABS32(X[i]);
760             sum[m++] += tmp;
761          } while (m<M);
762       }
763       m=0;
764 #ifdef FIXED_POINT
765       do {
766          if (SHR32(sum[m+1],3) > sum[m])
767          {
768             mdct_weight_shift=2;
769             mdct_weight_pos = m;
770          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
771          {
772             mdct_weight_shift=1;
773             mdct_weight_pos = m;
774          }
775          m++;
776       } while (m<M-1);
777 #else
778       do {
779          if (sum[m+1] > 8*sum[m])
780          {
781             mdct_weight_shift=2;
782             mdct_weight_pos = m;
783          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
784          {
785             mdct_weight_shift=1;
786             mdct_weight_pos = m;
787          }
788          m++;
789       } while (m<M-1);
790 #endif
791       if (mdct_weight_shift)
792          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
793    }
794
795    ALLOC(tf_res, st->mode->nbEBands, int);
796    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
797    tf_select = tf_analysis(bandLogE, st->oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
798    for (i=effEnd;i<st->end;i++)
799       tf_res[i] = tf_res[effEnd-1];
800
801    ALLOC(error, C*st->mode->nbEBands, celt_word16);
802    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
803          st->oldBandE, nbCompressedBytes*8, st->mode->prob,
804          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
805
806    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
807
808    if (shortBlocks)
809    {
810       if (transient_shift)
811       {
812          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
813          ec_enc_uint(enc, transient_shift, 4);
814          ec_enc_uint(enc, transient_time_quant, max_time);
815       } else {
816          ec_enc_uint(enc, mdct_weight_shift, 4);
817          if (mdct_weight_shift && M!=2)
818             ec_enc_uint(enc, mdct_weight_pos, M-1);
819       }
820    }
821
822    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
823
824    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
825       has_fold = 0;
826    ec_enc_bit_prob(enc, has_fold>>1, 8192);
827    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
828
829    /* Variable bitrate */
830    if (st->vbr_rate_norm>0)
831    {
832      celt_word16 alpha;
833      celt_int32 delta;
834      /* The target rate in 16th bits per frame */
835      celt_int32 vbr_rate;
836      celt_int32 target;
837      celt_int32 vbr_bound, max_allowed;
838
839      vbr_rate = M*st->vbr_rate_norm;
840
841      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
842      vbr_bound = vbr_rate;
843      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
844      if (max_allowed < 4)
845         max_allowed = 4;
846      if (max_allowed < nbAvailableBytes)
847         nbAvailableBytes = max_allowed;
848      target=vbr_rate;
849
850      /* Shortblocks get a large boost in bitrate, but since they 
851         are uncommon long blocks are not greatly effected */
852      if (shortBlocks)
853        target*=2;
854      else if (M > 1)
855        target-=(target+14)/28;
856
857      /* The average energy is removed from the target and the actual 
858         energy added*/
859      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
860
861      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
862      target=IMIN(nbAvailableBytes,target);
863      /* Make the adaptation coef (alpha) higher at the beginning */
864      if (st->vbr_count < 990)
865      {
866         st->vbr_count++;
867         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
868         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
869      } else
870         alpha = QCONST16(.001f,15);
871
872      /* By how much did we "miss" the target on that frame */
873      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
874      /* How many bits have we used in excess of what we're allowed */
875      st->vbr_reservoir += delta;
876      /*printf ("%d\n", st->vbr_reservoir);*/
877
878      /* Compute the offset we need to apply in order to reach the target */
879      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
880      st->vbr_offset = -st->vbr_drift;
881      /*printf ("%d\n", st->vbr_drift);*/
882
883      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
884      if (st->vbr_reservoir < 0)
885      {
886         /* We're under the min value -- increase rate */
887         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
888         st->vbr_reservoir += adjust*(8<<BITRES);
889         target += adjust;
890         /*printf ("+%d\n", adjust);*/
891      }
892      if (target < nbAvailableBytes)
893         nbAvailableBytes = target;
894      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
895
896      /* This moves the raw bits to take into account the new compressed size */
897      ec_byte_shrink(&buf, nbCompressedBytes);
898    }
899
900    /* Bit allocation */
901    ALLOC(fine_quant, st->mode->nbEBands, int);
902    ALLOC(pulses, st->mode->nbEBands, int);
903    ALLOC(offsets, st->mode->nbEBands, int);
904    ALLOC(fine_priority, st->mode->nbEBands, int);
905
906    for (i=0;i<st->mode->nbEBands;i++)
907       offsets[i] = 0;
908    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
909    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
910
911    quant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, enc, C);
912
913 #ifdef MEASURE_NORM_MSE
914    float X0[3000];
915    float bandE0[60];
916    for (c=0;c<C;c++)
917       for (i=0;i<N;i++)
918          X0[i+c*N] = X[i+c*N];
919    for (i=0;i<C*st->mode->nbEBands;i++)
920       bandE0[i] = bandE[i];
921 #endif
922
923    /* Residual quantisation */
924    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
925
926    quant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
927
928    /* Re-synthesis of the coded audio if required */
929    if (resynth)
930    {
931       log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
932
933 #ifdef MEASURE_NORM_MSE
934       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
935 #endif
936
937       if (mdct_weight_shift)
938       {
939          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
940       }
941
942       /* Synthesis */
943       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
944
945       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
946
947       for (c=0;c<C;c++)
948          for (i=0;i<M*st->mode->eBands[st->start];i++)
949             freq[c*N+i] = 0;
950       for (c=0;c<C;c++)
951          for (i=M*st->mode->eBands[st->end];i<N;i++)
952             freq[c*N+i] = 0;
953
954       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
955
956       /* De-emphasis and put everything back at the right place 
957          in the synthesis history */
958       if (optional_resynthesis != NULL) {
959          deemphasis(st->out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
960
961       }
962    }
963
964    /* If there's any room left (can only happen for very high rates),
965       fill it with zeros */
966    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
967       ec_enc_bits(enc, 0, 8);
968    ec_enc_done(enc);
969    
970    RESTORE_STACK;
971    if (ec_enc_get_error(enc))
972       return CELT_CORRUPTED_DATA;
973    else
974       return nbCompressedBytes;
975 }
976
977 #ifdef FIXED_POINT
978 #ifndef DISABLE_FLOAT_API
979 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
980 {
981    int j, ret, C, N, LM, M;
982    VARDECL(celt_int16, in);
983    SAVE_STACK;
984
985    if (check_encoder(st) != CELT_OK)
986       return CELT_INVALID_STATE;
987
988    if (check_mode(st->mode) != CELT_OK)
989       return CELT_INVALID_MODE;
990
991    if (pcm==NULL)
992       return CELT_BAD_ARG;
993
994    for (LM=0;LM<4;LM++)
995       if (st->mode->shortMdctSize<<LM==frame_size)
996          break;
997    if (LM>=MAX_CONFIG_SIZES)
998       return CELT_BAD_ARG;
999    M=1<<LM;
1000
1001    C = CHANNELS(st->channels);
1002    N = M*st->mode->shortMdctSize;
1003    ALLOC(in, C*N, celt_int16);
1004
1005    for (j=0;j<C*N;j++)
1006      in[j] = FLOAT2INT16(pcm[j]);
1007
1008    if (optional_resynthesis != NULL) {
1009      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1010       for (j=0;j<C*N;j++)
1011          optional_resynthesis[j]=in[j]*(1.f/32768.f);
1012    } else {
1013      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1014    }
1015    RESTORE_STACK;
1016    return ret;
1017
1018 }
1019 #endif /*DISABLE_FLOAT_API*/
1020 #else
1021 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1022 {
1023    int j, ret, C, N, LM, M;
1024    VARDECL(celt_sig, in);
1025    SAVE_STACK;
1026
1027    if (check_encoder(st) != CELT_OK)
1028       return CELT_INVALID_STATE;
1029
1030    if (check_mode(st->mode) != CELT_OK)
1031       return CELT_INVALID_MODE;
1032
1033    if (pcm==NULL)
1034       return CELT_BAD_ARG;
1035
1036    for (LM=0;LM<4;LM++)
1037       if (st->mode->shortMdctSize<<LM==frame_size)
1038          break;
1039    if (LM>=MAX_CONFIG_SIZES)
1040       return CELT_BAD_ARG;
1041    M=1<<LM;
1042
1043    C=CHANNELS(st->channels);
1044    N=M*st->mode->shortMdctSize;
1045    ALLOC(in, C*N, celt_sig);
1046    for (j=0;j<C*N;j++) {
1047      in[j] = SCALEOUT(pcm[j]);
1048    }
1049
1050    if (optional_resynthesis != NULL) {
1051       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1052       for (j=0;j<C*N;j++)
1053          optional_resynthesis[j] = FLOAT2INT16(in[j]);
1054    } else {
1055       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1056    }
1057    RESTORE_STACK;
1058    return ret;
1059 }
1060 #endif
1061
1062 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1063 {
1064    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1065 }
1066
1067 #ifndef DISABLE_FLOAT_API
1068 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1069 {
1070    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1071 }
1072 #endif /* DISABLE_FLOAT_API */
1073
1074 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1075 {
1076    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1077 }
1078
1079 #ifndef DISABLE_FLOAT_API
1080 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1081 {
1082    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1083 }
1084 #endif /* DISABLE_FLOAT_API */
1085
1086
1087 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1088 {
1089    va_list ap;
1090    
1091    if (check_encoder(st) != CELT_OK)
1092       return CELT_INVALID_STATE;
1093
1094    va_start(ap, request);
1095    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1096      goto bad_mode;
1097    switch (request)
1098    {
1099       case CELT_GET_MODE_REQUEST:
1100       {
1101          const CELTMode ** value = va_arg(ap, const CELTMode**);
1102          if (value==0)
1103             goto bad_arg;
1104          *value=st->mode;
1105       }
1106       break;
1107       case CELT_SET_COMPLEXITY_REQUEST:
1108       {
1109          int value = va_arg(ap, celt_int32);
1110          if (value<0 || value>10)
1111             goto bad_arg;
1112       }
1113       break;
1114       case CELT_SET_START_BAND_REQUEST:
1115       {
1116          celt_int32 value = va_arg(ap, celt_int32);
1117          if (value<0 || value>=st->mode->nbEBands)
1118             goto bad_arg;
1119          st->start = value;
1120       }
1121       break;
1122       case CELT_SET_END_BAND_REQUEST:
1123       {
1124          celt_int32 value = va_arg(ap, celt_int32);
1125          if (value<0 || value>=st->mode->nbEBands)
1126             goto bad_arg;
1127          st->end = value;
1128       }
1129       break;
1130       case CELT_SET_PREDICTION_REQUEST:
1131       {
1132          int value = va_arg(ap, celt_int32);
1133          if (value<0 || value>2)
1134             goto bad_arg;
1135          if (value==0)
1136          {
1137             st->force_intra   = 1;
1138          } else if (value==1) {
1139             st->force_intra   = 0;
1140          } else {
1141             st->force_intra   = 0;
1142          }   
1143       }
1144       break;
1145       case CELT_SET_VBR_RATE_REQUEST:
1146       {
1147          celt_int32 value = va_arg(ap, celt_int32);
1148          int frame_rate;
1149          int N = st->mode->shortMdctSize;
1150          if (value<0)
1151             goto bad_arg;
1152          if (value>3072000)
1153             value = 3072000;
1154          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1155          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1156       }
1157       break;
1158       case CELT_RESET_STATE:
1159       {
1160          const CELTMode *mode = st->mode;
1161          int C = st->channels;
1162
1163          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1164          CELT_MEMSET(st->out_mem, 0, (MAX_PERIOD+st->overlap)*C);
1165
1166          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1167
1168          CELT_MEMSET(st->preemph_memE, 0, C);
1169          CELT_MEMSET(st->preemph_memD, 0, C);
1170          st->delayedIntra = 1;
1171
1172          st->fold_decision = 1;
1173          st->tonal_average = QCONST16(1.f,8);
1174          st->gain_prod = 0;
1175          st->vbr_reservoir = 0;
1176          st->vbr_drift = 0;
1177          st->vbr_offset = 0;
1178          st->vbr_count = 0;
1179          st->frame_max = 0;
1180       }
1181       break;
1182       default:
1183          goto bad_request;
1184    }
1185    va_end(ap);
1186    return CELT_OK;
1187 bad_mode:
1188   va_end(ap);
1189   return CELT_INVALID_MODE;
1190 bad_arg:
1191    va_end(ap);
1192    return CELT_BAD_ARG;
1193 bad_request:
1194    va_end(ap);
1195    return CELT_UNIMPLEMENTED;
1196 }
1197
1198 /**********************************************************************/
1199 /*                                                                    */
1200 /*                             DECODER                                */
1201 /*                                                                    */
1202 /**********************************************************************/
1203 #define DECODE_BUFFER_SIZE 2048
1204
1205 #define DECODERVALID   0x4c434454
1206 #define DECODERPARTIAL 0x5444434c
1207 #define DECODERFREED   0x4c004400
1208
1209 /** Decoder state 
1210  @brief Decoder state
1211  */
1212 struct CELTDecoder {
1213    celt_uint32 marker;
1214    const CELTMode *mode;
1215    int overlap;
1216    int channels;
1217
1218    int start, end;
1219
1220    celt_sig preemph_memD[2];
1221
1222    celt_sig *out_mem;
1223    celt_word32 *decode_mem;
1224
1225    celt_word16 *oldBandE;
1226    
1227    celt_word16 *lpc;
1228
1229    int last_pitch_index;
1230    int loss_count;
1231 };
1232
1233 int check_decoder(const CELTDecoder *st) 
1234 {
1235    if (st==NULL)
1236    {
1237       celt_warning("NULL passed a decoder structure");  
1238       return CELT_INVALID_STATE;
1239    }
1240    if (st->marker == DECODERVALID)
1241       return CELT_OK;
1242    if (st->marker == DECODERFREED)
1243       celt_warning("Referencing a decoder that has already been freed");
1244    else
1245       celt_warning("This is not a valid CELT decoder structure");
1246    return CELT_INVALID_STATE;
1247 }
1248
1249 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1250 {
1251    int C;
1252    CELTDecoder *st;
1253
1254    if (check_mode(mode) != CELT_OK)
1255    {
1256       if (error)
1257          *error = CELT_INVALID_MODE;
1258       return NULL;
1259    }
1260
1261    if (channels < 0 || channels > 2)
1262    {
1263       celt_warning("Only mono and stereo supported");
1264       if (error)
1265          *error = CELT_BAD_ARG;
1266       return NULL;
1267    }
1268
1269    C = CHANNELS(channels);
1270    st = celt_alloc(sizeof(CELTDecoder));
1271
1272    if (st==NULL)
1273    {
1274       if (error)
1275          *error = CELT_ALLOC_FAIL;
1276       return NULL;
1277    }
1278
1279    st->marker = DECODERPARTIAL;
1280    st->mode = mode;
1281    st->overlap = mode->overlap;
1282    st->channels = channels;
1283
1284    st->start = 0;
1285    st->end = st->mode->effEBands;
1286
1287    st->decode_mem = (celt_sig*)celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig));
1288    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
1289    
1290    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
1291    
1292    st->lpc = (celt_word16*)celt_alloc(C*LPC_ORDER*sizeof(celt_word16));
1293
1294    st->loss_count = 0;
1295
1296    if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
1297          (st->lpc!=NULL))
1298    {
1299       if (error)
1300          *error = CELT_OK;
1301       st->marker = DECODERVALID;
1302       return st;
1303    }
1304    /* If the setup fails for some reason deallocate it. */
1305    celt_decoder_destroy(st);
1306    if (error)
1307       *error = CELT_ALLOC_FAIL;
1308    return NULL;
1309 }
1310
1311 void celt_decoder_destroy(CELTDecoder *st)
1312 {
1313    if (st == NULL)
1314    {
1315       celt_warning("NULL passed to celt_decoder_destroy");
1316       return;
1317    }
1318
1319    if (st->marker == DECODERFREED) 
1320    {
1321       celt_warning("Freeing a decoder which has already been freed"); 
1322       return;
1323    }
1324    
1325    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1326    {
1327       celt_warning("This is not a valid CELT decoder structure");
1328       return;
1329    }
1330    
1331    /*Check_mode is non-fatal here because we can still free
1332      the encoder memory even if the mode is bad, although calling
1333      the free functions in this order is a violation of the API.*/
1334    check_mode(st->mode);
1335    
1336    celt_free(st->decode_mem);
1337    celt_free(st->oldBandE);
1338    celt_free(st->lpc);
1339    
1340    st->marker = DECODERFREED;
1341    
1342    celt_free(st);
1343 }
1344
1345 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1346 {
1347    int c;
1348    int pitch_index;
1349    int overlap = st->mode->overlap;
1350    celt_word16 fade = Q15ONE;
1351    int i, len;
1352    const int C = CHANNELS(st->channels);
1353    int offset;
1354    SAVE_STACK;
1355    
1356    len = N+st->mode->overlap;
1357    
1358    if (st->loss_count == 0)
1359    {
1360       celt_word16 pitch_buf[MAX_PERIOD>>1];
1361       celt_word32 tmp=0;
1362       celt_word32 mem0[2]={0,0};
1363       celt_word16 mem1[2]={0,0};
1364       int len2 = len;
1365       /* FIXME: This is a kludge */
1366       if (len2>MAX_PERIOD>>1)
1367          len2 = MAX_PERIOD>>1;
1368       pitch_downsample(st->out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1369                        C, mem0, mem1);
1370       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1371                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1372       pitch_index = MAX_PERIOD-len2-pitch_index;
1373       st->last_pitch_index = pitch_index;
1374    } else {
1375       pitch_index = st->last_pitch_index;
1376       if (st->loss_count < 5)
1377          fade = QCONST16(.8f,15);
1378       else
1379          fade = 0;
1380    }
1381
1382    for (c=0;c<C;c++)
1383    {
1384       /* FIXME: This is more memory than necessary */
1385       celt_word32 e[2*MAX_PERIOD];
1386       celt_word16 exc[2*MAX_PERIOD];
1387       celt_word32 ac[LPC_ORDER+1];
1388       celt_word16 decay = 1;
1389       celt_word32 S1=0;
1390       celt_word16 mem[LPC_ORDER]={0};
1391
1392       offset = MAX_PERIOD-pitch_index;
1393       for (i=0;i<MAX_PERIOD;i++)
1394          exc[i] = ROUND16(st->out_mem[i*C+c], SIG_SHIFT);
1395
1396       if (st->loss_count == 0)
1397       {
1398          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1399                         LPC_ORDER, MAX_PERIOD);
1400
1401          /* Noise floor -40 dB */
1402 #ifdef FIXED_POINT
1403          ac[0] += SHR32(ac[0],13);
1404 #else
1405          ac[0] *= 1.0001f;
1406 #endif
1407          /* Lag windowing */
1408          for (i=1;i<=LPC_ORDER;i++)
1409          {
1410             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1411 #ifdef FIXED_POINT
1412             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1413 #else
1414             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1415 #endif
1416          }
1417
1418          _celt_lpc(st->lpc+c*LPC_ORDER, ac, LPC_ORDER);
1419       }
1420       fir(exc, st->lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1421       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1422       /* Check if the waveform is decaying (and if so how fast) */
1423       {
1424          celt_word32 E1=1, E2=1;
1425          int period;
1426          if (pitch_index <= MAX_PERIOD/2)
1427             period = pitch_index;
1428          else
1429             period = MAX_PERIOD/2;
1430          for (i=0;i<period;i++)
1431          {
1432             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1433             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1434          }
1435          if (E1 > E2)
1436             E1 = E2;
1437          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1438       }
1439
1440       /* Copy excitation, taking decay into account */
1441       for (i=0;i<len+st->mode->overlap;i++)
1442       {
1443          if (offset+i >= MAX_PERIOD)
1444          {
1445             offset -= pitch_index;
1446             decay = MULT16_16_Q15(decay, decay);
1447          }
1448          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1449          S1 += SHR32(MULT16_16(st->out_mem[offset+i],st->out_mem[offset+i]),8);
1450       }
1451
1452       iir(e, st->lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1453
1454       {
1455          celt_word32 S2=0;
1456          for (i=0;i<len+overlap;i++)
1457             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1458          /* This checks for an "explosion" in the synthesis */
1459 #ifdef FIXED_POINT
1460          if (!(S1 > SHR32(S2,2)))
1461 #else
1462          /* Float test is written this way to catch NaNs at the same time */
1463          if (!(S1 > 0.2f*S2))
1464 #endif
1465          {
1466             for (i=0;i<len+overlap;i++)
1467                e[i] = 0;
1468          } else if (S1 < S2)
1469          {
1470             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1471             for (i=0;i<len+overlap;i++)
1472                e[i] = MULT16_16_Q15(ratio, e[i]);
1473          }
1474       }
1475
1476       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1477          st->out_mem[C*i+c] = st->out_mem[C*(N+i)+c];
1478
1479       /* Apply TDAC to the concealed audio so that it blends with the
1480          previous and next frames */
1481       for (i=0;i<overlap/2;i++)
1482       {
1483          celt_word32 tmp1, tmp2;
1484          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1485                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1486          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1487                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1488          tmp1 = MULT16_32_Q15(fade, tmp1);
1489          tmp2 = MULT16_32_Q15(fade, tmp2);
1490          st->out_mem[C*(MAX_PERIOD+i)+c] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1491          st->out_mem[C*(MAX_PERIOD+overlap-i-1)+c] = MULT16_32_Q15(st->mode->window[i], tmp2);
1492          st->out_mem[C*(MAX_PERIOD-N+i)+c] += MULT16_32_Q15(st->mode->window[i], tmp1);
1493          st->out_mem[C*(MAX_PERIOD-N+overlap-i-1)+c] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1494       }
1495       for (i=0;i<N-overlap;i++)
1496          st->out_mem[C*(MAX_PERIOD-N+overlap+i)+c] = MULT16_32_Q15(fade, e[overlap+i]);
1497    }
1498
1499    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1500    
1501    st->loss_count++;
1502
1503    RESTORE_STACK;
1504 }
1505
1506 #ifdef FIXED_POINT
1507 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1508 {
1509 #else
1510 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1511 {
1512 #endif
1513    int c, i, N, N4;
1514    int has_fold;
1515    int bits;
1516    ec_dec _dec;
1517    ec_byte_buffer buf;
1518    VARDECL(celt_sig, freq);
1519    VARDECL(celt_norm, X);
1520    VARDECL(celt_ener, bandE);
1521    VARDECL(int, fine_quant);
1522    VARDECL(int, pulses);
1523    VARDECL(int, offsets);
1524    VARDECL(int, fine_priority);
1525    VARDECL(int, tf_res);
1526
1527    int shortBlocks;
1528    int isTransient;
1529    int intra_ener;
1530    int transient_time;
1531    int transient_shift;
1532    int mdct_weight_shift=0;
1533    const int C = CHANNELS(st->channels);
1534    int mdct_weight_pos=0;
1535    int LM, M;
1536    int nbFilledBytes, nbAvailableBytes;
1537    int effEnd;
1538    SAVE_STACK;
1539
1540    if (check_decoder(st) != CELT_OK)
1541       return CELT_INVALID_STATE;
1542
1543    if (check_mode(st->mode) != CELT_OK)
1544       return CELT_INVALID_MODE;
1545
1546    if (pcm==NULL)
1547       return CELT_BAD_ARG;
1548
1549    for (LM=0;LM<4;LM++)
1550       if (st->mode->shortMdctSize<<LM==frame_size)
1551          break;
1552    if (LM>=MAX_CONFIG_SIZES)
1553       return CELT_BAD_ARG;
1554    M=1<<LM;
1555
1556    N = M*st->mode->shortMdctSize;
1557    N4 = (N-st->overlap)>>1;
1558
1559    effEnd = st->end;
1560    if (effEnd > st->mode->effEBands)
1561       effEnd = st->mode->effEBands;
1562
1563    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1564    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1565    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1566    for (c=0;c<C;c++)
1567       for (i=0;i<M*st->mode->eBands[st->start];i++)
1568          X[c*N+i] = 0;
1569    for (c=0;c<C;c++)
1570       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1571          X[c*N+i] = 0;
1572
1573    if (data == NULL)
1574    {
1575       celt_decode_lost(st, pcm, N, LM);
1576       RESTORE_STACK;
1577       return CELT_OK;
1578    }
1579    if (len<0) {
1580      RESTORE_STACK;
1581      return CELT_BAD_ARG;
1582    }
1583    
1584    if (dec == NULL)
1585    {
1586       ec_byte_readinit(&buf,(unsigned char*)data,len);
1587       ec_dec_init(&_dec,&buf);
1588       dec = &_dec;
1589       nbFilledBytes = 0;
1590    } else {
1591       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1592    }
1593    nbAvailableBytes = len-nbFilledBytes;
1594
1595    /* Decode the global flags (first symbols in the stream) */
1596    intra_ener = ec_dec_bit_prob(dec, 8192);
1597    /* Get band energies */
1598    unquant_coarse_energy(st->mode, st->start, st->end, bandE, st->oldBandE, intra_ener, st->mode->prob, dec, C, LM);
1599
1600    isTransient = ec_dec_bit_prob(dec, 8192);
1601
1602    if (isTransient)
1603       shortBlocks = M;
1604    else
1605       shortBlocks = 0;
1606
1607    if (isTransient)
1608    {
1609       transient_shift = ec_dec_uint(dec, 4);
1610       if (transient_shift == 3)
1611       {
1612          int transient_time_quant;
1613          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1614          transient_time_quant = ec_dec_uint(dec, max_time);
1615          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1616       } else {
1617          mdct_weight_shift = transient_shift;
1618          if (mdct_weight_shift && M>2)
1619             mdct_weight_pos = ec_dec_uint(dec, M-1);
1620          transient_shift = 0;
1621          transient_time = 0;
1622       }
1623    } else {
1624       transient_time = -1;
1625       transient_shift = 0;
1626    }
1627
1628    ALLOC(tf_res, st->mode->nbEBands, int);
1629    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1630
1631    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1632    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1633
1634    ALLOC(pulses, st->mode->nbEBands, int);
1635    ALLOC(offsets, st->mode->nbEBands, int);
1636    ALLOC(fine_priority, st->mode->nbEBands, int);
1637
1638    for (i=0;i<st->mode->nbEBands;i++)
1639       offsets[i] = 0;
1640
1641    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1642    ALLOC(fine_quant, st->mode->nbEBands, int);
1643    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1644    /*bits = ec_dec_tell(dec, 0);
1645    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1646    
1647    unquant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, dec, C);
1648
1649    /* Decode fixed codebook */
1650    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1651
1652    unquant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1653
1654    log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
1655
1656    if (mdct_weight_shift)
1657    {
1658       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1659    }
1660
1661    /* Synthesis */
1662    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1663
1664
1665    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1666
1667    for (c=0;c<C;c++)
1668       for (i=0;i<M*st->mode->eBands[st->start];i++)
1669          freq[c*N+i] = 0;
1670    for (c=0;c<C;c++)
1671       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1672          freq[c*N+i] = 0;
1673
1674    /* Compute inverse MDCTs */
1675    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
1676
1677    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1678    st->loss_count = 0;
1679    RESTORE_STACK;
1680    if (ec_dec_get_error(dec))
1681       return CELT_CORRUPTED_DATA;
1682    else
1683       return CELT_OK;
1684 }
1685
1686 #ifdef FIXED_POINT
1687 #ifndef DISABLE_FLOAT_API
1688 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1689 {
1690    int j, ret, C, N, LM, M;
1691    VARDECL(celt_int16, out);
1692    SAVE_STACK;
1693
1694    if (check_decoder(st) != CELT_OK)
1695       return CELT_INVALID_STATE;
1696
1697    if (check_mode(st->mode) != CELT_OK)
1698       return CELT_INVALID_MODE;
1699
1700    if (pcm==NULL)
1701       return CELT_BAD_ARG;
1702
1703    for (LM=0;LM<4;LM++)
1704       if (st->mode->shortMdctSize<<LM==frame_size)
1705          break;
1706    if (LM>=MAX_CONFIG_SIZES)
1707       return CELT_BAD_ARG;
1708    M=1<<LM;
1709
1710    C = CHANNELS(st->channels);
1711    N = M*st->mode->shortMdctSize;
1712    
1713    ALLOC(out, C*N, celt_int16);
1714    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1715    if (ret==0)
1716       for (j=0;j<C*N;j++)
1717          pcm[j]=out[j]*(1.f/32768.f);
1718      
1719    RESTORE_STACK;
1720    return ret;
1721 }
1722 #endif /*DISABLE_FLOAT_API*/
1723 #else
1724 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1725 {
1726    int j, ret, C, N, LM, M;
1727    VARDECL(celt_sig, out);
1728    SAVE_STACK;
1729
1730    if (check_decoder(st) != CELT_OK)
1731       return CELT_INVALID_STATE;
1732
1733    if (check_mode(st->mode) != CELT_OK)
1734       return CELT_INVALID_MODE;
1735
1736    if (pcm==NULL)
1737       return CELT_BAD_ARG;
1738
1739    for (LM=0;LM<4;LM++)
1740       if (st->mode->shortMdctSize<<LM==frame_size)
1741          break;
1742    if (LM>=MAX_CONFIG_SIZES)
1743       return CELT_BAD_ARG;
1744    M=1<<LM;
1745
1746    C = CHANNELS(st->channels);
1747    N = M*st->mode->shortMdctSize;
1748    ALLOC(out, C*N, celt_sig);
1749
1750    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1751
1752    if (ret==0)
1753       for (j=0;j<C*N;j++)
1754          pcm[j] = FLOAT2INT16 (out[j]);
1755    
1756    RESTORE_STACK;
1757    return ret;
1758 }
1759 #endif
1760
1761 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1762 {
1763    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1764 }
1765
1766 #ifndef DISABLE_FLOAT_API
1767 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1768 {
1769    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1770 }
1771 #endif /* DISABLE_FLOAT_API */
1772
1773 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1774 {
1775    va_list ap;
1776
1777    if (check_decoder(st) != CELT_OK)
1778       return CELT_INVALID_STATE;
1779
1780    va_start(ap, request);
1781    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1782      goto bad_mode;
1783    switch (request)
1784    {
1785       case CELT_GET_MODE_REQUEST:
1786       {
1787          const CELTMode ** value = va_arg(ap, const CELTMode**);
1788          if (value==0)
1789             goto bad_arg;
1790          *value=st->mode;
1791       }
1792       break;
1793       case CELT_SET_START_BAND_REQUEST:
1794       {
1795          celt_int32 value = va_arg(ap, celt_int32);
1796          if (value<0 || value>=st->mode->nbEBands)
1797             goto bad_arg;
1798          st->start = value;
1799       }
1800       break;
1801       case CELT_SET_END_BAND_REQUEST:
1802       {
1803          celt_int32 value = va_arg(ap, celt_int32);
1804          if (value<0 || value>=st->mode->nbEBands)
1805             goto bad_arg;
1806          st->end = value;
1807       }
1808       break;
1809       case CELT_RESET_STATE:
1810       {
1811          const CELTMode *mode = st->mode;
1812          int C = st->channels;
1813
1814          CELT_MEMSET(st->decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1815          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1816
1817          CELT_MEMSET(st->preemph_memD, 0, C);
1818
1819          st->loss_count = 0;
1820
1821          CELT_MEMSET(st->lpc, 0, C*LPC_ORDER);
1822       }
1823       break;
1824       default:
1825          goto bad_request;
1826    }
1827    va_end(ap);
1828    return CELT_OK;
1829 bad_mode:
1830   va_end(ap);
1831   return CELT_INVALID_MODE;
1832 bad_arg:
1833    va_end(ap);
1834    return CELT_BAD_ARG;
1835 bad_request:
1836       va_end(ap);
1837   return CELT_UNIMPLEMENTED;
1838 }
1839
1840 const char *celt_strerror(int error)
1841 {
1842    static const char *error_strings[8] = {
1843       "success",
1844       "invalid argument",
1845       "invalid mode",
1846       "internal error",
1847       "corrupted stream",
1848       "request not implemented",
1849       "invalid state",
1850       "memory allocation failed"
1851    };
1852    if (error > 0 || error < -7)
1853       return "unknown error";
1854    else 
1855       return error_strings[-error];
1856 }
1857