Simplifies the implementation of RESET by placing all the data that needs
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2010 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int overlap;
74    int channels;
75    
76    int force_intra;
77    int start, end;
78
79    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
80
81    /* Everything beyond this point gets cleared on a reset */
82 #define ENCODER_RESET_START frame_max
83
84    celt_word32 frame_max;
85    int fold_decision;
86    int delayedIntra;
87    celt_word16 tonal_average;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_word32 preemph_memE[2];
96    celt_word32 preemph_memD[2];
97
98    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
99    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
100    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
101 };
102
103 int celt_encoder_get_size(const CELTMode *mode, int channels)
104 {
105    int size = sizeof(struct CELTEncoder)
106          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
107          + channels*mode->nbEBands*sizeof(celt_word16);
108    return size;
109 }
110
111 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
112 {
113    return celt_encoder_init(
114          (CELTEncoder *)celt_alloc(celt_encoder_get_size(mode, channels)),
115          mode, channels, error);
116 }
117
118 CELTEncoder *celt_encoder_init(CELTEncoder *st, const CELTMode *mode, int channels, int *error)
119 {
120    if (channels < 0 || channels > 2)
121    {
122       celt_warning("Only mono and stereo supported");
123       if (error)
124          *error = CELT_BAD_ARG;
125       return NULL;
126    }
127
128    CELT_MEMSET((char*)st, 0, celt_encoder_get_size(mode, channels));
129    
130    if (st==NULL)
131    {
132       if (error)
133          *error = CELT_ALLOC_FAIL;
134       return NULL;
135    }
136    st->mode = mode;
137    st->overlap = mode->overlap;
138    st->channels = channels;
139
140    st->start = 0;
141    st->end = st->mode->effEBands;
142
143    st->vbr_rate_norm = 0;
144    st->force_intra  = 0;
145    st->delayedIntra = 1;
146    st->tonal_average = QCONST16(1.f,8);
147    st->fold_decision = 1;
148
149    if (error)
150       *error = CELT_OK;
151    return st;
152 }
153
154 void celt_encoder_destroy(CELTEncoder *st)
155 {
156    celt_free(st);
157 }
158
159 static inline celt_int16 FLOAT2INT16(float x)
160 {
161    x = x*CELT_SIG_SCALE;
162    x = MAX32(x, -32768);
163    x = MIN32(x, 32767);
164    return (celt_int16)float2int(x);
165 }
166
167 static inline celt_word16 SIG2WORD16(celt_sig x)
168 {
169 #ifdef FIXED_POINT
170    x = PSHR32(x, SIG_SHIFT);
171    x = MAX32(x, -32768);
172    x = MIN32(x, 32767);
173    return EXTRACT16(x);
174 #else
175    return (celt_word16)x;
176 #endif
177 }
178
179 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
180                               int *transient_time, int *transient_shift,
181                               celt_word32 *frame_max, int overlap)
182 {
183    int i, n;
184    celt_word32 ratio;
185    celt_word32 threshold;
186    VARDECL(celt_word32, begin);
187    SAVE_STACK;
188    ALLOC(begin, len+1, celt_word32);
189    begin[0] = 0;
190    if (C==1)
191    {
192       for (i=0;i<len;i++)
193          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
194    } else {
195       for (i=0;i<len;i++)
196          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
197                                             ABS32(in[C*i+1])));
198    }
199    n = -1;
200
201    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
202    /* If the following condition isn't met, there's just no way
203       we'll have a transient*/
204    if (*frame_max < threshold)
205    {
206       /* It's likely we have a transient, now find it */
207       for (i=8;i<len-8;i++)
208       {
209          if (begin[i+1] < threshold)
210             n=i;
211       }
212    }
213    if (n<32)
214    {
215       n = -1;
216       ratio = 0;
217    } else {
218       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
219    }
220
221    if (ratio > 45)
222       *transient_shift = 3;
223    else
224       *transient_shift = 0;
225    
226    *transient_time = n;
227    *frame_max = begin[len-overlap];
228
229    RESTORE_STACK;
230    return ratio > 0;
231 }
232
233 /** Apply window and compute the MDCT for all sub-frames and 
234     all channels in a frame */
235 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
236 {
237    const int C = CHANNELS(_C);
238    if (C==1 && !shortBlocks)
239    {
240       const int overlap = OVERLAP(mode);
241       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
242    } else {
243       const int overlap = OVERLAP(mode);
244       int N = mode->shortMdctSize<<LM;
245       int B = 1;
246       int b, c;
247       VARDECL(celt_word32, x);
248       VARDECL(celt_word32, tmp);
249       SAVE_STACK;
250       if (shortBlocks)
251       {
252          /*lookup = &mode->mdct[0];*/
253          N = mode->shortMdctSize;
254          B = shortBlocks;
255       }
256       ALLOC(x, N+overlap, celt_word32);
257       ALLOC(tmp, N, celt_word32);
258       for (c=0;c<C;c++)
259       {
260          for (b=0;b<B;b++)
261          {
262             int j;
263             for (j=0;j<N+overlap;j++)
264                x[j] = in[C*(b*N+j)+c];
265             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
266             /* Interleaving the sub-frames */
267             for (j=0;j<N;j++)
268                out[(j*B+b)+c*N*B] = tmp[j];
269          }
270       }
271       RESTORE_STACK;
272    }
273 }
274
275 /** Compute the IMDCT and apply window for all sub-frames and 
276     all channels in a frame */
277 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
278       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
279       celt_sig * restrict overlap_mem[], int _C, int LM)
280 {
281    int c;
282    const int C = CHANNELS(_C);
283    const int N = mode->shortMdctSize<<LM;
284    const int overlap = OVERLAP(mode);
285    for (c=0;c<C;c++)
286    {
287       int j;
288          VARDECL(celt_word32, x);
289          VARDECL(celt_word32, tmp);
290          int b;
291          int N2 = N;
292          int B = 1;
293          SAVE_STACK;
294          
295          ALLOC(x, N+overlap, celt_word32);
296          ALLOC(tmp, N, celt_word32);
297
298          if (shortBlocks)
299          {
300             N2 = mode->shortMdctSize;
301             B = shortBlocks;
302          }
303          /* Prevents problems from the imdct doing the overlap-add */
304          CELT_MEMSET(x, 0, overlap);
305
306          for (b=0;b<B;b++)
307          {
308             /* De-interleaving the sub-frames */
309             for (j=0;j<N2;j++)
310                tmp[j] = X[(j*B+b)+c*N2*B];
311             clt_mdct_backward(&mode->mdct, tmp, x+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
312          }
313
314          if (transient_shift > 0)
315          {
316 #ifdef FIXED_POINT
317             for (j=0;j<16;j++)
318                x[transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[transient_time+j-16],transient_shift));
319             for (j=transient_time;j<N+overlap;j++)
320                x[j] = SHL32(x[j], transient_shift);
321 #else
322             for (j=0;j<16;j++)
323                x[transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
324             for (j=transient_time;j<N+overlap;j++)
325                x[j] *= 1<<transient_shift;
326 #endif
327          }
328          for (j=0;j<overlap;j++)
329             out_mem[c][j] = x[j] + overlap_mem[c][j];
330          for (;j<N;j++)
331             out_mem[c][j] = x[j];
332          for (j=0;j<overlap;j++)
333             overlap_mem[c][j] = x[N+j];
334          RESTORE_STACK;
335    }
336 }
337
338 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
339 {
340    const int C = CHANNELS(_C);
341    int c;
342    for (c=0;c<C;c++)
343    {
344       int j;
345       celt_sig * restrict x;
346       celt_word16  * restrict y;
347       celt_sig m = mem[c];
348       x =in[c];
349       y = pcm+c;
350       for (j=0;j<N;j++)
351       {
352          celt_sig tmp = *x + m;
353          m = MULT16_32_Q15(coef[0], tmp)
354            - MULT16_32_Q15(coef[1], *x);
355          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
356          *y = SCALEOUT(SIG2WORD16(tmp));
357          x++;
358          y+=C;
359       }
360       mem[c] = m;
361    }
362 }
363
364 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
365                        int end, int N,
366                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
367 {
368    int m, i, c;
369    const int C = CHANNELS(_C);
370    for (c=0;c<C;c++)
371       for (m=start;m<end;m++)
372          for (i=m+c*N;i<(c+1)*N;i+=M)
373 #ifdef FIXED_POINT
374             X[i] = SHR16(X[i], mdct_weight_shift);
375 #else
376             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
377 #endif
378    if (renorm)
379       renormalise_bands(mode, X, end_band, C, M);
380 }
381
382 static signed char tf_select_table[4][8] = {
383       {0, -1, 0, -1,    0,-1, 0,-1},
384       {0, -1, 0, -2,    1, 0, 1 -1},
385       {0, -2, 0, -3,    2, 0, 1 -1},
386       {0, -2, 0, -3,    2, 0, 1 -1},
387 };
388
389 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
390 {
391    int i;
392    celt_word16 threshold;
393    VARDECL(celt_word16, metric);
394    celt_word32 average=0;
395    celt_word32 cost0;
396    celt_word32 cost1;
397    VARDECL(int, path0);
398    VARDECL(int, path1);
399    celt_word16 lambda;
400    int tf_select=0;
401    SAVE_STACK;
402
403    /* FIXME: Should check number of bytes *left* */
404    if (nbCompressedBytes<15*C)
405    {
406       for (i=0;i<len;i++)
407          tf_res[i] = 0;
408       return 0;
409    }
410    if (nbCompressedBytes<40)
411       lambda = QCONST16(5.f, DB_SHIFT);
412    else if (nbCompressedBytes<60)
413       lambda = QCONST16(2.f, DB_SHIFT);
414    else if (nbCompressedBytes<100)
415       lambda = QCONST16(1.f, DB_SHIFT);
416    else
417       lambda = QCONST16(.5f, DB_SHIFT);
418
419    ALLOC(metric, len, celt_word16);
420    ALLOC(path0, len, int);
421    ALLOC(path1, len, int);
422    for (i=0;i<len;i++)
423    {
424       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
425       average += metric[i];
426    }
427    if (C==2)
428    {
429       average = 0;
430       for (i=0;i<len;i++)
431       {
432          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
433          average += metric[i];
434       }
435    }
436    average = DIV32(average, len);
437    /*if (!isTransient)
438       printf ("%f\n", average);*/
439    if (isTransient)
440    {
441       threshold = QCONST16(1.f,DB_SHIFT);
442       tf_select = average > QCONST16(3.f,DB_SHIFT);
443    } else {
444       threshold = QCONST16(.5f,DB_SHIFT);
445       tf_select = average > QCONST16(1.f,DB_SHIFT);
446    }
447    cost0 = 0;
448    cost1 = lambda;
449    /* Viterbi forward pass */
450    for (i=1;i<len;i++)
451    {
452       celt_word32 curr0, curr1;
453       celt_word32 from0, from1;
454
455       from0 = cost0;
456       from1 = cost1 + lambda;
457       if (from0 < from1)
458       {
459          curr0 = from0;
460          path0[i]= 0;
461       } else {
462          curr0 = from1;
463          path0[i]= 1;
464       }
465
466       from0 = cost0 + lambda;
467       from1 = cost1;
468       if (from0 < from1)
469       {
470          curr1 = from0;
471          path1[i]= 0;
472       } else {
473          curr1 = from1;
474          path1[i]= 1;
475       }
476       cost0 = curr0 + (metric[i]-threshold);
477       cost1 = curr1;
478    }
479    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
480    /* Viterbi backward pass to check the decisions */
481    for (i=len-2;i>=0;i--)
482    {
483       if (tf_res[i+1] == 1)
484          tf_res[i] = path1[i+1];
485       else
486          tf_res[i] = path0[i+1];
487    }
488    RESTORE_STACK;
489    return tf_select;
490 }
491
492 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
493 {
494    int curr, i;
495    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
496    {
497       for (i=start;i<end;i++)
498          tf_res[i] = isTransient;
499    } else {
500       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
501       curr = tf_res[start];
502       for (i=start+1;i<end;i++)
503       {
504          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
505          curr = tf_res[i];
506       }
507    }
508    ec_enc_bits(enc, tf_select, 1);
509    for (i=start;i<end;i++)
510       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
511 }
512
513 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
514 {
515    int i, curr, tf_select;
516    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
517    {
518       for (i=start;i<end;i++)
519          tf_res[i] = isTransient;
520    } else {
521       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
522       curr = tf_res[start];
523       for (i=start+1;i<end;i++)
524       {
525          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
526          curr = tf_res[i];
527       }
528    }
529    tf_select = ec_dec_bits(dec, 1);
530    for (i=start;i<end;i++)
531       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
532 }
533
534 #ifdef FIXED_POINT
535 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
536 {
537 #else
538 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
539 {
540 #endif
541    int i, c, N, NN;
542    int bits;
543    int has_fold=1;
544    ec_byte_buffer buf;
545    ec_enc         _enc;
546    VARDECL(celt_sig, in);
547    VARDECL(celt_sig, freq);
548    VARDECL(celt_norm, X);
549    VARDECL(celt_ener, bandE);
550    VARDECL(celt_word16, bandLogE);
551    VARDECL(int, fine_quant);
552    VARDECL(celt_word16, error);
553    VARDECL(int, pulses);
554    VARDECL(int, offsets);
555    VARDECL(int, fine_priority);
556    VARDECL(int, tf_res);
557    celt_sig *_overlap_mem;
558    celt_word16 *oldBandE;
559    int shortBlocks=0;
560    int isTransient=0;
561    int transient_time, transient_time_quant;
562    int transient_shift;
563    int resynth;
564    const int C = CHANNELS(st->channels);
565    int mdct_weight_shift = 0;
566    int mdct_weight_pos=0;
567    int LM, M;
568    int tf_select;
569    int nbFilledBytes, nbAvailableBytes;
570    int effEnd;
571    SAVE_STACK;
572
573    if (nbCompressedBytes<0 || pcm==NULL)
574      return CELT_BAD_ARG;
575
576    for (LM=0;LM<4;LM++)
577       if (st->mode->shortMdctSize<<LM==frame_size)
578          break;
579    if (LM>=MAX_CONFIG_SIZES)
580       return CELT_BAD_ARG;
581    M=1<<LM;
582
583    _overlap_mem = st->in_mem+C*(st->overlap);
584    oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
585
586    if (enc==NULL)
587    {
588       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
589       ec_enc_init(&_enc,&buf);
590       enc = &_enc;
591       nbFilledBytes=0;
592    } else {
593       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
594    }
595    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
596
597    effEnd = st->end;
598    if (effEnd > st->mode->effEBands)
599       effEnd = st->mode->effEBands;
600
601    N = M*st->mode->shortMdctSize;
602    ALLOC(in, C*(N+st->overlap), celt_sig);
603
604    CELT_COPY(in, st->in_mem, C*st->overlap);
605    for (c=0;c<C;c++)
606    {
607       const celt_word16 * restrict pcmp = pcm+c;
608       celt_sig * restrict inp = in+C*st->overlap+c;
609       for (i=0;i<N;i++)
610       {
611          /* Apply pre-emphasis */
612          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
613          *inp = tmp + st->preemph_memE[c];
614          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
615                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
616          inp += C;
617          pcmp += C;
618       }
619    }
620    CELT_COPY(st->in_mem, in+C*N, C*st->overlap);
621
622    /* Transient handling */
623    transient_time = -1;
624    transient_time_quant = -1;
625    transient_shift = 0;
626    isTransient = 0;
627
628    resynth = optional_resynthesis!=NULL;
629
630    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
631    {
632 #ifndef FIXED_POINT
633       float gain_1;
634 #endif
635       /* Apply the inverse shaping window */
636       if (transient_shift)
637       {
638          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
639          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
640 #ifdef FIXED_POINT
641          for (c=0;c<C;c++)
642             for (i=0;i<16;i++)
643                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
644          for (c=0;c<C;c++)
645             for (i=transient_time;i<N+st->overlap;i++)
646                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
647 #else
648          for (c=0;c<C;c++)
649             for (i=0;i<16;i++)
650                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
651          gain_1 = 1.f/(1<<transient_shift);
652          for (c=0;c<C;c++)
653             for (i=transient_time;i<N+st->overlap;i++)
654                in[C*i+c] *= gain_1;
655 #endif
656       }
657       isTransient = 1;
658       has_fold = 1;
659    }
660
661    if (isTransient)
662       shortBlocks = M;
663    else
664       shortBlocks = 0;
665
666    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
667    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
668    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
669    /* Compute MDCTs */
670    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
671
672    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
673
674    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
675
676    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
677
678    /* Band normalisation */
679    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
680
681    NN = M*st->mode->eBands[effEnd];
682    if (shortBlocks && !transient_shift)
683    {
684       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
685       int m;
686       for (c=0;c<C;c++)
687       {
688          m=0;
689          do {
690             celt_word32 tmp=0;
691             for (i=m+c*N;i<c*N+NN;i+=M)
692                tmp += ABS32(X[i]);
693             sum[m++] += tmp;
694          } while (m<M);
695       }
696       m=0;
697 #ifdef FIXED_POINT
698       do {
699          if (SHR32(sum[m+1],3) > sum[m])
700          {
701             mdct_weight_shift=2;
702             mdct_weight_pos = m;
703          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
704          {
705             mdct_weight_shift=1;
706             mdct_weight_pos = m;
707          }
708          m++;
709       } while (m<M-1);
710 #else
711       do {
712          if (sum[m+1] > 8*sum[m])
713          {
714             mdct_weight_shift=2;
715             mdct_weight_pos = m;
716          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
717          {
718             mdct_weight_shift=1;
719             mdct_weight_pos = m;
720          }
721          m++;
722       } while (m<M-1);
723 #endif
724       if (mdct_weight_shift)
725          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
726    }
727
728    ALLOC(tf_res, st->mode->nbEBands, int);
729    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
730    tf_select = tf_analysis(bandLogE, oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
731    for (i=effEnd;i<st->end;i++)
732       tf_res[i] = tf_res[effEnd-1];
733
734    ALLOC(error, C*st->mode->nbEBands, celt_word16);
735    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
736          oldBandE, nbCompressedBytes*8, st->mode->prob,
737          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
738
739    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
740
741    if (shortBlocks)
742    {
743       if (transient_shift)
744       {
745          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
746          ec_enc_uint(enc, transient_shift, 4);
747          ec_enc_uint(enc, transient_time_quant, max_time);
748       } else {
749          ec_enc_uint(enc, mdct_weight_shift, 4);
750          if (mdct_weight_shift && M!=2)
751             ec_enc_uint(enc, mdct_weight_pos, M-1);
752       }
753    }
754
755    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
756
757    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
758       has_fold = 0;
759    ec_enc_bit_prob(enc, has_fold>>1, 8192);
760    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
761
762    /* Variable bitrate */
763    if (st->vbr_rate_norm>0)
764    {
765      celt_word16 alpha;
766      celt_int32 delta;
767      /* The target rate in 16th bits per frame */
768      celt_int32 vbr_rate;
769      celt_int32 target;
770      celt_int32 vbr_bound, max_allowed;
771
772      vbr_rate = M*st->vbr_rate_norm;
773
774      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
775      vbr_bound = vbr_rate;
776      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
777      if (max_allowed < 4)
778         max_allowed = 4;
779      if (max_allowed < nbAvailableBytes)
780         nbAvailableBytes = max_allowed;
781      target=vbr_rate;
782
783      /* Shortblocks get a large boost in bitrate, but since they 
784         are uncommon long blocks are not greatly effected */
785      if (shortBlocks)
786        target*=2;
787      else if (M > 1)
788        target-=(target+14)/28;
789
790      /* The average energy is removed from the target and the actual 
791         energy added*/
792      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
793
794      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
795      target=IMIN(nbAvailableBytes,target);
796      /* Make the adaptation coef (alpha) higher at the beginning */
797      if (st->vbr_count < 990)
798      {
799         st->vbr_count++;
800         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
801         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
802      } else
803         alpha = QCONST16(.001f,15);
804
805      /* By how much did we "miss" the target on that frame */
806      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
807      /* How many bits have we used in excess of what we're allowed */
808      st->vbr_reservoir += delta;
809      /*printf ("%d\n", st->vbr_reservoir);*/
810
811      /* Compute the offset we need to apply in order to reach the target */
812      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
813      st->vbr_offset = -st->vbr_drift;
814      /*printf ("%d\n", st->vbr_drift);*/
815
816      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
817      if (st->vbr_reservoir < 0)
818      {
819         /* We're under the min value -- increase rate */
820         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
821         st->vbr_reservoir += adjust*(8<<BITRES);
822         target += adjust;
823         /*printf ("+%d\n", adjust);*/
824      }
825      if (target < nbAvailableBytes)
826         nbAvailableBytes = target;
827      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
828
829      /* This moves the raw bits to take into account the new compressed size */
830      ec_byte_shrink(&buf, nbCompressedBytes);
831    }
832
833    /* Bit allocation */
834    ALLOC(fine_quant, st->mode->nbEBands, int);
835    ALLOC(pulses, st->mode->nbEBands, int);
836    ALLOC(offsets, st->mode->nbEBands, int);
837    ALLOC(fine_priority, st->mode->nbEBands, int);
838
839    for (i=0;i<st->mode->nbEBands;i++)
840       offsets[i] = 0;
841    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
842    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
843
844    quant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, enc, C);
845
846 #ifdef MEASURE_NORM_MSE
847    float X0[3000];
848    float bandE0[60];
849    for (c=0;c<C;c++)
850       for (i=0;i<N;i++)
851          X0[i+c*N] = X[i+c*N];
852    for (i=0;i<C*st->mode->nbEBands;i++)
853       bandE0[i] = bandE[i];
854 #endif
855
856    /* Residual quantisation */
857    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
858
859    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
860
861    /* Re-synthesis of the coded audio if required */
862    if (resynth)
863    {
864       VARDECL(celt_sig, _out_mem);
865       celt_sig *out_mem[2];
866       celt_sig *overlap_mem[2];
867
868       log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
869
870 #ifdef MEASURE_NORM_MSE
871       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
872 #endif
873
874       if (mdct_weight_shift)
875       {
876          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
877       }
878
879       /* Synthesis */
880       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
881
882       for (c=0;c<C;c++)
883          for (i=0;i<M*st->mode->eBands[st->start];i++)
884             freq[c*N+i] = 0;
885       for (c=0;c<C;c++)
886          for (i=M*st->mode->eBands[st->end];i<N;i++)
887             freq[c*N+i] = 0;
888
889       ALLOC(_out_mem, C*N, celt_sig);
890
891       for (c=0;c<C;c++)
892       {
893          overlap_mem[c] = _overlap_mem + c*st->overlap;
894          out_mem[c] = _out_mem+c*N;
895       }
896
897       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
898             transient_shift, out_mem, overlap_mem, C, LM);
899
900       /* De-emphasis and put everything back at the right place 
901          in the synthesis history */
902       if (optional_resynthesis != NULL) {
903          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
904
905       }
906    }
907
908    /* If there's any room left (can only happen for very high rates),
909       fill it with zeros */
910    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
911       ec_enc_bits(enc, 0, 8);
912    ec_enc_done(enc);
913    
914    RESTORE_STACK;
915    if (ec_enc_get_error(enc))
916       return CELT_CORRUPTED_DATA;
917    else
918       return nbCompressedBytes;
919 }
920
921 #ifdef FIXED_POINT
922 #ifndef DISABLE_FLOAT_API
923 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
924 {
925    int j, ret, C, N, LM, M;
926    VARDECL(celt_int16, in);
927    SAVE_STACK;
928
929    if (pcm==NULL)
930       return CELT_BAD_ARG;
931
932    for (LM=0;LM<4;LM++)
933       if (st->mode->shortMdctSize<<LM==frame_size)
934          break;
935    if (LM>=MAX_CONFIG_SIZES)
936       return CELT_BAD_ARG;
937    M=1<<LM;
938
939    C = CHANNELS(st->channels);
940    N = M*st->mode->shortMdctSize;
941    ALLOC(in, C*N, celt_int16);
942
943    for (j=0;j<C*N;j++)
944      in[j] = FLOAT2INT16(pcm[j]);
945
946    if (optional_resynthesis != NULL) {
947      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
948       for (j=0;j<C*N;j++)
949          optional_resynthesis[j]=in[j]*(1.f/32768.f);
950    } else {
951      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
952    }
953    RESTORE_STACK;
954    return ret;
955
956 }
957 #endif /*DISABLE_FLOAT_API*/
958 #else
959 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
960 {
961    int j, ret, C, N, LM, M;
962    VARDECL(celt_sig, in);
963    SAVE_STACK;
964
965    if (pcm==NULL)
966       return CELT_BAD_ARG;
967
968    for (LM=0;LM<4;LM++)
969       if (st->mode->shortMdctSize<<LM==frame_size)
970          break;
971    if (LM>=MAX_CONFIG_SIZES)
972       return CELT_BAD_ARG;
973    M=1<<LM;
974
975    C=CHANNELS(st->channels);
976    N=M*st->mode->shortMdctSize;
977    ALLOC(in, C*N, celt_sig);
978    for (j=0;j<C*N;j++) {
979      in[j] = SCALEOUT(pcm[j]);
980    }
981
982    if (optional_resynthesis != NULL) {
983       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
984       for (j=0;j<C*N;j++)
985          optional_resynthesis[j] = FLOAT2INT16(in[j]);
986    } else {
987       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
988    }
989    RESTORE_STACK;
990    return ret;
991 }
992 #endif
993
994 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
995 {
996    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
997 }
998
999 #ifndef DISABLE_FLOAT_API
1000 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1001 {
1002    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1003 }
1004 #endif /* DISABLE_FLOAT_API */
1005
1006 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1007 {
1008    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1009 }
1010
1011 #ifndef DISABLE_FLOAT_API
1012 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1013 {
1014    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1015 }
1016 #endif /* DISABLE_FLOAT_API */
1017
1018
1019 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1020 {
1021    va_list ap;
1022    
1023    va_start(ap, request);
1024    switch (request)
1025    {
1026       case CELT_GET_MODE_REQUEST:
1027       {
1028          const CELTMode ** value = va_arg(ap, const CELTMode**);
1029          if (value==0)
1030             goto bad_arg;
1031          *value=st->mode;
1032       }
1033       break;
1034       case CELT_SET_COMPLEXITY_REQUEST:
1035       {
1036          int value = va_arg(ap, celt_int32);
1037          if (value<0 || value>10)
1038             goto bad_arg;
1039       }
1040       break;
1041       case CELT_SET_START_BAND_REQUEST:
1042       {
1043          celt_int32 value = va_arg(ap, celt_int32);
1044          if (value<0 || value>=st->mode->nbEBands)
1045             goto bad_arg;
1046          st->start = value;
1047       }
1048       break;
1049       case CELT_SET_END_BAND_REQUEST:
1050       {
1051          celt_int32 value = va_arg(ap, celt_int32);
1052          if (value<0 || value>=st->mode->nbEBands)
1053             goto bad_arg;
1054          st->end = value;
1055       }
1056       break;
1057       case CELT_SET_PREDICTION_REQUEST:
1058       {
1059          int value = va_arg(ap, celt_int32);
1060          if (value<0 || value>2)
1061             goto bad_arg;
1062          if (value==0)
1063          {
1064             st->force_intra   = 1;
1065          } else if (value==1) {
1066             st->force_intra   = 0;
1067          } else {
1068             st->force_intra   = 0;
1069          }   
1070       }
1071       break;
1072       case CELT_SET_VBR_RATE_REQUEST:
1073       {
1074          celt_int32 value = va_arg(ap, celt_int32);
1075          int frame_rate;
1076          int N = st->mode->shortMdctSize;
1077          if (value<0)
1078             goto bad_arg;
1079          if (value>3072000)
1080             value = 3072000;
1081          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1082          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1083       }
1084       break;
1085       case CELT_RESET_STATE:
1086       {
1087          CELT_MEMSET((char*)&st->ENCODER_RESET_START, 0,
1088                celt_encoder_get_size(st->mode, st->channels)-
1089                ((char*)&st->ENCODER_RESET_START - (char*)st));
1090          st->delayedIntra = 1;
1091          st->fold_decision = 1;
1092          st->tonal_average = QCONST16(1.f,8);
1093       }
1094       break;
1095       default:
1096          goto bad_request;
1097    }
1098    va_end(ap);
1099    return CELT_OK;
1100 bad_arg:
1101    va_end(ap);
1102    return CELT_BAD_ARG;
1103 bad_request:
1104    va_end(ap);
1105    return CELT_UNIMPLEMENTED;
1106 }
1107
1108 /**********************************************************************/
1109 /*                                                                    */
1110 /*                             DECODER                                */
1111 /*                                                                    */
1112 /**********************************************************************/
1113 #define DECODE_BUFFER_SIZE 2048
1114
1115 /** Decoder state 
1116  @brief Decoder state
1117  */
1118 struct CELTDecoder {
1119    const CELTMode *mode;
1120    int overlap;
1121    int channels;
1122
1123    int start, end;
1124
1125    /* Everything beyond this point gets cleared on a reset */
1126 #define DECODER_RESET_START last_pitch_index
1127
1128    int last_pitch_index;
1129    int loss_count;
1130
1131    celt_sig preemph_memD[2];
1132    
1133    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
1134    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
1135    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
1136 };
1137
1138 int celt_decoder_get_size(const CELTMode *mode, int channels)
1139 {
1140    int size = sizeof(struct CELTDecoder)
1141             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
1142             + channels*LPC_ORDER*sizeof(celt_word16)
1143             + channels*mode->nbEBands*sizeof(celt_word16);
1144    return size;
1145 }
1146
1147 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1148 {
1149    return celt_decoder_init(
1150          (CELTDecoder *)celt_alloc(celt_decoder_get_size(mode, channels)),
1151          mode, channels, error);
1152 }
1153
1154 CELTDecoder *celt_decoder_init(CELTDecoder *st, const CELTMode *mode, int channels, int *error)
1155 {
1156    if (channels < 0 || channels > 2)
1157    {
1158       celt_warning("Only mono and stereo supported");
1159       if (error)
1160          *error = CELT_BAD_ARG;
1161       return NULL;
1162    }
1163
1164    CELT_MEMSET((char*)st, 0, celt_decoder_get_size(mode, channels));
1165
1166    if (st==NULL)
1167    {
1168       if (error)
1169          *error = CELT_ALLOC_FAIL;
1170       return NULL;
1171    }
1172
1173    st->mode = mode;
1174    st->overlap = mode->overlap;
1175    st->channels = channels;
1176
1177    st->start = 0;
1178    st->end = st->mode->effEBands;
1179
1180    st->loss_count = 0;
1181
1182    if (error)
1183       *error = CELT_OK;
1184    return st;
1185 }
1186
1187 void celt_decoder_destroy(CELTDecoder *st)
1188 {
1189    celt_free(st);
1190 }
1191
1192 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1193 {
1194    int c;
1195    int pitch_index;
1196    int overlap = st->mode->overlap;
1197    celt_word16 fade = Q15ONE;
1198    int i, len;
1199    const int C = CHANNELS(st->channels);
1200    int offset;
1201    celt_sig *out_mem[2];
1202    celt_sig *decode_mem[2];
1203    celt_sig *overlap_mem[2];
1204    celt_word16 *lpc;
1205    celt_word16 *oldBandE;
1206    SAVE_STACK;
1207    
1208    for (c=0;c<C;c++)
1209    {
1210       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1211       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1212       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1213    }
1214    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1215    oldBandE = lpc+C*LPC_ORDER;
1216
1217    len = N+st->mode->overlap;
1218    
1219    if (st->loss_count == 0)
1220    {
1221       celt_word16 pitch_buf[MAX_PERIOD>>1];
1222       celt_word32 tmp=0;
1223       celt_word32 mem0[2]={0,0};
1224       celt_word16 mem1[2]={0,0};
1225       int len2 = len;
1226       /* FIXME: This is a kludge */
1227       if (len2>MAX_PERIOD>>1)
1228          len2 = MAX_PERIOD>>1;
1229       pitch_downsample(out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1230                        C, mem0, mem1);
1231       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1232                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1233       pitch_index = MAX_PERIOD-len2-pitch_index;
1234       st->last_pitch_index = pitch_index;
1235    } else {
1236       pitch_index = st->last_pitch_index;
1237       if (st->loss_count < 5)
1238          fade = QCONST16(.8f,15);
1239       else
1240          fade = 0;
1241    }
1242
1243    for (c=0;c<C;c++)
1244    {
1245       /* FIXME: This is more memory than necessary */
1246       celt_word32 e[2*MAX_PERIOD];
1247       celt_word16 exc[2*MAX_PERIOD];
1248       celt_word32 ac[LPC_ORDER+1];
1249       celt_word16 decay = 1;
1250       celt_word32 S1=0;
1251       celt_word16 mem[LPC_ORDER]={0};
1252
1253       offset = MAX_PERIOD-pitch_index;
1254       for (i=0;i<MAX_PERIOD;i++)
1255          exc[i] = ROUND16(out_mem[c][i], SIG_SHIFT);
1256
1257       if (st->loss_count == 0)
1258       {
1259          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1260                         LPC_ORDER, MAX_PERIOD);
1261
1262          /* Noise floor -40 dB */
1263 #ifdef FIXED_POINT
1264          ac[0] += SHR32(ac[0],13);
1265 #else
1266          ac[0] *= 1.0001f;
1267 #endif
1268          /* Lag windowing */
1269          for (i=1;i<=LPC_ORDER;i++)
1270          {
1271             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1272 #ifdef FIXED_POINT
1273             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1274 #else
1275             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1276 #endif
1277          }
1278
1279          _celt_lpc(lpc+c*LPC_ORDER, ac, LPC_ORDER);
1280       }
1281       fir(exc, lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1282       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1283       /* Check if the waveform is decaying (and if so how fast) */
1284       {
1285          celt_word32 E1=1, E2=1;
1286          int period;
1287          if (pitch_index <= MAX_PERIOD/2)
1288             period = pitch_index;
1289          else
1290             period = MAX_PERIOD/2;
1291          for (i=0;i<period;i++)
1292          {
1293             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1294             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1295          }
1296          if (E1 > E2)
1297             E1 = E2;
1298          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1299       }
1300
1301       /* Copy excitation, taking decay into account */
1302       for (i=0;i<len+st->mode->overlap;i++)
1303       {
1304          if (offset+i >= MAX_PERIOD)
1305          {
1306             offset -= pitch_index;
1307             decay = MULT16_16_Q15(decay, decay);
1308          }
1309          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1310          S1 += SHR32(MULT16_16(out_mem[c][offset+i],out_mem[c][offset+i]),8);
1311       }
1312
1313       iir(e, lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1314
1315       {
1316          celt_word32 S2=0;
1317          for (i=0;i<len+overlap;i++)
1318             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1319          /* This checks for an "explosion" in the synthesis */
1320 #ifdef FIXED_POINT
1321          if (!(S1 > SHR32(S2,2)))
1322 #else
1323          /* Float test is written this way to catch NaNs at the same time */
1324          if (!(S1 > 0.2f*S2))
1325 #endif
1326          {
1327             for (i=0;i<len+overlap;i++)
1328                e[i] = 0;
1329          } else if (S1 < S2)
1330          {
1331             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1332             for (i=0;i<len+overlap;i++)
1333                e[i] = MULT16_16_Q15(ratio, e[i]);
1334          }
1335       }
1336
1337       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1338          out_mem[c][i] = out_mem[c][N+i];
1339
1340       /* Apply TDAC to the concealed audio so that it blends with the
1341          previous and next frames */
1342       for (i=0;i<overlap/2;i++)
1343       {
1344          celt_word32 tmp1, tmp2;
1345          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1346                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1347          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1348                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1349          tmp1 = MULT16_32_Q15(fade, tmp1);
1350          tmp2 = MULT16_32_Q15(fade, tmp2);
1351          out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1352          out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1353          out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1354          out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1355       }
1356       for (i=0;i<N-overlap;i++)
1357          out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1358    }
1359
1360    deemphasis(out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1361    
1362    st->loss_count++;
1363
1364    RESTORE_STACK;
1365 }
1366
1367 #ifdef FIXED_POINT
1368 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1369 {
1370 #else
1371 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1372 {
1373 #endif
1374    int c, i, N;
1375    int has_fold;
1376    int bits;
1377    ec_dec _dec;
1378    ec_byte_buffer buf;
1379    VARDECL(celt_sig, freq);
1380    VARDECL(celt_norm, X);
1381    VARDECL(celt_ener, bandE);
1382    VARDECL(int, fine_quant);
1383    VARDECL(int, pulses);
1384    VARDECL(int, offsets);
1385    VARDECL(int, fine_priority);
1386    VARDECL(int, tf_res);
1387    celt_sig *out_mem[2];
1388    celt_sig *decode_mem[2];
1389    celt_sig *overlap_mem[2];
1390    celt_sig *out_syn[2];
1391    celt_word16 *lpc;
1392    celt_word16 *oldBandE;
1393
1394    int shortBlocks;
1395    int isTransient;
1396    int intra_ener;
1397    int transient_time;
1398    int transient_shift;
1399    int mdct_weight_shift=0;
1400    const int C = CHANNELS(st->channels);
1401    int mdct_weight_pos=0;
1402    int LM, M;
1403    int nbFilledBytes, nbAvailableBytes;
1404    int effEnd;
1405    SAVE_STACK;
1406
1407    if (pcm==NULL)
1408       return CELT_BAD_ARG;
1409
1410    for (LM=0;LM<4;LM++)
1411       if (st->mode->shortMdctSize<<LM==frame_size)
1412          break;
1413    if (LM>=MAX_CONFIG_SIZES)
1414       return CELT_BAD_ARG;
1415    M=1<<LM;
1416
1417    for (c=0;c<C;c++)
1418    {
1419       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1420       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1421       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1422    }
1423    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1424    oldBandE = lpc+C*LPC_ORDER;
1425
1426    N = M*st->mode->shortMdctSize;
1427
1428    effEnd = st->end;
1429    if (effEnd > st->mode->effEBands)
1430       effEnd = st->mode->effEBands;
1431
1432    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1433    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1434    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1435    for (c=0;c<C;c++)
1436       for (i=0;i<M*st->mode->eBands[st->start];i++)
1437          X[c*N+i] = 0;
1438    for (c=0;c<C;c++)
1439       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1440          X[c*N+i] = 0;
1441
1442    if (data == NULL)
1443    {
1444       celt_decode_lost(st, pcm, N, LM);
1445       RESTORE_STACK;
1446       return CELT_OK;
1447    }
1448    if (len<0) {
1449      RESTORE_STACK;
1450      return CELT_BAD_ARG;
1451    }
1452    
1453    if (dec == NULL)
1454    {
1455       ec_byte_readinit(&buf,(unsigned char*)data,len);
1456       ec_dec_init(&_dec,&buf);
1457       dec = &_dec;
1458       nbFilledBytes = 0;
1459    } else {
1460       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1461    }
1462    nbAvailableBytes = len-nbFilledBytes;
1463
1464    /* Decode the global flags (first symbols in the stream) */
1465    intra_ener = ec_dec_bit_prob(dec, 8192);
1466    /* Get band energies */
1467    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
1468          intra_ener, st->mode->prob, dec, C, LM);
1469
1470    isTransient = ec_dec_bit_prob(dec, 8192);
1471
1472    if (isTransient)
1473       shortBlocks = M;
1474    else
1475       shortBlocks = 0;
1476
1477    if (isTransient)
1478    {
1479       transient_shift = ec_dec_uint(dec, 4);
1480       if (transient_shift == 3)
1481       {
1482          int transient_time_quant;
1483          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1484          transient_time_quant = ec_dec_uint(dec, max_time);
1485          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1486       } else {
1487          mdct_weight_shift = transient_shift;
1488          if (mdct_weight_shift && M>2)
1489             mdct_weight_pos = ec_dec_uint(dec, M-1);
1490          transient_shift = 0;
1491          transient_time = 0;
1492       }
1493    } else {
1494       transient_time = -1;
1495       transient_shift = 0;
1496    }
1497
1498    ALLOC(tf_res, st->mode->nbEBands, int);
1499    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1500
1501    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1502    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1503
1504    ALLOC(pulses, st->mode->nbEBands, int);
1505    ALLOC(offsets, st->mode->nbEBands, int);
1506    ALLOC(fine_priority, st->mode->nbEBands, int);
1507
1508    for (i=0;i<st->mode->nbEBands;i++)
1509       offsets[i] = 0;
1510
1511    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1512    ALLOC(fine_quant, st->mode->nbEBands, int);
1513    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1514    /*bits = ec_dec_tell(dec, 0);
1515    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1516    
1517    unquant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, fine_quant, dec, C);
1518
1519    /* Decode fixed codebook */
1520    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1521
1522    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
1523          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1524
1525    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
1526
1527    if (mdct_weight_shift)
1528    {
1529       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1530    }
1531
1532    /* Synthesis */
1533    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1534
1535    CELT_MOVE(decode_mem[0], decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1536    if (C==2)
1537       CELT_MOVE(decode_mem[1], decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1538
1539    for (c=0;c<C;c++)
1540       for (i=0;i<M*st->mode->eBands[st->start];i++)
1541          freq[c*N+i] = 0;
1542    for (c=0;c<C;c++)
1543       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1544          freq[c*N+i] = 0;
1545
1546    out_syn[0] = out_mem[0]+MAX_PERIOD-N;
1547    if (C==2)
1548       out_syn[1] = out_mem[1]+MAX_PERIOD-N;
1549
1550    /* Compute inverse MDCTs */
1551    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1552          transient_shift, out_syn, overlap_mem, C, LM);
1553
1554    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
1555    st->loss_count = 0;
1556    RESTORE_STACK;
1557    if (ec_dec_get_error(dec))
1558       return CELT_CORRUPTED_DATA;
1559    else
1560       return CELT_OK;
1561 }
1562
1563 #ifdef FIXED_POINT
1564 #ifndef DISABLE_FLOAT_API
1565 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1566 {
1567    int j, ret, C, N, LM, M;
1568    VARDECL(celt_int16, out);
1569    SAVE_STACK;
1570
1571    if (pcm==NULL)
1572       return CELT_BAD_ARG;
1573
1574    for (LM=0;LM<4;LM++)
1575       if (st->mode->shortMdctSize<<LM==frame_size)
1576          break;
1577    if (LM>=MAX_CONFIG_SIZES)
1578       return CELT_BAD_ARG;
1579    M=1<<LM;
1580
1581    C = CHANNELS(st->channels);
1582    N = M*st->mode->shortMdctSize;
1583    
1584    ALLOC(out, C*N, celt_int16);
1585    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1586    if (ret==0)
1587       for (j=0;j<C*N;j++)
1588          pcm[j]=out[j]*(1.f/32768.f);
1589      
1590    RESTORE_STACK;
1591    return ret;
1592 }
1593 #endif /*DISABLE_FLOAT_API*/
1594 #else
1595 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1596 {
1597    int j, ret, C, N, LM, M;
1598    VARDECL(celt_sig, out);
1599    SAVE_STACK;
1600
1601    if (pcm==NULL)
1602       return CELT_BAD_ARG;
1603
1604    for (LM=0;LM<4;LM++)
1605       if (st->mode->shortMdctSize<<LM==frame_size)
1606          break;
1607    if (LM>=MAX_CONFIG_SIZES)
1608       return CELT_BAD_ARG;
1609    M=1<<LM;
1610
1611    C = CHANNELS(st->channels);
1612    N = M*st->mode->shortMdctSize;
1613    ALLOC(out, C*N, celt_sig);
1614
1615    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1616
1617    if (ret==0)
1618       for (j=0;j<C*N;j++)
1619          pcm[j] = FLOAT2INT16 (out[j]);
1620    
1621    RESTORE_STACK;
1622    return ret;
1623 }
1624 #endif
1625
1626 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1627 {
1628    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1629 }
1630
1631 #ifndef DISABLE_FLOAT_API
1632 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1633 {
1634    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1635 }
1636 #endif /* DISABLE_FLOAT_API */
1637
1638 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1639 {
1640    va_list ap;
1641
1642    va_start(ap, request);
1643    switch (request)
1644    {
1645       case CELT_GET_MODE_REQUEST:
1646       {
1647          const CELTMode ** value = va_arg(ap, const CELTMode**);
1648          if (value==0)
1649             goto bad_arg;
1650          *value=st->mode;
1651       }
1652       break;
1653       case CELT_SET_START_BAND_REQUEST:
1654       {
1655          celt_int32 value = va_arg(ap, celt_int32);
1656          if (value<0 || value>=st->mode->nbEBands)
1657             goto bad_arg;
1658          st->start = value;
1659       }
1660       break;
1661       case CELT_SET_END_BAND_REQUEST:
1662       {
1663          celt_int32 value = va_arg(ap, celt_int32);
1664          if (value<0 || value>=st->mode->nbEBands)
1665             goto bad_arg;
1666          st->end = value;
1667       }
1668       break;
1669       case CELT_RESET_STATE:
1670       {
1671          CELT_MEMSET((char*)&st->DECODER_RESET_START, 0,
1672                celt_decoder_get_size(st->mode, st->channels)-
1673                ((char*)&st->DECODER_RESET_START - (char*)st));
1674       }
1675       break;
1676       default:
1677          goto bad_request;
1678    }
1679    va_end(ap);
1680    return CELT_OK;
1681 bad_arg:
1682    va_end(ap);
1683    return CELT_BAD_ARG;
1684 bad_request:
1685       va_end(ap);
1686   return CELT_UNIMPLEMENTED;
1687 }
1688
1689 const char *celt_strerror(int error)
1690 {
1691    static const char *error_strings[8] = {
1692       "success",
1693       "invalid argument",
1694       "invalid mode",
1695       "internal error",
1696       "corrupted stream",
1697       "request not implemented",
1698       "invalid state",
1699       "memory allocation failed"
1700    };
1701    if (error > 0 || error < -7)
1702       return "unknown error";
1703    else 
1704       return error_strings[-error];
1705 }
1706