Remove check_mode()
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2010 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int overlap;
74    int channels;
75    
76    int force_intra;
77    int delayedIntra;
78    celt_word16 tonal_average;
79    int fold_decision;
80    celt_word16 gain_prod;
81    celt_word32 frame_max;
82    int start, end;
83
84    /* VBR-related parameters */
85    celt_int32 vbr_reservoir;
86    celt_int32 vbr_drift;
87    celt_int32 vbr_offset;
88    celt_int32 vbr_count;
89
90    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
91    celt_word32 preemph_memE[2];
92    celt_word32 preemph_memD[2];
93
94    celt_sig in_mem[1];
95 };
96
97 int celt_encoder_get_size(const CELTMode *mode, int channels)
98 {
99    int size = sizeof(struct CELTEncoder)
100          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
101          + channels*mode->nbEBands*sizeof(celt_word16);
102    return size;
103 }
104
105 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
106 {
107    CELTEncoder *st;
108
109    if (channels < 0 || channels > 2)
110    {
111       celt_warning("Only mono and stereo supported");
112       if (error)
113          *error = CELT_BAD_ARG;
114       return NULL;
115    }
116
117    st = celt_alloc(celt_encoder_get_size(mode, channels));
118    
119    if (st==NULL)
120    {
121       if (error)
122          *error = CELT_ALLOC_FAIL;
123       return NULL;
124    }
125    st->mode = mode;
126    st->overlap = mode->overlap;
127    st->channels = channels;
128
129    st->start = 0;
130    st->end = st->mode->effEBands;
131
132    st->vbr_rate_norm = 0;
133    st->force_intra  = 0;
134    st->delayedIntra = 1;
135    st->tonal_average = QCONST16(1.f,8);
136    st->fold_decision = 1;
137
138    if (error)
139       *error = CELT_OK;
140    return st;
141 }
142
143 void celt_encoder_destroy(CELTEncoder *st)
144 {
145    celt_free(st);
146 }
147
148 static inline celt_int16 FLOAT2INT16(float x)
149 {
150    x = x*CELT_SIG_SCALE;
151    x = MAX32(x, -32768);
152    x = MIN32(x, 32767);
153    return (celt_int16)float2int(x);
154 }
155
156 static inline celt_word16 SIG2WORD16(celt_sig x)
157 {
158 #ifdef FIXED_POINT
159    x = PSHR32(x, SIG_SHIFT);
160    x = MAX32(x, -32768);
161    x = MIN32(x, 32767);
162    return EXTRACT16(x);
163 #else
164    return (celt_word16)x;
165 #endif
166 }
167
168 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
169                               int *transient_time, int *transient_shift,
170                               celt_word32 *frame_max, int overlap)
171 {
172    int i, n;
173    celt_word32 ratio;
174    celt_word32 threshold;
175    VARDECL(celt_word32, begin);
176    SAVE_STACK;
177    ALLOC(begin, len+1, celt_word32);
178    begin[0] = 0;
179    if (C==1)
180    {
181       for (i=0;i<len;i++)
182          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
183    } else {
184       for (i=0;i<len;i++)
185          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
186                                             ABS32(in[C*i+1])));
187    }
188    n = -1;
189
190    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
191    /* If the following condition isn't met, there's just no way
192       we'll have a transient*/
193    if (*frame_max < threshold)
194    {
195       /* It's likely we have a transient, now find it */
196       for (i=8;i<len-8;i++)
197       {
198          if (begin[i+1] < threshold)
199             n=i;
200       }
201    }
202    if (n<32)
203    {
204       n = -1;
205       ratio = 0;
206    } else {
207       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
208    }
209
210    if (ratio > 45)
211       *transient_shift = 3;
212    else
213       *transient_shift = 0;
214    
215    *transient_time = n;
216    *frame_max = begin[len-overlap];
217
218    RESTORE_STACK;
219    return ratio > 0;
220 }
221
222 /** Apply window and compute the MDCT for all sub-frames and 
223     all channels in a frame */
224 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
225 {
226    const int C = CHANNELS(_C);
227    if (C==1 && !shortBlocks)
228    {
229       const int overlap = OVERLAP(mode);
230       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
231    } else {
232       const int overlap = OVERLAP(mode);
233       int N = mode->shortMdctSize<<LM;
234       int B = 1;
235       int b, c;
236       VARDECL(celt_word32, x);
237       VARDECL(celt_word32, tmp);
238       SAVE_STACK;
239       if (shortBlocks)
240       {
241          /*lookup = &mode->mdct[0];*/
242          N = mode->shortMdctSize;
243          B = shortBlocks;
244       }
245       ALLOC(x, N+overlap, celt_word32);
246       ALLOC(tmp, N, celt_word32);
247       for (c=0;c<C;c++)
248       {
249          for (b=0;b<B;b++)
250          {
251             int j;
252             for (j=0;j<N+overlap;j++)
253                x[j] = in[C*(b*N+j)+c];
254             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
255             /* Interleaving the sub-frames */
256             for (j=0;j<N;j++)
257                out[(j*B+b)+c*N*B] = tmp[j];
258          }
259       }
260       RESTORE_STACK;
261    }
262 }
263
264 /** Compute the IMDCT and apply window for all sub-frames and 
265     all channels in a frame */
266 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
267       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
268       celt_sig * restrict overlap_mem[], int _C, int LM)
269 {
270    int c;
271    const int C = CHANNELS(_C);
272    const int N = mode->shortMdctSize<<LM;
273    const int overlap = OVERLAP(mode);
274    for (c=0;c<C;c++)
275    {
276       int j;
277          VARDECL(celt_word32, x);
278          VARDECL(celt_word32, tmp);
279          int b;
280          int N2 = N;
281          int B = 1;
282          SAVE_STACK;
283          
284          ALLOC(x, N+overlap, celt_word32);
285          ALLOC(tmp, N, celt_word32);
286
287          if (shortBlocks)
288          {
289             N2 = mode->shortMdctSize;
290             B = shortBlocks;
291          }
292          /* Prevents problems from the imdct doing the overlap-add */
293          CELT_MEMSET(x, 0, overlap);
294
295          for (b=0;b<B;b++)
296          {
297             /* De-interleaving the sub-frames */
298             for (j=0;j<N2;j++)
299                tmp[j] = X[(j*B+b)+c*N2*B];
300             clt_mdct_backward(&mode->mdct, tmp, x+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
301          }
302
303          if (transient_shift > 0)
304          {
305 #ifdef FIXED_POINT
306             for (j=0;j<16;j++)
307                x[transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[transient_time+j-16],transient_shift));
308             for (j=transient_time;j<N+overlap;j++)
309                x[j] = SHL32(x[j], transient_shift);
310 #else
311             for (j=0;j<16;j++)
312                x[transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
313             for (j=transient_time;j<N+overlap;j++)
314                x[j] *= 1<<transient_shift;
315 #endif
316          }
317          for (j=0;j<overlap;j++)
318             out_mem[c][j] = x[j] + overlap_mem[c][j];
319          for (;j<N;j++)
320             out_mem[c][j] = x[j];
321          for (j=0;j<overlap;j++)
322             overlap_mem[c][j] = x[N+j];
323          RESTORE_STACK;
324    }
325 }
326
327 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
328 {
329    const int C = CHANNELS(_C);
330    int c;
331    for (c=0;c<C;c++)
332    {
333       int j;
334       celt_sig * restrict x;
335       celt_word16  * restrict y;
336       celt_sig m = mem[c];
337       x =in[c];
338       y = pcm+c;
339       for (j=0;j<N;j++)
340       {
341          celt_sig tmp = *x + m;
342          m = MULT16_32_Q15(coef[0], tmp)
343            - MULT16_32_Q15(coef[1], *x);
344          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
345          *y = SCALEOUT(SIG2WORD16(tmp));
346          x++;
347          y+=C;
348       }
349       mem[c] = m;
350    }
351 }
352
353 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
354                        int end, int N,
355                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
356 {
357    int m, i, c;
358    const int C = CHANNELS(_C);
359    for (c=0;c<C;c++)
360       for (m=start;m<end;m++)
361          for (i=m+c*N;i<(c+1)*N;i+=M)
362 #ifdef FIXED_POINT
363             X[i] = SHR16(X[i], mdct_weight_shift);
364 #else
365             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
366 #endif
367    if (renorm)
368       renormalise_bands(mode, X, end_band, C, M);
369 }
370
371 static signed char tf_select_table[4][8] = {
372       {0, -1, 0, -1,    0,-1, 0,-1},
373       {0, -1, 0, -2,    1, 0, 1 -1},
374       {0, -2, 0, -3,    2, 0, 1 -1},
375       {0, -2, 0, -3,    2, 0, 1 -1},
376 };
377
378 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
379 {
380    int i;
381    celt_word16 threshold;
382    VARDECL(celt_word16, metric);
383    celt_word32 average=0;
384    celt_word32 cost0;
385    celt_word32 cost1;
386    VARDECL(int, path0);
387    VARDECL(int, path1);
388    celt_word16 lambda;
389    int tf_select=0;
390    SAVE_STACK;
391
392    /* FIXME: Should check number of bytes *left* */
393    if (nbCompressedBytes<15*C)
394    {
395       for (i=0;i<len;i++)
396          tf_res[i] = 0;
397       return 0;
398    }
399    if (nbCompressedBytes<40)
400       lambda = QCONST16(5.f, DB_SHIFT);
401    else if (nbCompressedBytes<60)
402       lambda = QCONST16(2.f, DB_SHIFT);
403    else if (nbCompressedBytes<100)
404       lambda = QCONST16(1.f, DB_SHIFT);
405    else
406       lambda = QCONST16(.5f, DB_SHIFT);
407
408    ALLOC(metric, len, celt_word16);
409    ALLOC(path0, len, int);
410    ALLOC(path1, len, int);
411    for (i=0;i<len;i++)
412    {
413       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
414       average += metric[i];
415    }
416    if (C==2)
417    {
418       average = 0;
419       for (i=0;i<len;i++)
420       {
421          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
422          average += metric[i];
423       }
424    }
425    average = DIV32(average, len);
426    /*if (!isTransient)
427       printf ("%f\n", average);*/
428    if (isTransient)
429    {
430       threshold = QCONST16(1.f,DB_SHIFT);
431       tf_select = average > QCONST16(3.f,DB_SHIFT);
432    } else {
433       threshold = QCONST16(.5f,DB_SHIFT);
434       tf_select = average > QCONST16(1.f,DB_SHIFT);
435    }
436    cost0 = 0;
437    cost1 = lambda;
438    /* Viterbi forward pass */
439    for (i=1;i<len;i++)
440    {
441       celt_word32 curr0, curr1;
442       celt_word32 from0, from1;
443
444       from0 = cost0;
445       from1 = cost1 + lambda;
446       if (from0 < from1)
447       {
448          curr0 = from0;
449          path0[i]= 0;
450       } else {
451          curr0 = from1;
452          path0[i]= 1;
453       }
454
455       from0 = cost0 + lambda;
456       from1 = cost1;
457       if (from0 < from1)
458       {
459          curr1 = from0;
460          path1[i]= 0;
461       } else {
462          curr1 = from1;
463          path1[i]= 1;
464       }
465       cost0 = curr0 + (metric[i]-threshold);
466       cost1 = curr1;
467    }
468    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
469    /* Viterbi backward pass to check the decisions */
470    for (i=len-2;i>=0;i--)
471    {
472       if (tf_res[i+1] == 1)
473          tf_res[i] = path1[i+1];
474       else
475          tf_res[i] = path0[i+1];
476    }
477    RESTORE_STACK;
478    return tf_select;
479 }
480
481 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
482 {
483    int curr, i;
484    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
485    {
486       for (i=start;i<end;i++)
487          tf_res[i] = isTransient;
488    } else {
489       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
490       curr = tf_res[start];
491       for (i=start+1;i<end;i++)
492       {
493          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
494          curr = tf_res[i];
495       }
496    }
497    ec_enc_bits(enc, tf_select, 1);
498    for (i=start;i<end;i++)
499       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
500 }
501
502 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
503 {
504    int i, curr, tf_select;
505    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
506    {
507       for (i=start;i<end;i++)
508          tf_res[i] = isTransient;
509    } else {
510       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
511       curr = tf_res[start];
512       for (i=start+1;i<end;i++)
513       {
514          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
515          curr = tf_res[i];
516       }
517    }
518    tf_select = ec_dec_bits(dec, 1);
519    for (i=start;i<end;i++)
520       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
521 }
522
523 #ifdef FIXED_POINT
524 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
525 {
526 #else
527 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
528 {
529 #endif
530    int i, c, N, NN;
531    int bits;
532    int has_fold=1;
533    ec_byte_buffer buf;
534    ec_enc         _enc;
535    VARDECL(celt_sig, in);
536    VARDECL(celt_sig, freq);
537    VARDECL(celt_norm, X);
538    VARDECL(celt_ener, bandE);
539    VARDECL(celt_word16, bandLogE);
540    VARDECL(int, fine_quant);
541    VARDECL(celt_word16, error);
542    VARDECL(int, pulses);
543    VARDECL(int, offsets);
544    VARDECL(int, fine_priority);
545    VARDECL(int, tf_res);
546    celt_sig *_overlap_mem;
547    celt_word16 *oldBandE;
548    int shortBlocks=0;
549    int isTransient=0;
550    int transient_time, transient_time_quant;
551    int transient_shift;
552    int resynth;
553    const int C = CHANNELS(st->channels);
554    int mdct_weight_shift = 0;
555    int mdct_weight_pos=0;
556    int LM, M;
557    int tf_select;
558    int nbFilledBytes, nbAvailableBytes;
559    int effEnd;
560    SAVE_STACK;
561
562    if (nbCompressedBytes<0 || pcm==NULL)
563      return CELT_BAD_ARG;
564
565    for (LM=0;LM<4;LM++)
566       if (st->mode->shortMdctSize<<LM==frame_size)
567          break;
568    if (LM>=MAX_CONFIG_SIZES)
569       return CELT_BAD_ARG;
570    M=1<<LM;
571
572    _overlap_mem = st->in_mem+C*(st->overlap);
573    oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
574
575    if (enc==NULL)
576    {
577       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
578       ec_enc_init(&_enc,&buf);
579       enc = &_enc;
580       nbFilledBytes=0;
581    } else {
582       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
583    }
584    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
585
586    effEnd = st->end;
587    if (effEnd > st->mode->effEBands)
588       effEnd = st->mode->effEBands;
589
590    N = M*st->mode->shortMdctSize;
591    ALLOC(in, C*(N+st->overlap), celt_sig);
592
593    CELT_COPY(in, st->in_mem, C*st->overlap);
594    for (c=0;c<C;c++)
595    {
596       const celt_word16 * restrict pcmp = pcm+c;
597       celt_sig * restrict inp = in+C*st->overlap+c;
598       for (i=0;i<N;i++)
599       {
600          /* Apply pre-emphasis */
601          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
602          *inp = tmp + st->preemph_memE[c];
603          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
604                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
605          inp += C;
606          pcmp += C;
607       }
608    }
609    CELT_COPY(st->in_mem, in+C*N, C*st->overlap);
610
611    /* Transient handling */
612    transient_time = -1;
613    transient_time_quant = -1;
614    transient_shift = 0;
615    isTransient = 0;
616
617    resynth = optional_resynthesis!=NULL;
618
619    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
620    {
621 #ifndef FIXED_POINT
622       float gain_1;
623 #endif
624       /* Apply the inverse shaping window */
625       if (transient_shift)
626       {
627          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
628          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
629 #ifdef FIXED_POINT
630          for (c=0;c<C;c++)
631             for (i=0;i<16;i++)
632                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
633          for (c=0;c<C;c++)
634             for (i=transient_time;i<N+st->overlap;i++)
635                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
636 #else
637          for (c=0;c<C;c++)
638             for (i=0;i<16;i++)
639                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
640          gain_1 = 1.f/(1<<transient_shift);
641          for (c=0;c<C;c++)
642             for (i=transient_time;i<N+st->overlap;i++)
643                in[C*i+c] *= gain_1;
644 #endif
645       }
646       isTransient = 1;
647       has_fold = 1;
648    }
649
650    if (isTransient)
651       shortBlocks = M;
652    else
653       shortBlocks = 0;
654
655    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
656    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
657    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
658    /* Compute MDCTs */
659    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
660
661    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
662
663    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
664
665    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
666
667    /* Band normalisation */
668    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
669
670    NN = M*st->mode->eBands[effEnd];
671    if (shortBlocks && !transient_shift)
672    {
673       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
674       int m;
675       for (c=0;c<C;c++)
676       {
677          m=0;
678          do {
679             celt_word32 tmp=0;
680             for (i=m+c*N;i<c*N+NN;i+=M)
681                tmp += ABS32(X[i]);
682             sum[m++] += tmp;
683          } while (m<M);
684       }
685       m=0;
686 #ifdef FIXED_POINT
687       do {
688          if (SHR32(sum[m+1],3) > sum[m])
689          {
690             mdct_weight_shift=2;
691             mdct_weight_pos = m;
692          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
693          {
694             mdct_weight_shift=1;
695             mdct_weight_pos = m;
696          }
697          m++;
698       } while (m<M-1);
699 #else
700       do {
701          if (sum[m+1] > 8*sum[m])
702          {
703             mdct_weight_shift=2;
704             mdct_weight_pos = m;
705          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
706          {
707             mdct_weight_shift=1;
708             mdct_weight_pos = m;
709          }
710          m++;
711       } while (m<M-1);
712 #endif
713       if (mdct_weight_shift)
714          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
715    }
716
717    ALLOC(tf_res, st->mode->nbEBands, int);
718    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
719    tf_select = tf_analysis(bandLogE, oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
720    for (i=effEnd;i<st->end;i++)
721       tf_res[i] = tf_res[effEnd-1];
722
723    ALLOC(error, C*st->mode->nbEBands, celt_word16);
724    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
725          oldBandE, nbCompressedBytes*8, st->mode->prob,
726          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
727
728    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
729
730    if (shortBlocks)
731    {
732       if (transient_shift)
733       {
734          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
735          ec_enc_uint(enc, transient_shift, 4);
736          ec_enc_uint(enc, transient_time_quant, max_time);
737       } else {
738          ec_enc_uint(enc, mdct_weight_shift, 4);
739          if (mdct_weight_shift && M!=2)
740             ec_enc_uint(enc, mdct_weight_pos, M-1);
741       }
742    }
743
744    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
745
746    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
747       has_fold = 0;
748    ec_enc_bit_prob(enc, has_fold>>1, 8192);
749    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
750
751    /* Variable bitrate */
752    if (st->vbr_rate_norm>0)
753    {
754      celt_word16 alpha;
755      celt_int32 delta;
756      /* The target rate in 16th bits per frame */
757      celt_int32 vbr_rate;
758      celt_int32 target;
759      celt_int32 vbr_bound, max_allowed;
760
761      vbr_rate = M*st->vbr_rate_norm;
762
763      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
764      vbr_bound = vbr_rate;
765      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
766      if (max_allowed < 4)
767         max_allowed = 4;
768      if (max_allowed < nbAvailableBytes)
769         nbAvailableBytes = max_allowed;
770      target=vbr_rate;
771
772      /* Shortblocks get a large boost in bitrate, but since they 
773         are uncommon long blocks are not greatly effected */
774      if (shortBlocks)
775        target*=2;
776      else if (M > 1)
777        target-=(target+14)/28;
778
779      /* The average energy is removed from the target and the actual 
780         energy added*/
781      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
782
783      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
784      target=IMIN(nbAvailableBytes,target);
785      /* Make the adaptation coef (alpha) higher at the beginning */
786      if (st->vbr_count < 990)
787      {
788         st->vbr_count++;
789         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
790         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
791      } else
792         alpha = QCONST16(.001f,15);
793
794      /* By how much did we "miss" the target on that frame */
795      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
796      /* How many bits have we used in excess of what we're allowed */
797      st->vbr_reservoir += delta;
798      /*printf ("%d\n", st->vbr_reservoir);*/
799
800      /* Compute the offset we need to apply in order to reach the target */
801      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
802      st->vbr_offset = -st->vbr_drift;
803      /*printf ("%d\n", st->vbr_drift);*/
804
805      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
806      if (st->vbr_reservoir < 0)
807      {
808         /* We're under the min value -- increase rate */
809         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
810         st->vbr_reservoir += adjust*(8<<BITRES);
811         target += adjust;
812         /*printf ("+%d\n", adjust);*/
813      }
814      if (target < nbAvailableBytes)
815         nbAvailableBytes = target;
816      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
817
818      /* This moves the raw bits to take into account the new compressed size */
819      ec_byte_shrink(&buf, nbCompressedBytes);
820    }
821
822    /* Bit allocation */
823    ALLOC(fine_quant, st->mode->nbEBands, int);
824    ALLOC(pulses, st->mode->nbEBands, int);
825    ALLOC(offsets, st->mode->nbEBands, int);
826    ALLOC(fine_priority, st->mode->nbEBands, int);
827
828    for (i=0;i<st->mode->nbEBands;i++)
829       offsets[i] = 0;
830    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
831    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
832
833    quant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, enc, C);
834
835 #ifdef MEASURE_NORM_MSE
836    float X0[3000];
837    float bandE0[60];
838    for (c=0;c<C;c++)
839       for (i=0;i<N;i++)
840          X0[i+c*N] = X[i+c*N];
841    for (i=0;i<C*st->mode->nbEBands;i++)
842       bandE0[i] = bandE[i];
843 #endif
844
845    /* Residual quantisation */
846    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
847
848    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
849
850    /* Re-synthesis of the coded audio if required */
851    if (resynth)
852    {
853       VARDECL(celt_sig, _out_mem);
854       celt_sig *out_mem[2];
855       celt_sig *overlap_mem[2];
856
857       log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
858
859 #ifdef MEASURE_NORM_MSE
860       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
861 #endif
862
863       if (mdct_weight_shift)
864       {
865          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
866       }
867
868       /* Synthesis */
869       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
870
871       for (c=0;c<C;c++)
872          for (i=0;i<M*st->mode->eBands[st->start];i++)
873             freq[c*N+i] = 0;
874       for (c=0;c<C;c++)
875          for (i=M*st->mode->eBands[st->end];i<N;i++)
876             freq[c*N+i] = 0;
877
878       ALLOC(_out_mem, C*N, celt_sig);
879
880       for (c=0;c<C;c++)
881       {
882          overlap_mem[c] = _overlap_mem + c*st->overlap;
883          out_mem[c] = _out_mem+c*N;
884       }
885
886       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
887             transient_shift, out_mem, overlap_mem, C, LM);
888
889       /* De-emphasis and put everything back at the right place 
890          in the synthesis history */
891       if (optional_resynthesis != NULL) {
892          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
893
894       }
895    }
896
897    /* If there's any room left (can only happen for very high rates),
898       fill it with zeros */
899    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
900       ec_enc_bits(enc, 0, 8);
901    ec_enc_done(enc);
902    
903    RESTORE_STACK;
904    if (ec_enc_get_error(enc))
905       return CELT_CORRUPTED_DATA;
906    else
907       return nbCompressedBytes;
908 }
909
910 #ifdef FIXED_POINT
911 #ifndef DISABLE_FLOAT_API
912 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
913 {
914    int j, ret, C, N, LM, M;
915    VARDECL(celt_int16, in);
916    SAVE_STACK;
917
918    if (pcm==NULL)
919       return CELT_BAD_ARG;
920
921    for (LM=0;LM<4;LM++)
922       if (st->mode->shortMdctSize<<LM==frame_size)
923          break;
924    if (LM>=MAX_CONFIG_SIZES)
925       return CELT_BAD_ARG;
926    M=1<<LM;
927
928    C = CHANNELS(st->channels);
929    N = M*st->mode->shortMdctSize;
930    ALLOC(in, C*N, celt_int16);
931
932    for (j=0;j<C*N;j++)
933      in[j] = FLOAT2INT16(pcm[j]);
934
935    if (optional_resynthesis != NULL) {
936      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
937       for (j=0;j<C*N;j++)
938          optional_resynthesis[j]=in[j]*(1.f/32768.f);
939    } else {
940      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
941    }
942    RESTORE_STACK;
943    return ret;
944
945 }
946 #endif /*DISABLE_FLOAT_API*/
947 #else
948 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
949 {
950    int j, ret, C, N, LM, M;
951    VARDECL(celt_sig, in);
952    SAVE_STACK;
953
954    if (pcm==NULL)
955       return CELT_BAD_ARG;
956
957    for (LM=0;LM<4;LM++)
958       if (st->mode->shortMdctSize<<LM==frame_size)
959          break;
960    if (LM>=MAX_CONFIG_SIZES)
961       return CELT_BAD_ARG;
962    M=1<<LM;
963
964    C=CHANNELS(st->channels);
965    N=M*st->mode->shortMdctSize;
966    ALLOC(in, C*N, celt_sig);
967    for (j=0;j<C*N;j++) {
968      in[j] = SCALEOUT(pcm[j]);
969    }
970
971    if (optional_resynthesis != NULL) {
972       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
973       for (j=0;j<C*N;j++)
974          optional_resynthesis[j] = FLOAT2INT16(in[j]);
975    } else {
976       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
977    }
978    RESTORE_STACK;
979    return ret;
980 }
981 #endif
982
983 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
984 {
985    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
986 }
987
988 #ifndef DISABLE_FLOAT_API
989 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
990 {
991    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
992 }
993 #endif /* DISABLE_FLOAT_API */
994
995 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
996 {
997    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
998 }
999
1000 #ifndef DISABLE_FLOAT_API
1001 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1002 {
1003    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1004 }
1005 #endif /* DISABLE_FLOAT_API */
1006
1007
1008 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1009 {
1010    va_list ap;
1011    
1012    va_start(ap, request);
1013    switch (request)
1014    {
1015       case CELT_GET_MODE_REQUEST:
1016       {
1017          const CELTMode ** value = va_arg(ap, const CELTMode**);
1018          if (value==0)
1019             goto bad_arg;
1020          *value=st->mode;
1021       }
1022       break;
1023       case CELT_SET_COMPLEXITY_REQUEST:
1024       {
1025          int value = va_arg(ap, celt_int32);
1026          if (value<0 || value>10)
1027             goto bad_arg;
1028       }
1029       break;
1030       case CELT_SET_START_BAND_REQUEST:
1031       {
1032          celt_int32 value = va_arg(ap, celt_int32);
1033          if (value<0 || value>=st->mode->nbEBands)
1034             goto bad_arg;
1035          st->start = value;
1036       }
1037       break;
1038       case CELT_SET_END_BAND_REQUEST:
1039       {
1040          celt_int32 value = va_arg(ap, celt_int32);
1041          if (value<0 || value>=st->mode->nbEBands)
1042             goto bad_arg;
1043          st->end = value;
1044       }
1045       break;
1046       case CELT_SET_PREDICTION_REQUEST:
1047       {
1048          int value = va_arg(ap, celt_int32);
1049          if (value<0 || value>2)
1050             goto bad_arg;
1051          if (value==0)
1052          {
1053             st->force_intra   = 1;
1054          } else if (value==1) {
1055             st->force_intra   = 0;
1056          } else {
1057             st->force_intra   = 0;
1058          }   
1059       }
1060       break;
1061       case CELT_SET_VBR_RATE_REQUEST:
1062       {
1063          celt_int32 value = va_arg(ap, celt_int32);
1064          int frame_rate;
1065          int N = st->mode->shortMdctSize;
1066          if (value<0)
1067             goto bad_arg;
1068          if (value>3072000)
1069             value = 3072000;
1070          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1071          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1072       }
1073       break;
1074       case CELT_RESET_STATE:
1075       {
1076          celt_sig *overlap_mem;
1077          celt_word16 *oldBandE;
1078          const CELTMode *mode = st->mode;
1079          int C = st->channels;
1080          overlap_mem = st->in_mem+C*(st->overlap);
1081          oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
1082
1083          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1084          CELT_MEMSET(overlap_mem, 0, st->overlap*C);
1085
1086          CELT_MEMSET(oldBandE, 0, C*mode->nbEBands);
1087
1088          CELT_MEMSET(st->preemph_memE, 0, C);
1089          CELT_MEMSET(st->preemph_memD, 0, C);
1090          st->delayedIntra = 1;
1091
1092          st->fold_decision = 1;
1093          st->tonal_average = QCONST16(1.f,8);
1094          st->gain_prod = 0;
1095          st->vbr_reservoir = 0;
1096          st->vbr_drift = 0;
1097          st->vbr_offset = 0;
1098          st->vbr_count = 0;
1099          st->frame_max = 0;
1100       }
1101       break;
1102       default:
1103          goto bad_request;
1104    }
1105    va_end(ap);
1106    return CELT_OK;
1107 bad_arg:
1108    va_end(ap);
1109    return CELT_BAD_ARG;
1110 bad_request:
1111    va_end(ap);
1112    return CELT_UNIMPLEMENTED;
1113 }
1114
1115 /**********************************************************************/
1116 /*                                                                    */
1117 /*                             DECODER                                */
1118 /*                                                                    */
1119 /**********************************************************************/
1120 #define DECODE_BUFFER_SIZE 2048
1121
1122 #define DECODERVALID   0x4c434454
1123 #define DECODERPARTIAL 0x5444434c
1124 #define DECODERFREED   0x4c004400
1125
1126 /** Decoder state 
1127  @brief Decoder state
1128  */
1129 struct CELTDecoder {
1130    const CELTMode *mode;
1131    int overlap;
1132    int channels;
1133
1134    int start, end;
1135    int last_pitch_index;
1136    int loss_count;
1137
1138    celt_sig preemph_memD[2];
1139    
1140    celt_sig _decode_mem[1];
1141 };
1142
1143 int celt_decoder_get_size(const CELTMode *mode, int channels)
1144 {
1145    int size = sizeof(struct CELTDecoder)
1146             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
1147             + channels*LPC_ORDER*sizeof(celt_word16)
1148             + channels*mode->nbEBands*sizeof(celt_word16);
1149    return size;
1150 }
1151
1152 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1153 {
1154    int C;
1155    CELTDecoder *st;
1156
1157    if (channels < 0 || channels > 2)
1158    {
1159       celt_warning("Only mono and stereo supported");
1160       if (error)
1161          *error = CELT_BAD_ARG;
1162       return NULL;
1163    }
1164
1165    C = CHANNELS(channels);
1166    st = celt_alloc(celt_decoder_get_size(mode, channels));
1167
1168    if (st==NULL)
1169    {
1170       if (error)
1171          *error = CELT_ALLOC_FAIL;
1172       return NULL;
1173    }
1174
1175    st->mode = mode;
1176    st->overlap = mode->overlap;
1177    st->channels = channels;
1178
1179    st->start = 0;
1180    st->end = st->mode->effEBands;
1181
1182    st->loss_count = 0;
1183
1184    if (error)
1185       *error = CELT_OK;
1186    return st;
1187 }
1188
1189 void celt_decoder_destroy(CELTDecoder *st)
1190 {
1191    celt_free(st);
1192 }
1193
1194 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1195 {
1196    int c;
1197    int pitch_index;
1198    int overlap = st->mode->overlap;
1199    celt_word16 fade = Q15ONE;
1200    int i, len;
1201    const int C = CHANNELS(st->channels);
1202    int offset;
1203    celt_sig *out_mem[2];
1204    celt_sig *decode_mem[2];
1205    celt_sig *overlap_mem[2];
1206    celt_word16 *lpc;
1207    celt_word16 *oldBandE;
1208    SAVE_STACK;
1209    
1210    for (c=0;c<C;c++)
1211    {
1212       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1213       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1214       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1215    }
1216    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1217    oldBandE = lpc+C*LPC_ORDER;
1218
1219    len = N+st->mode->overlap;
1220    
1221    if (st->loss_count == 0)
1222    {
1223       celt_word16 pitch_buf[MAX_PERIOD>>1];
1224       celt_word32 tmp=0;
1225       celt_word32 mem0[2]={0,0};
1226       celt_word16 mem1[2]={0,0};
1227       int len2 = len;
1228       /* FIXME: This is a kludge */
1229       if (len2>MAX_PERIOD>>1)
1230          len2 = MAX_PERIOD>>1;
1231       pitch_downsample(out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1232                        C, mem0, mem1);
1233       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1234                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1235       pitch_index = MAX_PERIOD-len2-pitch_index;
1236       st->last_pitch_index = pitch_index;
1237    } else {
1238       pitch_index = st->last_pitch_index;
1239       if (st->loss_count < 5)
1240          fade = QCONST16(.8f,15);
1241       else
1242          fade = 0;
1243    }
1244
1245    for (c=0;c<C;c++)
1246    {
1247       /* FIXME: This is more memory than necessary */
1248       celt_word32 e[2*MAX_PERIOD];
1249       celt_word16 exc[2*MAX_PERIOD];
1250       celt_word32 ac[LPC_ORDER+1];
1251       celt_word16 decay = 1;
1252       celt_word32 S1=0;
1253       celt_word16 mem[LPC_ORDER]={0};
1254
1255       offset = MAX_PERIOD-pitch_index;
1256       for (i=0;i<MAX_PERIOD;i++)
1257          exc[i] = ROUND16(out_mem[c][i], SIG_SHIFT);
1258
1259       if (st->loss_count == 0)
1260       {
1261          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1262                         LPC_ORDER, MAX_PERIOD);
1263
1264          /* Noise floor -40 dB */
1265 #ifdef FIXED_POINT
1266          ac[0] += SHR32(ac[0],13);
1267 #else
1268          ac[0] *= 1.0001f;
1269 #endif
1270          /* Lag windowing */
1271          for (i=1;i<=LPC_ORDER;i++)
1272          {
1273             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1274 #ifdef FIXED_POINT
1275             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1276 #else
1277             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1278 #endif
1279          }
1280
1281          _celt_lpc(lpc+c*LPC_ORDER, ac, LPC_ORDER);
1282       }
1283       fir(exc, lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1284       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1285       /* Check if the waveform is decaying (and if so how fast) */
1286       {
1287          celt_word32 E1=1, E2=1;
1288          int period;
1289          if (pitch_index <= MAX_PERIOD/2)
1290             period = pitch_index;
1291          else
1292             period = MAX_PERIOD/2;
1293          for (i=0;i<period;i++)
1294          {
1295             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1296             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1297          }
1298          if (E1 > E2)
1299             E1 = E2;
1300          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1301       }
1302
1303       /* Copy excitation, taking decay into account */
1304       for (i=0;i<len+st->mode->overlap;i++)
1305       {
1306          if (offset+i >= MAX_PERIOD)
1307          {
1308             offset -= pitch_index;
1309             decay = MULT16_16_Q15(decay, decay);
1310          }
1311          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1312          S1 += SHR32(MULT16_16(out_mem[c][offset+i],out_mem[c][offset+i]),8);
1313       }
1314
1315       iir(e, lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1316
1317       {
1318          celt_word32 S2=0;
1319          for (i=0;i<len+overlap;i++)
1320             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1321          /* This checks for an "explosion" in the synthesis */
1322 #ifdef FIXED_POINT
1323          if (!(S1 > SHR32(S2,2)))
1324 #else
1325          /* Float test is written this way to catch NaNs at the same time */
1326          if (!(S1 > 0.2f*S2))
1327 #endif
1328          {
1329             for (i=0;i<len+overlap;i++)
1330                e[i] = 0;
1331          } else if (S1 < S2)
1332          {
1333             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1334             for (i=0;i<len+overlap;i++)
1335                e[i] = MULT16_16_Q15(ratio, e[i]);
1336          }
1337       }
1338
1339       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1340          out_mem[c][i] = out_mem[c][N+i];
1341
1342       /* Apply TDAC to the concealed audio so that it blends with the
1343          previous and next frames */
1344       for (i=0;i<overlap/2;i++)
1345       {
1346          celt_word32 tmp1, tmp2;
1347          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1348                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1349          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1350                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1351          tmp1 = MULT16_32_Q15(fade, tmp1);
1352          tmp2 = MULT16_32_Q15(fade, tmp2);
1353          out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1354          out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1355          out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1356          out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1357       }
1358       for (i=0;i<N-overlap;i++)
1359          out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1360    }
1361
1362    deemphasis(out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1363    
1364    st->loss_count++;
1365
1366    RESTORE_STACK;
1367 }
1368
1369 #ifdef FIXED_POINT
1370 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1371 {
1372 #else
1373 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1374 {
1375 #endif
1376    int c, i, N;
1377    int has_fold;
1378    int bits;
1379    ec_dec _dec;
1380    ec_byte_buffer buf;
1381    VARDECL(celt_sig, freq);
1382    VARDECL(celt_norm, X);
1383    VARDECL(celt_ener, bandE);
1384    VARDECL(int, fine_quant);
1385    VARDECL(int, pulses);
1386    VARDECL(int, offsets);
1387    VARDECL(int, fine_priority);
1388    VARDECL(int, tf_res);
1389    celt_sig *out_mem[2];
1390    celt_sig *decode_mem[2];
1391    celt_sig *overlap_mem[2];
1392    celt_sig *out_syn[2];
1393    celt_word16 *lpc;
1394    celt_word16 *oldBandE;
1395
1396    int shortBlocks;
1397    int isTransient;
1398    int intra_ener;
1399    int transient_time;
1400    int transient_shift;
1401    int mdct_weight_shift=0;
1402    const int C = CHANNELS(st->channels);
1403    int mdct_weight_pos=0;
1404    int LM, M;
1405    int nbFilledBytes, nbAvailableBytes;
1406    int effEnd;
1407    SAVE_STACK;
1408
1409    if (pcm==NULL)
1410       return CELT_BAD_ARG;
1411
1412    for (LM=0;LM<4;LM++)
1413       if (st->mode->shortMdctSize<<LM==frame_size)
1414          break;
1415    if (LM>=MAX_CONFIG_SIZES)
1416       return CELT_BAD_ARG;
1417    M=1<<LM;
1418
1419    for (c=0;c<C;c++)
1420    {
1421       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1422       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1423       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1424    }
1425    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1426    oldBandE = lpc+C*LPC_ORDER;
1427
1428    N = M*st->mode->shortMdctSize;
1429
1430    effEnd = st->end;
1431    if (effEnd > st->mode->effEBands)
1432       effEnd = st->mode->effEBands;
1433
1434    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1435    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1436    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1437    for (c=0;c<C;c++)
1438       for (i=0;i<M*st->mode->eBands[st->start];i++)
1439          X[c*N+i] = 0;
1440    for (c=0;c<C;c++)
1441       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1442          X[c*N+i] = 0;
1443
1444    if (data == NULL)
1445    {
1446       celt_decode_lost(st, pcm, N, LM);
1447       RESTORE_STACK;
1448       return CELT_OK;
1449    }
1450    if (len<0) {
1451      RESTORE_STACK;
1452      return CELT_BAD_ARG;
1453    }
1454    
1455    if (dec == NULL)
1456    {
1457       ec_byte_readinit(&buf,(unsigned char*)data,len);
1458       ec_dec_init(&_dec,&buf);
1459       dec = &_dec;
1460       nbFilledBytes = 0;
1461    } else {
1462       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1463    }
1464    nbAvailableBytes = len-nbFilledBytes;
1465
1466    /* Decode the global flags (first symbols in the stream) */
1467    intra_ener = ec_dec_bit_prob(dec, 8192);
1468    /* Get band energies */
1469    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
1470          intra_ener, st->mode->prob, dec, C, LM);
1471
1472    isTransient = ec_dec_bit_prob(dec, 8192);
1473
1474    if (isTransient)
1475       shortBlocks = M;
1476    else
1477       shortBlocks = 0;
1478
1479    if (isTransient)
1480    {
1481       transient_shift = ec_dec_uint(dec, 4);
1482       if (transient_shift == 3)
1483       {
1484          int transient_time_quant;
1485          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1486          transient_time_quant = ec_dec_uint(dec, max_time);
1487          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1488       } else {
1489          mdct_weight_shift = transient_shift;
1490          if (mdct_weight_shift && M>2)
1491             mdct_weight_pos = ec_dec_uint(dec, M-1);
1492          transient_shift = 0;
1493          transient_time = 0;
1494       }
1495    } else {
1496       transient_time = -1;
1497       transient_shift = 0;
1498    }
1499
1500    ALLOC(tf_res, st->mode->nbEBands, int);
1501    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1502
1503    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1504    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1505
1506    ALLOC(pulses, st->mode->nbEBands, int);
1507    ALLOC(offsets, st->mode->nbEBands, int);
1508    ALLOC(fine_priority, st->mode->nbEBands, int);
1509
1510    for (i=0;i<st->mode->nbEBands;i++)
1511       offsets[i] = 0;
1512
1513    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1514    ALLOC(fine_quant, st->mode->nbEBands, int);
1515    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1516    /*bits = ec_dec_tell(dec, 0);
1517    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1518    
1519    unquant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, fine_quant, dec, C);
1520
1521    /* Decode fixed codebook */
1522    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1523
1524    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
1525          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1526
1527    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
1528
1529    if (mdct_weight_shift)
1530    {
1531       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1532    }
1533
1534    /* Synthesis */
1535    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1536
1537    CELT_MOVE(decode_mem[0], decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1538    if (C==2)
1539       CELT_MOVE(decode_mem[1], decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1540
1541    for (c=0;c<C;c++)
1542       for (i=0;i<M*st->mode->eBands[st->start];i++)
1543          freq[c*N+i] = 0;
1544    for (c=0;c<C;c++)
1545       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1546          freq[c*N+i] = 0;
1547
1548    out_syn[0] = out_mem[0]+MAX_PERIOD-N;
1549    if (C==2)
1550       out_syn[1] = out_mem[1]+MAX_PERIOD-N;
1551
1552    /* Compute inverse MDCTs */
1553    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1554          transient_shift, out_syn, overlap_mem, C, LM);
1555
1556    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
1557    st->loss_count = 0;
1558    RESTORE_STACK;
1559    if (ec_dec_get_error(dec))
1560       return CELT_CORRUPTED_DATA;
1561    else
1562       return CELT_OK;
1563 }
1564
1565 #ifdef FIXED_POINT
1566 #ifndef DISABLE_FLOAT_API
1567 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1568 {
1569    int j, ret, C, N, LM, M;
1570    VARDECL(celt_int16, out);
1571    SAVE_STACK;
1572
1573    if (pcm==NULL)
1574       return CELT_BAD_ARG;
1575
1576    for (LM=0;LM<4;LM++)
1577       if (st->mode->shortMdctSize<<LM==frame_size)
1578          break;
1579    if (LM>=MAX_CONFIG_SIZES)
1580       return CELT_BAD_ARG;
1581    M=1<<LM;
1582
1583    C = CHANNELS(st->channels);
1584    N = M*st->mode->shortMdctSize;
1585    
1586    ALLOC(out, C*N, celt_int16);
1587    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1588    if (ret==0)
1589       for (j=0;j<C*N;j++)
1590          pcm[j]=out[j]*(1.f/32768.f);
1591      
1592    RESTORE_STACK;
1593    return ret;
1594 }
1595 #endif /*DISABLE_FLOAT_API*/
1596 #else
1597 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1598 {
1599    int j, ret, C, N, LM, M;
1600    VARDECL(celt_sig, out);
1601    SAVE_STACK;
1602
1603    if (pcm==NULL)
1604       return CELT_BAD_ARG;
1605
1606    for (LM=0;LM<4;LM++)
1607       if (st->mode->shortMdctSize<<LM==frame_size)
1608          break;
1609    if (LM>=MAX_CONFIG_SIZES)
1610       return CELT_BAD_ARG;
1611    M=1<<LM;
1612
1613    C = CHANNELS(st->channels);
1614    N = M*st->mode->shortMdctSize;
1615    ALLOC(out, C*N, celt_sig);
1616
1617    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1618
1619    if (ret==0)
1620       for (j=0;j<C*N;j++)
1621          pcm[j] = FLOAT2INT16 (out[j]);
1622    
1623    RESTORE_STACK;
1624    return ret;
1625 }
1626 #endif
1627
1628 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1629 {
1630    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1631 }
1632
1633 #ifndef DISABLE_FLOAT_API
1634 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1635 {
1636    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1637 }
1638 #endif /* DISABLE_FLOAT_API */
1639
1640 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1641 {
1642    va_list ap;
1643
1644    va_start(ap, request);
1645    switch (request)
1646    {
1647       case CELT_GET_MODE_REQUEST:
1648       {
1649          const CELTMode ** value = va_arg(ap, const CELTMode**);
1650          if (value==0)
1651             goto bad_arg;
1652          *value=st->mode;
1653       }
1654       break;
1655       case CELT_SET_START_BAND_REQUEST:
1656       {
1657          celt_int32 value = va_arg(ap, celt_int32);
1658          if (value<0 || value>=st->mode->nbEBands)
1659             goto bad_arg;
1660          st->start = value;
1661       }
1662       break;
1663       case CELT_SET_END_BAND_REQUEST:
1664       {
1665          celt_int32 value = va_arg(ap, celt_int32);
1666          if (value<0 || value>=st->mode->nbEBands)
1667             goto bad_arg;
1668          st->end = value;
1669       }
1670       break;
1671       case CELT_RESET_STATE:
1672       {
1673          const CELTMode *mode = st->mode;
1674          int C = st->channels;
1675          celt_word16 *lpc;
1676          celt_word16 *oldBandE;
1677
1678          lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1679          oldBandE = lpc+C*LPC_ORDER;
1680
1681          CELT_MEMSET(st->_decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1682          CELT_MEMSET(oldBandE, 0, C*mode->nbEBands);
1683
1684          CELT_MEMSET(st->preemph_memD, 0, C);
1685
1686          st->loss_count = 0;
1687
1688          CELT_MEMSET(lpc, 0, C*LPC_ORDER);
1689       }
1690       break;
1691       default:
1692          goto bad_request;
1693    }
1694    va_end(ap);
1695    return CELT_OK;
1696 bad_arg:
1697    va_end(ap);
1698    return CELT_BAD_ARG;
1699 bad_request:
1700       va_end(ap);
1701   return CELT_UNIMPLEMENTED;
1702 }
1703
1704 const char *celt_strerror(int error)
1705 {
1706    static const char *error_strings[8] = {
1707       "success",
1708       "invalid argument",
1709       "invalid mode",
1710       "internal error",
1711       "corrupted stream",
1712       "request not implemented",
1713       "invalid state",
1714       "memory allocation failed"
1715    };
1716    if (error > 0 || error < -7)
1717       return "unknown error";
1718    else 
1719       return error_strings[-error];
1720 }
1721