c4c3831083bde249bbb0d376f47d160fbe6389e1
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2010 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int overlap;
74    int channels;
75    
76    int force_intra;
77    int delayedIntra;
78    celt_word16 tonal_average;
79    int fold_decision;
80    celt_word16 gain_prod;
81    celt_word32 frame_max;
82    int start, end;
83
84    /* VBR-related parameters */
85    celt_int32 vbr_reservoir;
86    celt_int32 vbr_drift;
87    celt_int32 vbr_offset;
88    celt_int32 vbr_count;
89
90    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
91    celt_word32 preemph_memE[2];
92    celt_word32 preemph_memD[2];
93
94    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
95    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
96    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
97 };
98
99 int celt_encoder_get_size(const CELTMode *mode, int channels)
100 {
101    int size = sizeof(struct CELTEncoder)
102          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
103          + channels*mode->nbEBands*sizeof(celt_word16);
104    return size;
105 }
106
107 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
108 {
109    return celt_encoder_init(
110          (CELTEncoder *)celt_alloc(celt_encoder_get_size(mode, channels)),
111          mode, channels, error);
112 }
113
114 CELTEncoder *celt_encoder_init(CELTEncoder *st, const CELTMode *mode, int channels, int *error)
115 {
116    if (channels < 0 || channels > 2)
117    {
118       celt_warning("Only mono and stereo supported");
119       if (error)
120          *error = CELT_BAD_ARG;
121       return NULL;
122    }
123
124    CELT_MEMSET((char*)st, 0, celt_encoder_get_size(mode, channels));
125    
126    if (st==NULL)
127    {
128       if (error)
129          *error = CELT_ALLOC_FAIL;
130       return NULL;
131    }
132    st->mode = mode;
133    st->overlap = mode->overlap;
134    st->channels = channels;
135
136    st->start = 0;
137    st->end = st->mode->effEBands;
138
139    st->vbr_rate_norm = 0;
140    st->force_intra  = 0;
141    st->delayedIntra = 1;
142    st->tonal_average = QCONST16(1.f,8);
143    st->fold_decision = 1;
144
145    if (error)
146       *error = CELT_OK;
147    return st;
148 }
149
150 void celt_encoder_destroy(CELTEncoder *st)
151 {
152    celt_free(st);
153 }
154
155 static inline celt_int16 FLOAT2INT16(float x)
156 {
157    x = x*CELT_SIG_SCALE;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16)float2int(x);
161 }
162
163 static inline celt_word16 SIG2WORD16(celt_sig x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16)x;
172 #endif
173 }
174
175 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
176                               int *transient_time, int *transient_shift,
177                               celt_word32 *frame_max, int overlap)
178 {
179    int i, n;
180    celt_word32 ratio;
181    celt_word32 threshold;
182    VARDECL(celt_word32, begin);
183    SAVE_STACK;
184    ALLOC(begin, len+1, celt_word32);
185    begin[0] = 0;
186    if (C==1)
187    {
188       for (i=0;i<len;i++)
189          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
190    } else {
191       for (i=0;i<len;i++)
192          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
193                                             ABS32(in[C*i+1])));
194    }
195    n = -1;
196
197    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
198    /* If the following condition isn't met, there's just no way
199       we'll have a transient*/
200    if (*frame_max < threshold)
201    {
202       /* It's likely we have a transient, now find it */
203       for (i=8;i<len-8;i++)
204       {
205          if (begin[i+1] < threshold)
206             n=i;
207       }
208    }
209    if (n<32)
210    {
211       n = -1;
212       ratio = 0;
213    } else {
214       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
215    }
216
217    if (ratio > 45)
218       *transient_shift = 3;
219    else
220       *transient_shift = 0;
221    
222    *transient_time = n;
223    *frame_max = begin[len-overlap];
224
225    RESTORE_STACK;
226    return ratio > 0;
227 }
228
229 /** Apply window and compute the MDCT for all sub-frames and 
230     all channels in a frame */
231 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
232 {
233    const int C = CHANNELS(_C);
234    if (C==1 && !shortBlocks)
235    {
236       const int overlap = OVERLAP(mode);
237       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
238    } else {
239       const int overlap = OVERLAP(mode);
240       int N = mode->shortMdctSize<<LM;
241       int B = 1;
242       int b, c;
243       VARDECL(celt_word32, x);
244       VARDECL(celt_word32, tmp);
245       SAVE_STACK;
246       if (shortBlocks)
247       {
248          /*lookup = &mode->mdct[0];*/
249          N = mode->shortMdctSize;
250          B = shortBlocks;
251       }
252       ALLOC(x, N+overlap, celt_word32);
253       ALLOC(tmp, N, celt_word32);
254       for (c=0;c<C;c++)
255       {
256          for (b=0;b<B;b++)
257          {
258             int j;
259             for (j=0;j<N+overlap;j++)
260                x[j] = in[C*(b*N+j)+c];
261             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
262             /* Interleaving the sub-frames */
263             for (j=0;j<N;j++)
264                out[(j*B+b)+c*N*B] = tmp[j];
265          }
266       }
267       RESTORE_STACK;
268    }
269 }
270
271 /** Compute the IMDCT and apply window for all sub-frames and 
272     all channels in a frame */
273 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
274       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
275       celt_sig * restrict overlap_mem[], int _C, int LM)
276 {
277    int c;
278    const int C = CHANNELS(_C);
279    const int N = mode->shortMdctSize<<LM;
280    const int overlap = OVERLAP(mode);
281    for (c=0;c<C;c++)
282    {
283       int j;
284          VARDECL(celt_word32, x);
285          VARDECL(celt_word32, tmp);
286          int b;
287          int N2 = N;
288          int B = 1;
289          SAVE_STACK;
290          
291          ALLOC(x, N+overlap, celt_word32);
292          ALLOC(tmp, N, celt_word32);
293
294          if (shortBlocks)
295          {
296             N2 = mode->shortMdctSize;
297             B = shortBlocks;
298          }
299          /* Prevents problems from the imdct doing the overlap-add */
300          CELT_MEMSET(x, 0, overlap);
301
302          for (b=0;b<B;b++)
303          {
304             /* De-interleaving the sub-frames */
305             for (j=0;j<N2;j++)
306                tmp[j] = X[(j*B+b)+c*N2*B];
307             clt_mdct_backward(&mode->mdct, tmp, x+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
308          }
309
310          if (transient_shift > 0)
311          {
312 #ifdef FIXED_POINT
313             for (j=0;j<16;j++)
314                x[transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[transient_time+j-16],transient_shift));
315             for (j=transient_time;j<N+overlap;j++)
316                x[j] = SHL32(x[j], transient_shift);
317 #else
318             for (j=0;j<16;j++)
319                x[transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
320             for (j=transient_time;j<N+overlap;j++)
321                x[j] *= 1<<transient_shift;
322 #endif
323          }
324          for (j=0;j<overlap;j++)
325             out_mem[c][j] = x[j] + overlap_mem[c][j];
326          for (;j<N;j++)
327             out_mem[c][j] = x[j];
328          for (j=0;j<overlap;j++)
329             overlap_mem[c][j] = x[N+j];
330          RESTORE_STACK;
331    }
332 }
333
334 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
335 {
336    const int C = CHANNELS(_C);
337    int c;
338    for (c=0;c<C;c++)
339    {
340       int j;
341       celt_sig * restrict x;
342       celt_word16  * restrict y;
343       celt_sig m = mem[c];
344       x =in[c];
345       y = pcm+c;
346       for (j=0;j<N;j++)
347       {
348          celt_sig tmp = *x + m;
349          m = MULT16_32_Q15(coef[0], tmp)
350            - MULT16_32_Q15(coef[1], *x);
351          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
352          *y = SCALEOUT(SIG2WORD16(tmp));
353          x++;
354          y+=C;
355       }
356       mem[c] = m;
357    }
358 }
359
360 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
361                        int end, int N,
362                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
363 {
364    int m, i, c;
365    const int C = CHANNELS(_C);
366    for (c=0;c<C;c++)
367       for (m=start;m<end;m++)
368          for (i=m+c*N;i<(c+1)*N;i+=M)
369 #ifdef FIXED_POINT
370             X[i] = SHR16(X[i], mdct_weight_shift);
371 #else
372             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
373 #endif
374    if (renorm)
375       renormalise_bands(mode, X, end_band, C, M);
376 }
377
378 static signed char tf_select_table[4][8] = {
379       {0, -1, 0, -1,    0,-1, 0,-1},
380       {0, -1, 0, -2,    1, 0, 1 -1},
381       {0, -2, 0, -3,    2, 0, 1 -1},
382       {0, -2, 0, -3,    2, 0, 1 -1},
383 };
384
385 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
386 {
387    int i;
388    celt_word16 threshold;
389    VARDECL(celt_word16, metric);
390    celt_word32 average=0;
391    celt_word32 cost0;
392    celt_word32 cost1;
393    VARDECL(int, path0);
394    VARDECL(int, path1);
395    celt_word16 lambda;
396    int tf_select=0;
397    SAVE_STACK;
398
399    /* FIXME: Should check number of bytes *left* */
400    if (nbCompressedBytes<15*C)
401    {
402       for (i=0;i<len;i++)
403          tf_res[i] = 0;
404       return 0;
405    }
406    if (nbCompressedBytes<40)
407       lambda = QCONST16(5.f, DB_SHIFT);
408    else if (nbCompressedBytes<60)
409       lambda = QCONST16(2.f, DB_SHIFT);
410    else if (nbCompressedBytes<100)
411       lambda = QCONST16(1.f, DB_SHIFT);
412    else
413       lambda = QCONST16(.5f, DB_SHIFT);
414
415    ALLOC(metric, len, celt_word16);
416    ALLOC(path0, len, int);
417    ALLOC(path1, len, int);
418    for (i=0;i<len;i++)
419    {
420       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
421       average += metric[i];
422    }
423    if (C==2)
424    {
425       average = 0;
426       for (i=0;i<len;i++)
427       {
428          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
429          average += metric[i];
430       }
431    }
432    average = DIV32(average, len);
433    /*if (!isTransient)
434       printf ("%f\n", average);*/
435    if (isTransient)
436    {
437       threshold = QCONST16(1.f,DB_SHIFT);
438       tf_select = average > QCONST16(3.f,DB_SHIFT);
439    } else {
440       threshold = QCONST16(.5f,DB_SHIFT);
441       tf_select = average > QCONST16(1.f,DB_SHIFT);
442    }
443    cost0 = 0;
444    cost1 = lambda;
445    /* Viterbi forward pass */
446    for (i=1;i<len;i++)
447    {
448       celt_word32 curr0, curr1;
449       celt_word32 from0, from1;
450
451       from0 = cost0;
452       from1 = cost1 + lambda;
453       if (from0 < from1)
454       {
455          curr0 = from0;
456          path0[i]= 0;
457       } else {
458          curr0 = from1;
459          path0[i]= 1;
460       }
461
462       from0 = cost0 + lambda;
463       from1 = cost1;
464       if (from0 < from1)
465       {
466          curr1 = from0;
467          path1[i]= 0;
468       } else {
469          curr1 = from1;
470          path1[i]= 1;
471       }
472       cost0 = curr0 + (metric[i]-threshold);
473       cost1 = curr1;
474    }
475    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
476    /* Viterbi backward pass to check the decisions */
477    for (i=len-2;i>=0;i--)
478    {
479       if (tf_res[i+1] == 1)
480          tf_res[i] = path1[i+1];
481       else
482          tf_res[i] = path0[i+1];
483    }
484    RESTORE_STACK;
485    return tf_select;
486 }
487
488 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
489 {
490    int curr, i;
491    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
492    {
493       for (i=start;i<end;i++)
494          tf_res[i] = isTransient;
495    } else {
496       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
497       curr = tf_res[start];
498       for (i=start+1;i<end;i++)
499       {
500          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
501          curr = tf_res[i];
502       }
503    }
504    ec_enc_bits(enc, tf_select, 1);
505    for (i=start;i<end;i++)
506       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
507 }
508
509 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
510 {
511    int i, curr, tf_select;
512    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
513    {
514       for (i=start;i<end;i++)
515          tf_res[i] = isTransient;
516    } else {
517       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
518       curr = tf_res[start];
519       for (i=start+1;i<end;i++)
520       {
521          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
522          curr = tf_res[i];
523       }
524    }
525    tf_select = ec_dec_bits(dec, 1);
526    for (i=start;i<end;i++)
527       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
528 }
529
530 #ifdef FIXED_POINT
531 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
532 {
533 #else
534 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
535 {
536 #endif
537    int i, c, N, NN;
538    int bits;
539    int has_fold=1;
540    ec_byte_buffer buf;
541    ec_enc         _enc;
542    VARDECL(celt_sig, in);
543    VARDECL(celt_sig, freq);
544    VARDECL(celt_norm, X);
545    VARDECL(celt_ener, bandE);
546    VARDECL(celt_word16, bandLogE);
547    VARDECL(int, fine_quant);
548    VARDECL(celt_word16, error);
549    VARDECL(int, pulses);
550    VARDECL(int, offsets);
551    VARDECL(int, fine_priority);
552    VARDECL(int, tf_res);
553    celt_sig *_overlap_mem;
554    celt_word16 *oldBandE;
555    int shortBlocks=0;
556    int isTransient=0;
557    int transient_time, transient_time_quant;
558    int transient_shift;
559    int resynth;
560    const int C = CHANNELS(st->channels);
561    int mdct_weight_shift = 0;
562    int mdct_weight_pos=0;
563    int LM, M;
564    int tf_select;
565    int nbFilledBytes, nbAvailableBytes;
566    int effEnd;
567    SAVE_STACK;
568
569    if (nbCompressedBytes<0 || pcm==NULL)
570      return CELT_BAD_ARG;
571
572    for (LM=0;LM<4;LM++)
573       if (st->mode->shortMdctSize<<LM==frame_size)
574          break;
575    if (LM>=MAX_CONFIG_SIZES)
576       return CELT_BAD_ARG;
577    M=1<<LM;
578
579    _overlap_mem = st->in_mem+C*(st->overlap);
580    oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
581
582    if (enc==NULL)
583    {
584       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
585       ec_enc_init(&_enc,&buf);
586       enc = &_enc;
587       nbFilledBytes=0;
588    } else {
589       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
590    }
591    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
592
593    effEnd = st->end;
594    if (effEnd > st->mode->effEBands)
595       effEnd = st->mode->effEBands;
596
597    N = M*st->mode->shortMdctSize;
598    ALLOC(in, C*(N+st->overlap), celt_sig);
599
600    CELT_COPY(in, st->in_mem, C*st->overlap);
601    for (c=0;c<C;c++)
602    {
603       const celt_word16 * restrict pcmp = pcm+c;
604       celt_sig * restrict inp = in+C*st->overlap+c;
605       for (i=0;i<N;i++)
606       {
607          /* Apply pre-emphasis */
608          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
609          *inp = tmp + st->preemph_memE[c];
610          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
611                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
612          inp += C;
613          pcmp += C;
614       }
615    }
616    CELT_COPY(st->in_mem, in+C*N, C*st->overlap);
617
618    /* Transient handling */
619    transient_time = -1;
620    transient_time_quant = -1;
621    transient_shift = 0;
622    isTransient = 0;
623
624    resynth = optional_resynthesis!=NULL;
625
626    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
627    {
628 #ifndef FIXED_POINT
629       float gain_1;
630 #endif
631       /* Apply the inverse shaping window */
632       if (transient_shift)
633       {
634          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
635          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
636 #ifdef FIXED_POINT
637          for (c=0;c<C;c++)
638             for (i=0;i<16;i++)
639                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
640          for (c=0;c<C;c++)
641             for (i=transient_time;i<N+st->overlap;i++)
642                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
643 #else
644          for (c=0;c<C;c++)
645             for (i=0;i<16;i++)
646                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
647          gain_1 = 1.f/(1<<transient_shift);
648          for (c=0;c<C;c++)
649             for (i=transient_time;i<N+st->overlap;i++)
650                in[C*i+c] *= gain_1;
651 #endif
652       }
653       isTransient = 1;
654       has_fold = 1;
655    }
656
657    if (isTransient)
658       shortBlocks = M;
659    else
660       shortBlocks = 0;
661
662    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
663    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
664    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
665    /* Compute MDCTs */
666    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
667
668    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
669
670    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
671
672    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
673
674    /* Band normalisation */
675    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
676
677    NN = M*st->mode->eBands[effEnd];
678    if (shortBlocks && !transient_shift)
679    {
680       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
681       int m;
682       for (c=0;c<C;c++)
683       {
684          m=0;
685          do {
686             celt_word32 tmp=0;
687             for (i=m+c*N;i<c*N+NN;i+=M)
688                tmp += ABS32(X[i]);
689             sum[m++] += tmp;
690          } while (m<M);
691       }
692       m=0;
693 #ifdef FIXED_POINT
694       do {
695          if (SHR32(sum[m+1],3) > sum[m])
696          {
697             mdct_weight_shift=2;
698             mdct_weight_pos = m;
699          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
700          {
701             mdct_weight_shift=1;
702             mdct_weight_pos = m;
703          }
704          m++;
705       } while (m<M-1);
706 #else
707       do {
708          if (sum[m+1] > 8*sum[m])
709          {
710             mdct_weight_shift=2;
711             mdct_weight_pos = m;
712          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
713          {
714             mdct_weight_shift=1;
715             mdct_weight_pos = m;
716          }
717          m++;
718       } while (m<M-1);
719 #endif
720       if (mdct_weight_shift)
721          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
722    }
723
724    ALLOC(tf_res, st->mode->nbEBands, int);
725    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
726    tf_select = tf_analysis(bandLogE, oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
727    for (i=effEnd;i<st->end;i++)
728       tf_res[i] = tf_res[effEnd-1];
729
730    ALLOC(error, C*st->mode->nbEBands, celt_word16);
731    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
732          oldBandE, nbCompressedBytes*8, st->mode->prob,
733          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
734
735    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
736
737    if (shortBlocks)
738    {
739       if (transient_shift)
740       {
741          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
742          ec_enc_uint(enc, transient_shift, 4);
743          ec_enc_uint(enc, transient_time_quant, max_time);
744       } else {
745          ec_enc_uint(enc, mdct_weight_shift, 4);
746          if (mdct_weight_shift && M!=2)
747             ec_enc_uint(enc, mdct_weight_pos, M-1);
748       }
749    }
750
751    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
752
753    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
754       has_fold = 0;
755    ec_enc_bit_prob(enc, has_fold>>1, 8192);
756    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
757
758    /* Variable bitrate */
759    if (st->vbr_rate_norm>0)
760    {
761      celt_word16 alpha;
762      celt_int32 delta;
763      /* The target rate in 16th bits per frame */
764      celt_int32 vbr_rate;
765      celt_int32 target;
766      celt_int32 vbr_bound, max_allowed;
767
768      vbr_rate = M*st->vbr_rate_norm;
769
770      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
771      vbr_bound = vbr_rate;
772      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
773      if (max_allowed < 4)
774         max_allowed = 4;
775      if (max_allowed < nbAvailableBytes)
776         nbAvailableBytes = max_allowed;
777      target=vbr_rate;
778
779      /* Shortblocks get a large boost in bitrate, but since they 
780         are uncommon long blocks are not greatly effected */
781      if (shortBlocks)
782        target*=2;
783      else if (M > 1)
784        target-=(target+14)/28;
785
786      /* The average energy is removed from the target and the actual 
787         energy added*/
788      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
789
790      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
791      target=IMIN(nbAvailableBytes,target);
792      /* Make the adaptation coef (alpha) higher at the beginning */
793      if (st->vbr_count < 990)
794      {
795         st->vbr_count++;
796         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
797         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
798      } else
799         alpha = QCONST16(.001f,15);
800
801      /* By how much did we "miss" the target on that frame */
802      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
803      /* How many bits have we used in excess of what we're allowed */
804      st->vbr_reservoir += delta;
805      /*printf ("%d\n", st->vbr_reservoir);*/
806
807      /* Compute the offset we need to apply in order to reach the target */
808      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
809      st->vbr_offset = -st->vbr_drift;
810      /*printf ("%d\n", st->vbr_drift);*/
811
812      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
813      if (st->vbr_reservoir < 0)
814      {
815         /* We're under the min value -- increase rate */
816         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
817         st->vbr_reservoir += adjust*(8<<BITRES);
818         target += adjust;
819         /*printf ("+%d\n", adjust);*/
820      }
821      if (target < nbAvailableBytes)
822         nbAvailableBytes = target;
823      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
824
825      /* This moves the raw bits to take into account the new compressed size */
826      ec_byte_shrink(&buf, nbCompressedBytes);
827    }
828
829    /* Bit allocation */
830    ALLOC(fine_quant, st->mode->nbEBands, int);
831    ALLOC(pulses, st->mode->nbEBands, int);
832    ALLOC(offsets, st->mode->nbEBands, int);
833    ALLOC(fine_priority, st->mode->nbEBands, int);
834
835    for (i=0;i<st->mode->nbEBands;i++)
836       offsets[i] = 0;
837    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
838    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
839
840    quant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, enc, C);
841
842 #ifdef MEASURE_NORM_MSE
843    float X0[3000];
844    float bandE0[60];
845    for (c=0;c<C;c++)
846       for (i=0;i<N;i++)
847          X0[i+c*N] = X[i+c*N];
848    for (i=0;i<C*st->mode->nbEBands;i++)
849       bandE0[i] = bandE[i];
850 #endif
851
852    /* Residual quantisation */
853    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
854
855    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
856
857    /* Re-synthesis of the coded audio if required */
858    if (resynth)
859    {
860       VARDECL(celt_sig, _out_mem);
861       celt_sig *out_mem[2];
862       celt_sig *overlap_mem[2];
863
864       log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
865
866 #ifdef MEASURE_NORM_MSE
867       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
868 #endif
869
870       if (mdct_weight_shift)
871       {
872          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
873       }
874
875       /* Synthesis */
876       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
877
878       for (c=0;c<C;c++)
879          for (i=0;i<M*st->mode->eBands[st->start];i++)
880             freq[c*N+i] = 0;
881       for (c=0;c<C;c++)
882          for (i=M*st->mode->eBands[st->end];i<N;i++)
883             freq[c*N+i] = 0;
884
885       ALLOC(_out_mem, C*N, celt_sig);
886
887       for (c=0;c<C;c++)
888       {
889          overlap_mem[c] = _overlap_mem + c*st->overlap;
890          out_mem[c] = _out_mem+c*N;
891       }
892
893       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
894             transient_shift, out_mem, overlap_mem, C, LM);
895
896       /* De-emphasis and put everything back at the right place 
897          in the synthesis history */
898       if (optional_resynthesis != NULL) {
899          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
900
901       }
902    }
903
904    /* If there's any room left (can only happen for very high rates),
905       fill it with zeros */
906    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
907       ec_enc_bits(enc, 0, 8);
908    ec_enc_done(enc);
909    
910    RESTORE_STACK;
911    if (ec_enc_get_error(enc))
912       return CELT_CORRUPTED_DATA;
913    else
914       return nbCompressedBytes;
915 }
916
917 #ifdef FIXED_POINT
918 #ifndef DISABLE_FLOAT_API
919 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
920 {
921    int j, ret, C, N, LM, M;
922    VARDECL(celt_int16, in);
923    SAVE_STACK;
924
925    if (pcm==NULL)
926       return CELT_BAD_ARG;
927
928    for (LM=0;LM<4;LM++)
929       if (st->mode->shortMdctSize<<LM==frame_size)
930          break;
931    if (LM>=MAX_CONFIG_SIZES)
932       return CELT_BAD_ARG;
933    M=1<<LM;
934
935    C = CHANNELS(st->channels);
936    N = M*st->mode->shortMdctSize;
937    ALLOC(in, C*N, celt_int16);
938
939    for (j=0;j<C*N;j++)
940      in[j] = FLOAT2INT16(pcm[j]);
941
942    if (optional_resynthesis != NULL) {
943      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
944       for (j=0;j<C*N;j++)
945          optional_resynthesis[j]=in[j]*(1.f/32768.f);
946    } else {
947      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
948    }
949    RESTORE_STACK;
950    return ret;
951
952 }
953 #endif /*DISABLE_FLOAT_API*/
954 #else
955 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
956 {
957    int j, ret, C, N, LM, M;
958    VARDECL(celt_sig, in);
959    SAVE_STACK;
960
961    if (pcm==NULL)
962       return CELT_BAD_ARG;
963
964    for (LM=0;LM<4;LM++)
965       if (st->mode->shortMdctSize<<LM==frame_size)
966          break;
967    if (LM>=MAX_CONFIG_SIZES)
968       return CELT_BAD_ARG;
969    M=1<<LM;
970
971    C=CHANNELS(st->channels);
972    N=M*st->mode->shortMdctSize;
973    ALLOC(in, C*N, celt_sig);
974    for (j=0;j<C*N;j++) {
975      in[j] = SCALEOUT(pcm[j]);
976    }
977
978    if (optional_resynthesis != NULL) {
979       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
980       for (j=0;j<C*N;j++)
981          optional_resynthesis[j] = FLOAT2INT16(in[j]);
982    } else {
983       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
984    }
985    RESTORE_STACK;
986    return ret;
987 }
988 #endif
989
990 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
991 {
992    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
993 }
994
995 #ifndef DISABLE_FLOAT_API
996 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
997 {
998    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
999 }
1000 #endif /* DISABLE_FLOAT_API */
1001
1002 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1003 {
1004    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1005 }
1006
1007 #ifndef DISABLE_FLOAT_API
1008 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1009 {
1010    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1011 }
1012 #endif /* DISABLE_FLOAT_API */
1013
1014
1015 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1016 {
1017    va_list ap;
1018    
1019    va_start(ap, request);
1020    switch (request)
1021    {
1022       case CELT_GET_MODE_REQUEST:
1023       {
1024          const CELTMode ** value = va_arg(ap, const CELTMode**);
1025          if (value==0)
1026             goto bad_arg;
1027          *value=st->mode;
1028       }
1029       break;
1030       case CELT_SET_COMPLEXITY_REQUEST:
1031       {
1032          int value = va_arg(ap, celt_int32);
1033          if (value<0 || value>10)
1034             goto bad_arg;
1035       }
1036       break;
1037       case CELT_SET_START_BAND_REQUEST:
1038       {
1039          celt_int32 value = va_arg(ap, celt_int32);
1040          if (value<0 || value>=st->mode->nbEBands)
1041             goto bad_arg;
1042          st->start = value;
1043       }
1044       break;
1045       case CELT_SET_END_BAND_REQUEST:
1046       {
1047          celt_int32 value = va_arg(ap, celt_int32);
1048          if (value<0 || value>=st->mode->nbEBands)
1049             goto bad_arg;
1050          st->end = value;
1051       }
1052       break;
1053       case CELT_SET_PREDICTION_REQUEST:
1054       {
1055          int value = va_arg(ap, celt_int32);
1056          if (value<0 || value>2)
1057             goto bad_arg;
1058          if (value==0)
1059          {
1060             st->force_intra   = 1;
1061          } else if (value==1) {
1062             st->force_intra   = 0;
1063          } else {
1064             st->force_intra   = 0;
1065          }   
1066       }
1067       break;
1068       case CELT_SET_VBR_RATE_REQUEST:
1069       {
1070          celt_int32 value = va_arg(ap, celt_int32);
1071          int frame_rate;
1072          int N = st->mode->shortMdctSize;
1073          if (value<0)
1074             goto bad_arg;
1075          if (value>3072000)
1076             value = 3072000;
1077          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1078          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1079       }
1080       break;
1081       case CELT_RESET_STATE:
1082       {
1083          celt_sig *overlap_mem;
1084          celt_word16 *oldBandE;
1085          const CELTMode *mode = st->mode;
1086          int C = st->channels;
1087          overlap_mem = st->in_mem+C*(st->overlap);
1088          oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
1089
1090          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1091          CELT_MEMSET(overlap_mem, 0, st->overlap*C);
1092
1093          CELT_MEMSET(oldBandE, 0, C*mode->nbEBands);
1094
1095          CELT_MEMSET(st->preemph_memE, 0, C);
1096          CELT_MEMSET(st->preemph_memD, 0, C);
1097          st->delayedIntra = 1;
1098
1099          st->fold_decision = 1;
1100          st->tonal_average = QCONST16(1.f,8);
1101          st->gain_prod = 0;
1102          st->vbr_reservoir = 0;
1103          st->vbr_drift = 0;
1104          st->vbr_offset = 0;
1105          st->vbr_count = 0;
1106          st->frame_max = 0;
1107       }
1108       break;
1109       default:
1110          goto bad_request;
1111    }
1112    va_end(ap);
1113    return CELT_OK;
1114 bad_arg:
1115    va_end(ap);
1116    return CELT_BAD_ARG;
1117 bad_request:
1118    va_end(ap);
1119    return CELT_UNIMPLEMENTED;
1120 }
1121
1122 /**********************************************************************/
1123 /*                                                                    */
1124 /*                             DECODER                                */
1125 /*                                                                    */
1126 /**********************************************************************/
1127 #define DECODE_BUFFER_SIZE 2048
1128
1129 /** Decoder state 
1130  @brief Decoder state
1131  */
1132 struct CELTDecoder {
1133    const CELTMode *mode;
1134    int overlap;
1135    int channels;
1136
1137    int start, end;
1138    int last_pitch_index;
1139    int loss_count;
1140
1141    celt_sig preemph_memD[2];
1142    
1143    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
1144    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
1145    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
1146 };
1147
1148 int celt_decoder_get_size(const CELTMode *mode, int channels)
1149 {
1150    int size = sizeof(struct CELTDecoder)
1151             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
1152             + channels*LPC_ORDER*sizeof(celt_word16)
1153             + channels*mode->nbEBands*sizeof(celt_word16);
1154    return size;
1155 }
1156
1157 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1158 {
1159    return celt_decoder_init(
1160          (CELTDecoder *)celt_alloc(celt_decoder_get_size(mode, channels)),
1161          mode, channels, error);
1162 }
1163
1164 CELTDecoder *celt_decoder_init(CELTDecoder *st, const CELTMode *mode, int channels, int *error)
1165 {
1166    if (channels < 0 || channels > 2)
1167    {
1168       celt_warning("Only mono and stereo supported");
1169       if (error)
1170          *error = CELT_BAD_ARG;
1171       return NULL;
1172    }
1173
1174    CELT_MEMSET((char*)st, 0, celt_decoder_get_size(mode, channels));
1175
1176    if (st==NULL)
1177    {
1178       if (error)
1179          *error = CELT_ALLOC_FAIL;
1180       return NULL;
1181    }
1182
1183    st->mode = mode;
1184    st->overlap = mode->overlap;
1185    st->channels = channels;
1186
1187    st->start = 0;
1188    st->end = st->mode->effEBands;
1189
1190    st->loss_count = 0;
1191
1192    if (error)
1193       *error = CELT_OK;
1194    return st;
1195 }
1196
1197 void celt_decoder_destroy(CELTDecoder *st)
1198 {
1199    celt_free(st);
1200 }
1201
1202 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1203 {
1204    int c;
1205    int pitch_index;
1206    int overlap = st->mode->overlap;
1207    celt_word16 fade = Q15ONE;
1208    int i, len;
1209    const int C = CHANNELS(st->channels);
1210    int offset;
1211    celt_sig *out_mem[2];
1212    celt_sig *decode_mem[2];
1213    celt_sig *overlap_mem[2];
1214    celt_word16 *lpc;
1215    celt_word16 *oldBandE;
1216    SAVE_STACK;
1217    
1218    for (c=0;c<C;c++)
1219    {
1220       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1221       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1222       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1223    }
1224    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1225    oldBandE = lpc+C*LPC_ORDER;
1226
1227    len = N+st->mode->overlap;
1228    
1229    if (st->loss_count == 0)
1230    {
1231       celt_word16 pitch_buf[MAX_PERIOD>>1];
1232       celt_word32 tmp=0;
1233       celt_word32 mem0[2]={0,0};
1234       celt_word16 mem1[2]={0,0};
1235       int len2 = len;
1236       /* FIXME: This is a kludge */
1237       if (len2>MAX_PERIOD>>1)
1238          len2 = MAX_PERIOD>>1;
1239       pitch_downsample(out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1240                        C, mem0, mem1);
1241       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1242                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1243       pitch_index = MAX_PERIOD-len2-pitch_index;
1244       st->last_pitch_index = pitch_index;
1245    } else {
1246       pitch_index = st->last_pitch_index;
1247       if (st->loss_count < 5)
1248          fade = QCONST16(.8f,15);
1249       else
1250          fade = 0;
1251    }
1252
1253    for (c=0;c<C;c++)
1254    {
1255       /* FIXME: This is more memory than necessary */
1256       celt_word32 e[2*MAX_PERIOD];
1257       celt_word16 exc[2*MAX_PERIOD];
1258       celt_word32 ac[LPC_ORDER+1];
1259       celt_word16 decay = 1;
1260       celt_word32 S1=0;
1261       celt_word16 mem[LPC_ORDER]={0};
1262
1263       offset = MAX_PERIOD-pitch_index;
1264       for (i=0;i<MAX_PERIOD;i++)
1265          exc[i] = ROUND16(out_mem[c][i], SIG_SHIFT);
1266
1267       if (st->loss_count == 0)
1268       {
1269          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1270                         LPC_ORDER, MAX_PERIOD);
1271
1272          /* Noise floor -40 dB */
1273 #ifdef FIXED_POINT
1274          ac[0] += SHR32(ac[0],13);
1275 #else
1276          ac[0] *= 1.0001f;
1277 #endif
1278          /* Lag windowing */
1279          for (i=1;i<=LPC_ORDER;i++)
1280          {
1281             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1282 #ifdef FIXED_POINT
1283             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1284 #else
1285             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1286 #endif
1287          }
1288
1289          _celt_lpc(lpc+c*LPC_ORDER, ac, LPC_ORDER);
1290       }
1291       fir(exc, lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1292       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1293       /* Check if the waveform is decaying (and if so how fast) */
1294       {
1295          celt_word32 E1=1, E2=1;
1296          int period;
1297          if (pitch_index <= MAX_PERIOD/2)
1298             period = pitch_index;
1299          else
1300             period = MAX_PERIOD/2;
1301          for (i=0;i<period;i++)
1302          {
1303             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1304             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1305          }
1306          if (E1 > E2)
1307             E1 = E2;
1308          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1309       }
1310
1311       /* Copy excitation, taking decay into account */
1312       for (i=0;i<len+st->mode->overlap;i++)
1313       {
1314          if (offset+i >= MAX_PERIOD)
1315          {
1316             offset -= pitch_index;
1317             decay = MULT16_16_Q15(decay, decay);
1318          }
1319          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1320          S1 += SHR32(MULT16_16(out_mem[c][offset+i],out_mem[c][offset+i]),8);
1321       }
1322
1323       iir(e, lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1324
1325       {
1326          celt_word32 S2=0;
1327          for (i=0;i<len+overlap;i++)
1328             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1329          /* This checks for an "explosion" in the synthesis */
1330 #ifdef FIXED_POINT
1331          if (!(S1 > SHR32(S2,2)))
1332 #else
1333          /* Float test is written this way to catch NaNs at the same time */
1334          if (!(S1 > 0.2f*S2))
1335 #endif
1336          {
1337             for (i=0;i<len+overlap;i++)
1338                e[i] = 0;
1339          } else if (S1 < S2)
1340          {
1341             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1342             for (i=0;i<len+overlap;i++)
1343                e[i] = MULT16_16_Q15(ratio, e[i]);
1344          }
1345       }
1346
1347       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1348          out_mem[c][i] = out_mem[c][N+i];
1349
1350       /* Apply TDAC to the concealed audio so that it blends with the
1351          previous and next frames */
1352       for (i=0;i<overlap/2;i++)
1353       {
1354          celt_word32 tmp1, tmp2;
1355          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1356                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1357          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1358                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1359          tmp1 = MULT16_32_Q15(fade, tmp1);
1360          tmp2 = MULT16_32_Q15(fade, tmp2);
1361          out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1362          out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1363          out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1364          out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1365       }
1366       for (i=0;i<N-overlap;i++)
1367          out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1368    }
1369
1370    deemphasis(out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1371    
1372    st->loss_count++;
1373
1374    RESTORE_STACK;
1375 }
1376
1377 #ifdef FIXED_POINT
1378 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1379 {
1380 #else
1381 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1382 {
1383 #endif
1384    int c, i, N;
1385    int has_fold;
1386    int bits;
1387    ec_dec _dec;
1388    ec_byte_buffer buf;
1389    VARDECL(celt_sig, freq);
1390    VARDECL(celt_norm, X);
1391    VARDECL(celt_ener, bandE);
1392    VARDECL(int, fine_quant);
1393    VARDECL(int, pulses);
1394    VARDECL(int, offsets);
1395    VARDECL(int, fine_priority);
1396    VARDECL(int, tf_res);
1397    celt_sig *out_mem[2];
1398    celt_sig *decode_mem[2];
1399    celt_sig *overlap_mem[2];
1400    celt_sig *out_syn[2];
1401    celt_word16 *lpc;
1402    celt_word16 *oldBandE;
1403
1404    int shortBlocks;
1405    int isTransient;
1406    int intra_ener;
1407    int transient_time;
1408    int transient_shift;
1409    int mdct_weight_shift=0;
1410    const int C = CHANNELS(st->channels);
1411    int mdct_weight_pos=0;
1412    int LM, M;
1413    int nbFilledBytes, nbAvailableBytes;
1414    int effEnd;
1415    SAVE_STACK;
1416
1417    if (pcm==NULL)
1418       return CELT_BAD_ARG;
1419
1420    for (LM=0;LM<4;LM++)
1421       if (st->mode->shortMdctSize<<LM==frame_size)
1422          break;
1423    if (LM>=MAX_CONFIG_SIZES)
1424       return CELT_BAD_ARG;
1425    M=1<<LM;
1426
1427    for (c=0;c<C;c++)
1428    {
1429       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1430       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1431       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1432    }
1433    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1434    oldBandE = lpc+C*LPC_ORDER;
1435
1436    N = M*st->mode->shortMdctSize;
1437
1438    effEnd = st->end;
1439    if (effEnd > st->mode->effEBands)
1440       effEnd = st->mode->effEBands;
1441
1442    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1443    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1444    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1445    for (c=0;c<C;c++)
1446       for (i=0;i<M*st->mode->eBands[st->start];i++)
1447          X[c*N+i] = 0;
1448    for (c=0;c<C;c++)
1449       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1450          X[c*N+i] = 0;
1451
1452    if (data == NULL)
1453    {
1454       celt_decode_lost(st, pcm, N, LM);
1455       RESTORE_STACK;
1456       return CELT_OK;
1457    }
1458    if (len<0) {
1459      RESTORE_STACK;
1460      return CELT_BAD_ARG;
1461    }
1462    
1463    if (dec == NULL)
1464    {
1465       ec_byte_readinit(&buf,(unsigned char*)data,len);
1466       ec_dec_init(&_dec,&buf);
1467       dec = &_dec;
1468       nbFilledBytes = 0;
1469    } else {
1470       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1471    }
1472    nbAvailableBytes = len-nbFilledBytes;
1473
1474    /* Decode the global flags (first symbols in the stream) */
1475    intra_ener = ec_dec_bit_prob(dec, 8192);
1476    /* Get band energies */
1477    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
1478          intra_ener, st->mode->prob, dec, C, LM);
1479
1480    isTransient = ec_dec_bit_prob(dec, 8192);
1481
1482    if (isTransient)
1483       shortBlocks = M;
1484    else
1485       shortBlocks = 0;
1486
1487    if (isTransient)
1488    {
1489       transient_shift = ec_dec_uint(dec, 4);
1490       if (transient_shift == 3)
1491       {
1492          int transient_time_quant;
1493          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1494          transient_time_quant = ec_dec_uint(dec, max_time);
1495          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1496       } else {
1497          mdct_weight_shift = transient_shift;
1498          if (mdct_weight_shift && M>2)
1499             mdct_weight_pos = ec_dec_uint(dec, M-1);
1500          transient_shift = 0;
1501          transient_time = 0;
1502       }
1503    } else {
1504       transient_time = -1;
1505       transient_shift = 0;
1506    }
1507
1508    ALLOC(tf_res, st->mode->nbEBands, int);
1509    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1510
1511    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1512    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1513
1514    ALLOC(pulses, st->mode->nbEBands, int);
1515    ALLOC(offsets, st->mode->nbEBands, int);
1516    ALLOC(fine_priority, st->mode->nbEBands, int);
1517
1518    for (i=0;i<st->mode->nbEBands;i++)
1519       offsets[i] = 0;
1520
1521    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1522    ALLOC(fine_quant, st->mode->nbEBands, int);
1523    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1524    /*bits = ec_dec_tell(dec, 0);
1525    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1526    
1527    unquant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, fine_quant, dec, C);
1528
1529    /* Decode fixed codebook */
1530    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1531
1532    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
1533          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1534
1535    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
1536
1537    if (mdct_weight_shift)
1538    {
1539       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1540    }
1541
1542    /* Synthesis */
1543    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1544
1545    CELT_MOVE(decode_mem[0], decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1546    if (C==2)
1547       CELT_MOVE(decode_mem[1], decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1548
1549    for (c=0;c<C;c++)
1550       for (i=0;i<M*st->mode->eBands[st->start];i++)
1551          freq[c*N+i] = 0;
1552    for (c=0;c<C;c++)
1553       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1554          freq[c*N+i] = 0;
1555
1556    out_syn[0] = out_mem[0]+MAX_PERIOD-N;
1557    if (C==2)
1558       out_syn[1] = out_mem[1]+MAX_PERIOD-N;
1559
1560    /* Compute inverse MDCTs */
1561    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1562          transient_shift, out_syn, overlap_mem, C, LM);
1563
1564    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
1565    st->loss_count = 0;
1566    RESTORE_STACK;
1567    if (ec_dec_get_error(dec))
1568       return CELT_CORRUPTED_DATA;
1569    else
1570       return CELT_OK;
1571 }
1572
1573 #ifdef FIXED_POINT
1574 #ifndef DISABLE_FLOAT_API
1575 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1576 {
1577    int j, ret, C, N, LM, M;
1578    VARDECL(celt_int16, out);
1579    SAVE_STACK;
1580
1581    if (pcm==NULL)
1582       return CELT_BAD_ARG;
1583
1584    for (LM=0;LM<4;LM++)
1585       if (st->mode->shortMdctSize<<LM==frame_size)
1586          break;
1587    if (LM>=MAX_CONFIG_SIZES)
1588       return CELT_BAD_ARG;
1589    M=1<<LM;
1590
1591    C = CHANNELS(st->channels);
1592    N = M*st->mode->shortMdctSize;
1593    
1594    ALLOC(out, C*N, celt_int16);
1595    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1596    if (ret==0)
1597       for (j=0;j<C*N;j++)
1598          pcm[j]=out[j]*(1.f/32768.f);
1599      
1600    RESTORE_STACK;
1601    return ret;
1602 }
1603 #endif /*DISABLE_FLOAT_API*/
1604 #else
1605 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1606 {
1607    int j, ret, C, N, LM, M;
1608    VARDECL(celt_sig, out);
1609    SAVE_STACK;
1610
1611    if (pcm==NULL)
1612       return CELT_BAD_ARG;
1613
1614    for (LM=0;LM<4;LM++)
1615       if (st->mode->shortMdctSize<<LM==frame_size)
1616          break;
1617    if (LM>=MAX_CONFIG_SIZES)
1618       return CELT_BAD_ARG;
1619    M=1<<LM;
1620
1621    C = CHANNELS(st->channels);
1622    N = M*st->mode->shortMdctSize;
1623    ALLOC(out, C*N, celt_sig);
1624
1625    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1626
1627    if (ret==0)
1628       for (j=0;j<C*N;j++)
1629          pcm[j] = FLOAT2INT16 (out[j]);
1630    
1631    RESTORE_STACK;
1632    return ret;
1633 }
1634 #endif
1635
1636 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1637 {
1638    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1639 }
1640
1641 #ifndef DISABLE_FLOAT_API
1642 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1643 {
1644    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1645 }
1646 #endif /* DISABLE_FLOAT_API */
1647
1648 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1649 {
1650    va_list ap;
1651
1652    va_start(ap, request);
1653    switch (request)
1654    {
1655       case CELT_GET_MODE_REQUEST:
1656       {
1657          const CELTMode ** value = va_arg(ap, const CELTMode**);
1658          if (value==0)
1659             goto bad_arg;
1660          *value=st->mode;
1661       }
1662       break;
1663       case CELT_SET_START_BAND_REQUEST:
1664       {
1665          celt_int32 value = va_arg(ap, celt_int32);
1666          if (value<0 || value>=st->mode->nbEBands)
1667             goto bad_arg;
1668          st->start = value;
1669       }
1670       break;
1671       case CELT_SET_END_BAND_REQUEST:
1672       {
1673          celt_int32 value = va_arg(ap, celt_int32);
1674          if (value<0 || value>=st->mode->nbEBands)
1675             goto bad_arg;
1676          st->end = value;
1677       }
1678       break;
1679       case CELT_RESET_STATE:
1680       {
1681          const CELTMode *mode = st->mode;
1682          int C = st->channels;
1683          celt_word16 *lpc;
1684          celt_word16 *oldBandE;
1685
1686          lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1687          oldBandE = lpc+C*LPC_ORDER;
1688
1689          CELT_MEMSET(st->_decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1690          CELT_MEMSET(oldBandE, 0, C*mode->nbEBands);
1691
1692          CELT_MEMSET(st->preemph_memD, 0, C);
1693
1694          st->loss_count = 0;
1695
1696          CELT_MEMSET(lpc, 0, C*LPC_ORDER);
1697       }
1698       break;
1699       default:
1700          goto bad_request;
1701    }
1702    va_end(ap);
1703    return CELT_OK;
1704 bad_arg:
1705    va_end(ap);
1706    return CELT_BAD_ARG;
1707 bad_request:
1708       va_end(ap);
1709   return CELT_UNIMPLEMENTED;
1710 }
1711
1712 const char *celt_strerror(int error)
1713 {
1714    static const char *error_strings[8] = {
1715       "success",
1716       "invalid argument",
1717       "invalid mode",
1718       "internal error",
1719       "corrupted stream",
1720       "request not implemented",
1721       "invalid state",
1722       "memory allocation failed"
1723    };
1724    if (error > 0 || error < -7)
1725       return "unknown error";
1726    else 
1727       return error_strings[-error];
1728 }
1729