Fixes some MSVC warnings
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2010 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int overlap;
74    int channels;
75    
76    int force_intra;
77    int start, end;
78
79    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
80
81    /* Everything beyond this point gets cleared on a reset */
82 #define ENCODER_RESET_START frame_max
83
84    celt_word32 frame_max;
85    int fold_decision;
86    int delayedIntra;
87    celt_word16 tonal_average;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_word32 preemph_memE[2];
96    celt_word32 preemph_memD[2];
97
98    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
99    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
100    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
101 };
102
103 int celt_encoder_get_size(const CELTMode *mode, int channels)
104 {
105    int size = sizeof(struct CELTEncoder)
106          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
107          + channels*mode->nbEBands*sizeof(celt_word16);
108    return size;
109 }
110
111 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
112 {
113    return celt_encoder_init(
114          (CELTEncoder *)celt_alloc(celt_encoder_get_size(mode, channels)),
115          mode, channels, error);
116 }
117
118 CELTEncoder *celt_encoder_init(CELTEncoder *st, const CELTMode *mode, int channels, int *error)
119 {
120    if (channels < 0 || channels > 2)
121    {
122       if (error)
123          *error = CELT_BAD_ARG;
124       return NULL;
125    }
126
127    if (st==NULL)
128    {
129       if (error)
130          *error = CELT_ALLOC_FAIL;
131       return NULL;
132    }
133
134    CELT_MEMSET((char*)st, 0, celt_encoder_get_size(mode, channels));
135    
136    st->mode = mode;
137    st->overlap = mode->overlap;
138    st->channels = channels;
139
140    st->start = 0;
141    st->end = st->mode->effEBands;
142
143    st->vbr_rate_norm = 0;
144    st->force_intra  = 0;
145    st->delayedIntra = 1;
146    st->tonal_average = QCONST16(1.f,8);
147    st->fold_decision = 1;
148
149    if (error)
150       *error = CELT_OK;
151    return st;
152 }
153
154 void celt_encoder_destroy(CELTEncoder *st)
155 {
156    celt_free(st);
157 }
158
159 static inline celt_int16 FLOAT2INT16(float x)
160 {
161    x = x*CELT_SIG_SCALE;
162    x = MAX32(x, -32768);
163    x = MIN32(x, 32767);
164    return (celt_int16)float2int(x);
165 }
166
167 static inline celt_word16 SIG2WORD16(celt_sig x)
168 {
169 #ifdef FIXED_POINT
170    x = PSHR32(x, SIG_SHIFT);
171    x = MAX32(x, -32768);
172    x = MIN32(x, 32767);
173    return EXTRACT16(x);
174 #else
175    return (celt_word16)x;
176 #endif
177 }
178
179 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
180                               int *transient_time, int *transient_shift,
181                               celt_word32 *frame_max, int overlap)
182 {
183    int i, n;
184    celt_word32 ratio;
185    celt_word32 threshold;
186    VARDECL(celt_word32, begin);
187    SAVE_STACK;
188    ALLOC(begin, len+1, celt_word32);
189    begin[0] = 0;
190    if (C==1)
191    {
192       for (i=0;i<len;i++)
193          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
194    } else {
195       for (i=0;i<len;i++)
196          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
197                                             ABS32(in[C*i+1])));
198    }
199    n = -1;
200
201    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
202    /* If the following condition isn't met, there's just no way
203       we'll have a transient*/
204    if (*frame_max < threshold)
205    {
206       /* It's likely we have a transient, now find it */
207       for (i=8;i<len-8;i++)
208       {
209          if (begin[i+1] < threshold)
210             n=i;
211       }
212    }
213    if (n<32)
214    {
215       n = -1;
216       ratio = 0;
217    } else {
218       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
219    }
220
221    if (ratio > 45)
222       *transient_shift = 3;
223    else
224       *transient_shift = 0;
225    
226    *transient_time = n;
227    *frame_max = begin[len-overlap];
228
229    RESTORE_STACK;
230    return ratio > 0;
231 }
232
233 /** Apply window and compute the MDCT for all sub-frames and 
234     all channels in a frame */
235 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
236 {
237    const int C = CHANNELS(_C);
238    if (C==1 && !shortBlocks)
239    {
240       const int overlap = OVERLAP(mode);
241       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
242    } else {
243       const int overlap = OVERLAP(mode);
244       int N = mode->shortMdctSize<<LM;
245       int B = 1;
246       int b, c;
247       VARDECL(celt_word32, x);
248       VARDECL(celt_word32, tmp);
249       SAVE_STACK;
250       if (shortBlocks)
251       {
252          /*lookup = &mode->mdct[0];*/
253          N = mode->shortMdctSize;
254          B = shortBlocks;
255       }
256       ALLOC(x, N+overlap, celt_word32);
257       ALLOC(tmp, N, celt_word32);
258       for (c=0;c<C;c++)
259       {
260          for (b=0;b<B;b++)
261          {
262             int j;
263             for (j=0;j<N+overlap;j++)
264                x[j] = in[C*(b*N+j)+c];
265             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
266             /* Interleaving the sub-frames */
267             for (j=0;j<N;j++)
268                out[(j*B+b)+c*N*B] = tmp[j];
269          }
270       }
271       RESTORE_STACK;
272    }
273 }
274
275 /** Compute the IMDCT and apply window for all sub-frames and 
276     all channels in a frame */
277 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
278       int transient_time, int transient_shift, celt_sig * restrict out_mem[],
279       celt_sig * restrict overlap_mem[], int _C, int LM)
280 {
281    int c;
282    const int C = CHANNELS(_C);
283    const int N = mode->shortMdctSize<<LM;
284    const int overlap = OVERLAP(mode);
285    for (c=0;c<C;c++)
286    {
287       int j;
288          VARDECL(celt_word32, x);
289          VARDECL(celt_word32, tmp);
290          int b;
291          int N2 = N;
292          int B = 1;
293          SAVE_STACK;
294          
295          ALLOC(x, N+overlap, celt_word32);
296          ALLOC(tmp, N, celt_word32);
297
298          if (shortBlocks)
299          {
300             N2 = mode->shortMdctSize;
301             B = shortBlocks;
302          }
303          /* Prevents problems from the imdct doing the overlap-add */
304          CELT_MEMSET(x, 0, overlap);
305
306          for (b=0;b<B;b++)
307          {
308             /* De-interleaving the sub-frames */
309             for (j=0;j<N2;j++)
310                tmp[j] = X[(j*B+b)+c*N2*B];
311             clt_mdct_backward(&mode->mdct, tmp, x+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
312          }
313
314          if (transient_shift > 0)
315          {
316 #ifdef FIXED_POINT
317             for (j=0;j<16;j++)
318                x[transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[transient_time+j-16],transient_shift));
319             for (j=transient_time;j<N+overlap;j++)
320                x[j] = SHL32(x[j], transient_shift);
321 #else
322             for (j=0;j<16;j++)
323                x[transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
324             for (j=transient_time;j<N+overlap;j++)
325                x[j] *= 1<<transient_shift;
326 #endif
327          }
328          for (j=0;j<overlap;j++)
329             out_mem[c][j] = x[j] + overlap_mem[c][j];
330          for (;j<N;j++)
331             out_mem[c][j] = x[j];
332          for (j=0;j<overlap;j++)
333             overlap_mem[c][j] = x[N+j];
334          RESTORE_STACK;
335    }
336 }
337
338 static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
339 {
340    const int C = CHANNELS(_C);
341    int c;
342    for (c=0;c<C;c++)
343    {
344       int j;
345       celt_sig * restrict x;
346       celt_word16  * restrict y;
347       celt_sig m = mem[c];
348       x =in[c];
349       y = pcm+c;
350       for (j=0;j<N;j++)
351       {
352          celt_sig tmp = *x + m;
353          m = MULT16_32_Q15(coef[0], tmp)
354            - MULT16_32_Q15(coef[1], *x);
355          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
356          *y = SCALEOUT(SIG2WORD16(tmp));
357          x++;
358          y+=C;
359       }
360       mem[c] = m;
361    }
362 }
363
364 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
365                        int end, int N,
366                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
367 {
368    int m, i, c;
369    const int C = CHANNELS(_C);
370    for (c=0;c<C;c++)
371       for (m=start;m<end;m++)
372          for (i=m+c*N;i<(c+1)*N;i+=M)
373 #ifdef FIXED_POINT
374             X[i] = SHR16(X[i], mdct_weight_shift);
375 #else
376             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
377 #endif
378    if (renorm)
379       renormalise_bands(mode, X, end_band, C, M);
380 }
381
382 static signed char tf_select_table[4][8] = {
383       {0, -1, 0, -1,    0,-1, 0,-1},
384       {0, -1, 0, -2,    1, 0, 1 -1},
385       {0, -2, 0, -3,    2, 0, 1 -1},
386       {0, -2, 0, -3,    2, 0, 1 -1},
387 };
388
389 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
390 {
391    int i;
392    celt_word16 threshold;
393    VARDECL(celt_word16, metric);
394    celt_word32 average=0;
395    celt_word32 cost0;
396    celt_word32 cost1;
397    VARDECL(int, path0);
398    VARDECL(int, path1);
399    celt_word16 lambda;
400    int tf_select=0;
401    SAVE_STACK;
402
403    /* FIXME: Should check number of bytes *left* */
404    if (nbCompressedBytes<15*C)
405    {
406       for (i=0;i<len;i++)
407          tf_res[i] = 0;
408       return 0;
409    }
410    if (nbCompressedBytes<40)
411       lambda = QCONST16(5.f, DB_SHIFT);
412    else if (nbCompressedBytes<60)
413       lambda = QCONST16(2.f, DB_SHIFT);
414    else if (nbCompressedBytes<100)
415       lambda = QCONST16(1.f, DB_SHIFT);
416    else
417       lambda = QCONST16(.5f, DB_SHIFT);
418
419    ALLOC(metric, len, celt_word16);
420    ALLOC(path0, len, int);
421    ALLOC(path1, len, int);
422    for (i=0;i<len;i++)
423    {
424       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
425       average += metric[i];
426    }
427    if (C==2)
428    {
429       average = 0;
430       for (i=0;i<len;i++)
431       {
432          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
433          average += metric[i];
434       }
435    }
436    average = DIV32(average, len);
437    /*if (!isTransient)
438       printf ("%f\n", average);*/
439    if (isTransient)
440    {
441       threshold = QCONST16(1.f,DB_SHIFT);
442       tf_select = average > QCONST16(3.f,DB_SHIFT);
443    } else {
444       threshold = QCONST16(.5f,DB_SHIFT);
445       tf_select = average > QCONST16(1.f,DB_SHIFT);
446    }
447    cost0 = 0;
448    cost1 = lambda;
449    /* Viterbi forward pass */
450    for (i=1;i<len;i++)
451    {
452       celt_word32 curr0, curr1;
453       celt_word32 from0, from1;
454
455       from0 = cost0;
456       from1 = cost1 + lambda;
457       if (from0 < from1)
458       {
459          curr0 = from0;
460          path0[i]= 0;
461       } else {
462          curr0 = from1;
463          path0[i]= 1;
464       }
465
466       from0 = cost0 + lambda;
467       from1 = cost1;
468       if (from0 < from1)
469       {
470          curr1 = from0;
471          path1[i]= 0;
472       } else {
473          curr1 = from1;
474          path1[i]= 1;
475       }
476       cost0 = curr0 + (metric[i]-threshold);
477       cost1 = curr1;
478    }
479    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
480    /* Viterbi backward pass to check the decisions */
481    for (i=len-2;i>=0;i--)
482    {
483       if (tf_res[i+1] == 1)
484          tf_res[i] = path1[i+1];
485       else
486          tf_res[i] = path0[i+1];
487    }
488    RESTORE_STACK;
489    return tf_select;
490 }
491
492 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
493 {
494    int curr, i;
495    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
496    {
497       for (i=start;i<end;i++)
498          tf_res[i] = isTransient;
499    } else {
500       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
501       curr = tf_res[start];
502       for (i=start+1;i<end;i++)
503       {
504          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
505          curr = tf_res[i];
506       }
507    }
508    ec_enc_bits(enc, tf_select, 1);
509    for (i=start;i<end;i++)
510       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
511 }
512
513 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
514 {
515    int i, curr, tf_select;
516    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
517    {
518       for (i=start;i<end;i++)
519          tf_res[i] = isTransient;
520    } else {
521       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
522       curr = tf_res[start];
523       for (i=start+1;i<end;i++)
524       {
525          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
526          curr = tf_res[i];
527       }
528    }
529    tf_select = ec_dec_bits(dec, 1);
530    for (i=start;i<end;i++)
531       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
532 }
533
534 #ifdef FIXED_POINT
535 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
536 {
537 #else
538 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
539 {
540 #endif
541    int i, c, N, NN;
542    int bits;
543    int has_fold=1;
544    ec_byte_buffer buf;
545    ec_enc         _enc;
546    VARDECL(celt_sig, in);
547    VARDECL(celt_sig, freq);
548    VARDECL(celt_norm, X);
549    VARDECL(celt_ener, bandE);
550    VARDECL(celt_word16, bandLogE);
551    VARDECL(int, fine_quant);
552    VARDECL(celt_word16, error);
553    VARDECL(int, pulses);
554    VARDECL(int, offsets);
555    VARDECL(int, fine_priority);
556    VARDECL(int, tf_res);
557    celt_sig *_overlap_mem;
558    celt_word16 *oldBandE;
559    int shortBlocks=0;
560    int isTransient=0;
561    int transient_time, transient_time_quant;
562    int transient_shift;
563    int resynth;
564    const int C = CHANNELS(st->channels);
565    int mdct_weight_shift = 0;
566    int mdct_weight_pos=0;
567    int LM, M;
568    int tf_select;
569    int nbFilledBytes, nbAvailableBytes;
570    int effEnd;
571    SAVE_STACK;
572
573    if (nbCompressedBytes<0 || pcm==NULL)
574      return CELT_BAD_ARG;
575
576    for (LM=0;LM<4;LM++)
577       if (st->mode->shortMdctSize<<LM==frame_size)
578          break;
579    if (LM>=MAX_CONFIG_SIZES)
580       return CELT_BAD_ARG;
581    M=1<<LM;
582
583    _overlap_mem = st->in_mem+C*(st->overlap);
584    oldBandE = (celt_word16*)(st->in_mem+2*C*(st->overlap));
585
586    if (enc==NULL)
587    {
588       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
589       ec_enc_init(&_enc,&buf);
590       enc = &_enc;
591       nbFilledBytes=0;
592    } else {
593       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
594    }
595    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
596
597    effEnd = st->end;
598    if (effEnd > st->mode->effEBands)
599       effEnd = st->mode->effEBands;
600
601    N = M*st->mode->shortMdctSize;
602    ALLOC(in, C*(N+st->overlap), celt_sig);
603
604    CELT_COPY(in, st->in_mem, C*st->overlap);
605    for (c=0;c<C;c++)
606    {
607       const celt_word16 * restrict pcmp = pcm+c;
608       celt_sig * restrict inp = in+C*st->overlap+c;
609       for (i=0;i<N;i++)
610       {
611          /* Apply pre-emphasis */
612          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
613          *inp = tmp + st->preemph_memE[c];
614          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
615                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
616          inp += C;
617          pcmp += C;
618       }
619    }
620    CELT_COPY(st->in_mem, in+C*N, C*st->overlap);
621
622    /* Transient handling */
623    transient_time = -1;
624    transient_time_quant = -1;
625    transient_shift = 0;
626    isTransient = 0;
627
628    resynth = optional_resynthesis!=NULL;
629
630    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
631    {
632 #ifndef FIXED_POINT
633       float gain_1;
634 #endif
635       /* Apply the inverse shaping window */
636       if (transient_shift)
637       {
638          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
639          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
640 #ifdef FIXED_POINT
641          for (c=0;c<C;c++)
642             for (i=0;i<16;i++)
643                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
644          for (c=0;c<C;c++)
645             for (i=transient_time;i<N+st->overlap;i++)
646                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
647 #else
648          for (c=0;c<C;c++)
649             for (i=0;i<16;i++)
650                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
651          gain_1 = 1.f/(1<<transient_shift);
652          for (c=0;c<C;c++)
653             for (i=transient_time;i<N+st->overlap;i++)
654                in[C*i+c] *= gain_1;
655 #endif
656       }
657       isTransient = 1;
658       has_fold = 1;
659    }
660
661    if (isTransient)
662       shortBlocks = M;
663    else
664       shortBlocks = 0;
665
666    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
667    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
668    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
669    /* Compute MDCTs */
670    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
671
672    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
673
674    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
675
676    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
677
678    /* Band normalisation */
679    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
680
681    NN = M*st->mode->eBands[effEnd];
682    if (shortBlocks && !transient_shift)
683    {
684       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
685       int m;
686       for (c=0;c<C;c++)
687       {
688          m=0;
689          do {
690             celt_word32 tmp=0;
691             for (i=m+c*N;i<c*N+NN;i+=M)
692                tmp += ABS32(X[i]);
693             sum[m++] += tmp;
694          } while (m<M);
695       }
696       m=0;
697 #ifdef FIXED_POINT
698       do {
699          if (SHR32(sum[m+1],3) > sum[m])
700          {
701             mdct_weight_shift=2;
702             mdct_weight_pos = m;
703          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
704          {
705             mdct_weight_shift=1;
706             mdct_weight_pos = m;
707          }
708          m++;
709       } while (m<M-1);
710 #else
711       do {
712          if (sum[m+1] > 8*sum[m])
713          {
714             mdct_weight_shift=2;
715             mdct_weight_pos = m;
716          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
717          {
718             mdct_weight_shift=1;
719             mdct_weight_pos = m;
720          }
721          m++;
722       } while (m<M-1);
723 #endif
724       if (mdct_weight_shift)
725          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
726    }
727
728    ALLOC(tf_res, st->mode->nbEBands, int);
729    /* Needs to be before coarse energy quantization because otherwise the energy gets modified */
730    tf_select = tf_analysis(bandLogE, oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
731    for (i=effEnd;i<st->end;i++)
732       tf_res[i] = tf_res[effEnd-1];
733
734    ALLOC(error, C*st->mode->nbEBands, celt_word16);
735    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
736          oldBandE, nbCompressedBytes*8, st->mode->prob,
737          error, enc, C, LM, nbAvailableBytes, st->force_intra, &st->delayedIntra);
738
739    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
740
741    if (shortBlocks)
742    {
743       if (transient_shift)
744       {
745          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
746          ec_enc_uint(enc, transient_shift, 4);
747          ec_enc_uint(enc, transient_time_quant, max_time);
748       } else {
749          ec_enc_uint(enc, mdct_weight_shift, 4);
750          if (mdct_weight_shift && M!=2)
751             ec_enc_uint(enc, mdct_weight_pos, M-1);
752       }
753    }
754
755    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
756
757    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
758       has_fold = 0;
759    ec_enc_bit_prob(enc, has_fold>>1, 8192);
760    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
761
762    /* Variable bitrate */
763    if (st->vbr_rate_norm>0)
764    {
765      celt_word16 alpha;
766      celt_int32 delta;
767      /* The target rate in 16th bits per frame */
768      celt_int32 vbr_rate;
769      celt_int32 target;
770      celt_int32 vbr_bound, max_allowed;
771
772      vbr_rate = M*st->vbr_rate_norm;
773
774      /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
775      vbr_bound = vbr_rate;
776      max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
777      if (max_allowed < 4)
778         max_allowed = 4;
779      if (max_allowed < nbAvailableBytes)
780         nbAvailableBytes = max_allowed;
781      target=vbr_rate;
782
783      /* Shortblocks get a large boost in bitrate, but since they 
784         are uncommon long blocks are not greatly effected */
785      if (shortBlocks)
786        target*=2;
787      else if (M > 1)
788        target-=(target+14)/28;
789
790      /* The average energy is removed from the target and the actual 
791         energy added*/
792      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
793
794      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
795      target=IMIN(nbAvailableBytes,target);
796      /* Make the adaptation coef (alpha) higher at the beginning */
797      if (st->vbr_count < 990)
798      {
799         st->vbr_count++;
800         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
801         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
802      } else
803         alpha = QCONST16(.001f,15);
804
805      /* By how much did we "miss" the target on that frame */
806      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
807      /* How many bits have we used in excess of what we're allowed */
808      st->vbr_reservoir += delta;
809      /*printf ("%d\n", st->vbr_reservoir);*/
810
811      /* Compute the offset we need to apply in order to reach the target */
812      st->vbr_drift += (celt_int32)MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
813      st->vbr_offset = -st->vbr_drift;
814      /*printf ("%d\n", st->vbr_drift);*/
815
816      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
817      if (st->vbr_reservoir < 0)
818      {
819         /* We're under the min value -- increase rate */
820         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
821         st->vbr_reservoir += adjust*(8<<BITRES);
822         target += adjust;
823         /*printf ("+%d\n", adjust);*/
824      }
825      if (target < nbAvailableBytes)
826         nbAvailableBytes = target;
827      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
828
829      /* This moves the raw bits to take into account the new compressed size */
830      ec_byte_shrink(&buf, nbCompressedBytes);
831    }
832
833    /* Bit allocation */
834    ALLOC(fine_quant, st->mode->nbEBands, int);
835    ALLOC(pulses, st->mode->nbEBands, int);
836    ALLOC(offsets, st->mode->nbEBands, int);
837    ALLOC(fine_priority, st->mode->nbEBands, int);
838
839    for (i=0;i<st->mode->nbEBands;i++)
840       offsets[i] = 0;
841    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
842    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, LM);
843
844    quant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, enc, C);
845
846 #ifdef MEASURE_NORM_MSE
847    float X0[3000];
848    float bandE0[60];
849    for (c=0;c<C;c++)
850       for (i=0;i<N;i++)
851          X0[i+c*N] = X[i+c*N];
852    for (i=0;i<C*st->mode->nbEBands;i++)
853       bandE0[i] = bandE[i];
854 #endif
855
856    /* Residual quantisation */
857    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
858
859    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
860
861    /* Re-synthesis of the coded audio if required */
862    if (resynth)
863    {
864       VARDECL(celt_sig, _out_mem);
865       celt_sig *out_mem[2];
866       celt_sig *overlap_mem[2];
867
868       log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
869
870 #ifdef MEASURE_NORM_MSE
871       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
872 #endif
873
874       if (mdct_weight_shift)
875       {
876          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
877       }
878
879       /* Synthesis */
880       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
881
882       for (c=0;c<C;c++)
883          for (i=0;i<M*st->mode->eBands[st->start];i++)
884             freq[c*N+i] = 0;
885       for (c=0;c<C;c++)
886          for (i=M*st->mode->eBands[st->end];i<N;i++)
887             freq[c*N+i] = 0;
888
889       ALLOC(_out_mem, C*N, celt_sig);
890
891       for (c=0;c<C;c++)
892       {
893          overlap_mem[c] = _overlap_mem + c*st->overlap;
894          out_mem[c] = _out_mem+c*N;
895       }
896
897       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
898             transient_shift, out_mem, overlap_mem, C, LM);
899
900       /* De-emphasis and put everything back at the right place 
901          in the synthesis history */
902       if (optional_resynthesis != NULL) {
903          deemphasis(out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
904
905       }
906    }
907
908    /* If there's any room left (can only happen for very high rates),
909       fill it with zeros */
910    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
911       ec_enc_bits(enc, 0, 8);
912    ec_enc_done(enc);
913    
914    RESTORE_STACK;
915    if (ec_enc_get_error(enc))
916       return CELT_CORRUPTED_DATA;
917    else
918       return nbCompressedBytes;
919 }
920
921 #ifdef FIXED_POINT
922 #ifndef DISABLE_FLOAT_API
923 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
924 {
925    int j, ret, C, N, LM, M;
926    VARDECL(celt_int16, in);
927    SAVE_STACK;
928
929    if (pcm==NULL)
930       return CELT_BAD_ARG;
931
932    for (LM=0;LM<4;LM++)
933       if (st->mode->shortMdctSize<<LM==frame_size)
934          break;
935    if (LM>=MAX_CONFIG_SIZES)
936       return CELT_BAD_ARG;
937    M=1<<LM;
938
939    C = CHANNELS(st->channels);
940    N = M*st->mode->shortMdctSize;
941    ALLOC(in, C*N, celt_int16);
942
943    for (j=0;j<C*N;j++)
944      in[j] = FLOAT2INT16(pcm[j]);
945
946    if (optional_resynthesis != NULL) {
947      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
948       for (j=0;j<C*N;j++)
949          optional_resynthesis[j]=in[j]*(1.f/32768.f);
950    } else {
951      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
952    }
953    RESTORE_STACK;
954    return ret;
955
956 }
957 #endif /*DISABLE_FLOAT_API*/
958 #else
959 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
960 {
961    int j, ret, C, N, LM, M;
962    VARDECL(celt_sig, in);
963    SAVE_STACK;
964
965    if (pcm==NULL)
966       return CELT_BAD_ARG;
967
968    for (LM=0;LM<4;LM++)
969       if (st->mode->shortMdctSize<<LM==frame_size)
970          break;
971    if (LM>=MAX_CONFIG_SIZES)
972       return CELT_BAD_ARG;
973    M=1<<LM;
974
975    C=CHANNELS(st->channels);
976    N=M*st->mode->shortMdctSize;
977    ALLOC(in, C*N, celt_sig);
978    for (j=0;j<C*N;j++) {
979      in[j] = SCALEOUT(pcm[j]);
980    }
981
982    if (optional_resynthesis != NULL) {
983       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
984       for (j=0;j<C*N;j++)
985          optional_resynthesis[j] = FLOAT2INT16(in[j]);
986    } else {
987       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
988    }
989    RESTORE_STACK;
990    return ret;
991 }
992 #endif
993
994 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
995 {
996    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
997 }
998
999 #ifndef DISABLE_FLOAT_API
1000 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1001 {
1002    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1003 }
1004 #endif /* DISABLE_FLOAT_API */
1005
1006 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1007 {
1008    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1009 }
1010
1011 #ifndef DISABLE_FLOAT_API
1012 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1013 {
1014    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1015 }
1016 #endif /* DISABLE_FLOAT_API */
1017
1018
1019 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1020 {
1021    va_list ap;
1022    
1023    va_start(ap, request);
1024    switch (request)
1025    {
1026       case CELT_GET_MODE_REQUEST:
1027       {
1028          const CELTMode ** value = va_arg(ap, const CELTMode**);
1029          if (value==0)
1030             goto bad_arg;
1031          *value=st->mode;
1032       }
1033       break;
1034       case CELT_SET_COMPLEXITY_REQUEST:
1035       {
1036          int value = va_arg(ap, celt_int32);
1037          if (value<0 || value>10)
1038             goto bad_arg;
1039       }
1040       break;
1041       case CELT_SET_START_BAND_REQUEST:
1042       {
1043          celt_int32 value = va_arg(ap, celt_int32);
1044          if (value<0 || value>=st->mode->nbEBands)
1045             goto bad_arg;
1046          st->start = value;
1047       }
1048       break;
1049       case CELT_SET_END_BAND_REQUEST:
1050       {
1051          celt_int32 value = va_arg(ap, celt_int32);
1052          if (value<0 || value>=st->mode->nbEBands)
1053             goto bad_arg;
1054          st->end = value;
1055       }
1056       break;
1057       case CELT_SET_PREDICTION_REQUEST:
1058       {
1059          int value = va_arg(ap, celt_int32);
1060          if (value<0 || value>2)
1061             goto bad_arg;
1062          if (value==0)
1063          {
1064             st->force_intra   = 1;
1065          } else if (value==1) {
1066             st->force_intra   = 0;
1067          } else {
1068             st->force_intra   = 0;
1069          }   
1070       }
1071       break;
1072       case CELT_SET_VBR_RATE_REQUEST:
1073       {
1074          celt_int32 value = va_arg(ap, celt_int32);
1075          int frame_rate;
1076          int N = st->mode->shortMdctSize;
1077          if (value<0)
1078             goto bad_arg;
1079          if (value>3072000)
1080             value = 3072000;
1081          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1082          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1083       }
1084       break;
1085       case CELT_RESET_STATE:
1086       {
1087          CELT_MEMSET((char*)&st->ENCODER_RESET_START, 0,
1088                celt_encoder_get_size(st->mode, st->channels)-
1089                ((char*)&st->ENCODER_RESET_START - (char*)st));
1090          st->delayedIntra = 1;
1091          st->fold_decision = 1;
1092          st->tonal_average = QCONST16(1.f,8);
1093       }
1094       break;
1095       default:
1096          goto bad_request;
1097    }
1098    va_end(ap);
1099    return CELT_OK;
1100 bad_arg:
1101    va_end(ap);
1102    return CELT_BAD_ARG;
1103 bad_request:
1104    va_end(ap);
1105    return CELT_UNIMPLEMENTED;
1106 }
1107
1108 /**********************************************************************/
1109 /*                                                                    */
1110 /*                             DECODER                                */
1111 /*                                                                    */
1112 /**********************************************************************/
1113 #define DECODE_BUFFER_SIZE 2048
1114
1115 /** Decoder state 
1116  @brief Decoder state
1117  */
1118 struct CELTDecoder {
1119    const CELTMode *mode;
1120    int overlap;
1121    int channels;
1122
1123    int start, end;
1124
1125    /* Everything beyond this point gets cleared on a reset */
1126 #define DECODER_RESET_START last_pitch_index
1127
1128    int last_pitch_index;
1129    int loss_count;
1130
1131    celt_sig preemph_memD[2];
1132    
1133    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
1134    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
1135    /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
1136 };
1137
1138 int celt_decoder_get_size(const CELTMode *mode, int channels)
1139 {
1140    int size = sizeof(struct CELTDecoder)
1141             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
1142             + channels*LPC_ORDER*sizeof(celt_word16)
1143             + channels*mode->nbEBands*sizeof(celt_word16);
1144    return size;
1145 }
1146
1147 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1148 {
1149    return celt_decoder_init(
1150          (CELTDecoder *)celt_alloc(celt_decoder_get_size(mode, channels)),
1151          mode, channels, error);
1152 }
1153
1154 CELTDecoder *celt_decoder_init(CELTDecoder *st, const CELTMode *mode, int channels, int *error)
1155 {
1156    if (channels < 0 || channels > 2)
1157    {
1158       if (error)
1159          *error = CELT_BAD_ARG;
1160       return NULL;
1161    }
1162
1163    if (st==NULL)
1164    {
1165       if (error)
1166          *error = CELT_ALLOC_FAIL;
1167       return NULL;
1168    }
1169
1170    CELT_MEMSET((char*)st, 0, celt_decoder_get_size(mode, channels));
1171
1172    st->mode = mode;
1173    st->overlap = mode->overlap;
1174    st->channels = channels;
1175
1176    st->start = 0;
1177    st->end = st->mode->effEBands;
1178
1179    st->loss_count = 0;
1180
1181    if (error)
1182       *error = CELT_OK;
1183    return st;
1184 }
1185
1186 void celt_decoder_destroy(CELTDecoder *st)
1187 {
1188    celt_free(st);
1189 }
1190
1191 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1192 {
1193    int c;
1194    int pitch_index;
1195    int overlap = st->mode->overlap;
1196    celt_word16 fade = Q15ONE;
1197    int i, len;
1198    const int C = CHANNELS(st->channels);
1199    int offset;
1200    celt_sig *out_mem[2];
1201    celt_sig *decode_mem[2];
1202    celt_sig *overlap_mem[2];
1203    celt_word16 *lpc;
1204    celt_word16 *oldBandE;
1205    SAVE_STACK;
1206    
1207    for (c=0;c<C;c++)
1208    {
1209       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1210       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1211       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1212    }
1213    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1214    oldBandE = lpc+C*LPC_ORDER;
1215
1216    len = N+st->mode->overlap;
1217    
1218    if (st->loss_count == 0)
1219    {
1220       celt_word16 pitch_buf[MAX_PERIOD>>1];
1221       celt_word32 tmp=0;
1222       celt_word32 mem0[2]={0,0};
1223       celt_word16 mem1[2]={0,0};
1224       int len2 = len;
1225       /* FIXME: This is a kludge */
1226       if (len2>MAX_PERIOD>>1)
1227          len2 = MAX_PERIOD>>1;
1228       pitch_downsample(out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1229                        C, mem0, mem1);
1230       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1231                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1232       pitch_index = MAX_PERIOD-len2-pitch_index;
1233       st->last_pitch_index = pitch_index;
1234    } else {
1235       pitch_index = st->last_pitch_index;
1236       if (st->loss_count < 5)
1237          fade = QCONST16(.8f,15);
1238       else
1239          fade = 0;
1240    }
1241
1242    for (c=0;c<C;c++)
1243    {
1244       /* FIXME: This is more memory than necessary */
1245       celt_word32 e[2*MAX_PERIOD];
1246       celt_word16 exc[2*MAX_PERIOD];
1247       celt_word32 ac[LPC_ORDER+1];
1248       celt_word16 decay = 1;
1249       celt_word32 S1=0;
1250       celt_word16 mem[LPC_ORDER]={0};
1251
1252       offset = MAX_PERIOD-pitch_index;
1253       for (i=0;i<MAX_PERIOD;i++)
1254          exc[i] = ROUND16(out_mem[c][i], SIG_SHIFT);
1255
1256       if (st->loss_count == 0)
1257       {
1258          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1259                         LPC_ORDER, MAX_PERIOD);
1260
1261          /* Noise floor -40 dB */
1262 #ifdef FIXED_POINT
1263          ac[0] += SHR32(ac[0],13);
1264 #else
1265          ac[0] *= 1.0001f;
1266 #endif
1267          /* Lag windowing */
1268          for (i=1;i<=LPC_ORDER;i++)
1269          {
1270             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1271 #ifdef FIXED_POINT
1272             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1273 #else
1274             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1275 #endif
1276          }
1277
1278          _celt_lpc(lpc+c*LPC_ORDER, ac, LPC_ORDER);
1279       }
1280       fir(exc, lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1281       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1282       /* Check if the waveform is decaying (and if so how fast) */
1283       {
1284          celt_word32 E1=1, E2=1;
1285          int period;
1286          if (pitch_index <= MAX_PERIOD/2)
1287             period = pitch_index;
1288          else
1289             period = MAX_PERIOD/2;
1290          for (i=0;i<period;i++)
1291          {
1292             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1293             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1294          }
1295          if (E1 > E2)
1296             E1 = E2;
1297          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1298       }
1299
1300       /* Copy excitation, taking decay into account */
1301       for (i=0;i<len+st->mode->overlap;i++)
1302       {
1303          if (offset+i >= MAX_PERIOD)
1304          {
1305             offset -= pitch_index;
1306             decay = MULT16_16_Q15(decay, decay);
1307          }
1308          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1309          S1 += SHR32(MULT16_16(out_mem[c][offset+i],out_mem[c][offset+i]),8);
1310       }
1311
1312       iir(e, lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1313
1314       {
1315          celt_word32 S2=0;
1316          for (i=0;i<len+overlap;i++)
1317             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1318          /* This checks for an "explosion" in the synthesis */
1319 #ifdef FIXED_POINT
1320          if (!(S1 > SHR32(S2,2)))
1321 #else
1322          /* Float test is written this way to catch NaNs at the same time */
1323          if (!(S1 > 0.2f*S2))
1324 #endif
1325          {
1326             for (i=0;i<len+overlap;i++)
1327                e[i] = 0;
1328          } else if (S1 < S2)
1329          {
1330             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1331             for (i=0;i<len+overlap;i++)
1332                e[i] = MULT16_16_Q15(ratio, e[i]);
1333          }
1334       }
1335
1336       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1337          out_mem[c][i] = out_mem[c][N+i];
1338
1339       /* Apply TDAC to the concealed audio so that it blends with the
1340          previous and next frames */
1341       for (i=0;i<overlap/2;i++)
1342       {
1343          celt_word32 tmp1, tmp2;
1344          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1345                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1346          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1347                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1348          tmp1 = MULT16_32_Q15(fade, tmp1);
1349          tmp2 = MULT16_32_Q15(fade, tmp2);
1350          out_mem[c][MAX_PERIOD+i] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1351          out_mem[c][MAX_PERIOD+overlap-i-1] = MULT16_32_Q15(st->mode->window[i], tmp2);
1352          out_mem[c][MAX_PERIOD-N+i] += MULT16_32_Q15(st->mode->window[i], tmp1);
1353          out_mem[c][MAX_PERIOD-N+overlap-i-1] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1354       }
1355       for (i=0;i<N-overlap;i++)
1356          out_mem[c][MAX_PERIOD-N+overlap+i] = MULT16_32_Q15(fade, e[overlap+i]);
1357    }
1358
1359    deemphasis(out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1360    
1361    st->loss_count++;
1362
1363    RESTORE_STACK;
1364 }
1365
1366 #ifdef FIXED_POINT
1367 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1368 {
1369 #else
1370 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1371 {
1372 #endif
1373    int c, i, N;
1374    int has_fold;
1375    int bits;
1376    ec_dec _dec;
1377    ec_byte_buffer buf;
1378    VARDECL(celt_sig, freq);
1379    VARDECL(celt_norm, X);
1380    VARDECL(celt_ener, bandE);
1381    VARDECL(int, fine_quant);
1382    VARDECL(int, pulses);
1383    VARDECL(int, offsets);
1384    VARDECL(int, fine_priority);
1385    VARDECL(int, tf_res);
1386    celt_sig *out_mem[2];
1387    celt_sig *decode_mem[2];
1388    celt_sig *overlap_mem[2];
1389    celt_sig *out_syn[2];
1390    celt_word16 *lpc;
1391    celt_word16 *oldBandE;
1392
1393    int shortBlocks;
1394    int isTransient;
1395    int intra_ener;
1396    int transient_time;
1397    int transient_shift;
1398    int mdct_weight_shift=0;
1399    const int C = CHANNELS(st->channels);
1400    int mdct_weight_pos=0;
1401    int LM, M;
1402    int nbFilledBytes, nbAvailableBytes;
1403    int effEnd;
1404    SAVE_STACK;
1405
1406    if (pcm==NULL)
1407       return CELT_BAD_ARG;
1408
1409    for (LM=0;LM<4;LM++)
1410       if (st->mode->shortMdctSize<<LM==frame_size)
1411          break;
1412    if (LM>=MAX_CONFIG_SIZES)
1413       return CELT_BAD_ARG;
1414    M=1<<LM;
1415
1416    for (c=0;c<C;c++)
1417    {
1418       decode_mem[c] = st->_decode_mem + c*(DECODE_BUFFER_SIZE+st->overlap);
1419       out_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE-MAX_PERIOD;
1420       overlap_mem[c] = decode_mem[c]+DECODE_BUFFER_SIZE;
1421    }
1422    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
1423    oldBandE = lpc+C*LPC_ORDER;
1424
1425    N = M*st->mode->shortMdctSize;
1426
1427    effEnd = st->end;
1428    if (effEnd > st->mode->effEBands)
1429       effEnd = st->mode->effEBands;
1430
1431    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1432    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1433    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1434    for (c=0;c<C;c++)
1435       for (i=0;i<M*st->mode->eBands[st->start];i++)
1436          X[c*N+i] = 0;
1437    for (c=0;c<C;c++)
1438       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1439          X[c*N+i] = 0;
1440
1441    if (data == NULL)
1442    {
1443       celt_decode_lost(st, pcm, N, LM);
1444       RESTORE_STACK;
1445       return CELT_OK;
1446    }
1447    if (len<0) {
1448      RESTORE_STACK;
1449      return CELT_BAD_ARG;
1450    }
1451    
1452    if (dec == NULL)
1453    {
1454       ec_byte_readinit(&buf,(unsigned char*)data,len);
1455       ec_dec_init(&_dec,&buf);
1456       dec = &_dec;
1457       nbFilledBytes = 0;
1458    } else {
1459       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1460    }
1461    nbAvailableBytes = len-nbFilledBytes;
1462
1463    /* Decode the global flags (first symbols in the stream) */
1464    intra_ener = ec_dec_bit_prob(dec, 8192);
1465    /* Get band energies */
1466    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
1467          intra_ener, st->mode->prob, dec, C, LM);
1468
1469    isTransient = ec_dec_bit_prob(dec, 8192);
1470
1471    if (isTransient)
1472       shortBlocks = M;
1473    else
1474       shortBlocks = 0;
1475
1476    if (isTransient)
1477    {
1478       transient_shift = ec_dec_uint(dec, 4);
1479       if (transient_shift == 3)
1480       {
1481          int transient_time_quant;
1482          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1483          transient_time_quant = ec_dec_uint(dec, max_time);
1484          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1485       } else {
1486          mdct_weight_shift = transient_shift;
1487          if (mdct_weight_shift && M>2)
1488             mdct_weight_pos = ec_dec_uint(dec, M-1);
1489          transient_shift = 0;
1490          transient_time = 0;
1491       }
1492    } else {
1493       transient_time = -1;
1494       transient_shift = 0;
1495    }
1496
1497    ALLOC(tf_res, st->mode->nbEBands, int);
1498    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1499
1500    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1501    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1502
1503    ALLOC(pulses, st->mode->nbEBands, int);
1504    ALLOC(offsets, st->mode->nbEBands, int);
1505    ALLOC(fine_priority, st->mode->nbEBands, int);
1506
1507    for (i=0;i<st->mode->nbEBands;i++)
1508       offsets[i] = 0;
1509
1510    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1511    ALLOC(fine_quant, st->mode->nbEBands, int);
1512    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, LM);
1513    /*bits = ec_dec_tell(dec, 0);
1514    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1515    
1516    unquant_fine_energy(st->mode, st->start, st->end, bandE, oldBandE, fine_quant, dec, C);
1517
1518    /* Decode fixed codebook */
1519    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1520
1521    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
1522          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1523
1524    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
1525
1526    if (mdct_weight_shift)
1527    {
1528       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1529    }
1530
1531    /* Synthesis */
1532    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1533
1534    CELT_MOVE(decode_mem[0], decode_mem[0]+N, DECODE_BUFFER_SIZE-N);
1535    if (C==2)
1536       CELT_MOVE(decode_mem[1], decode_mem[1]+N, DECODE_BUFFER_SIZE-N);
1537
1538    for (c=0;c<C;c++)
1539       for (i=0;i<M*st->mode->eBands[st->start];i++)
1540          freq[c*N+i] = 0;
1541    for (c=0;c<C;c++)
1542       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1543          freq[c*N+i] = 0;
1544
1545    out_syn[0] = out_mem[0]+MAX_PERIOD-N;
1546    if (C==2)
1547       out_syn[1] = out_mem[1]+MAX_PERIOD-N;
1548
1549    /* Compute inverse MDCTs */
1550    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time,
1551          transient_shift, out_syn, overlap_mem, C, LM);
1552
1553    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
1554    st->loss_count = 0;
1555    RESTORE_STACK;
1556    if (ec_dec_get_error(dec))
1557       return CELT_CORRUPTED_DATA;
1558    else
1559       return CELT_OK;
1560 }
1561
1562 #ifdef FIXED_POINT
1563 #ifndef DISABLE_FLOAT_API
1564 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1565 {
1566    int j, ret, C, N, LM, M;
1567    VARDECL(celt_int16, out);
1568    SAVE_STACK;
1569
1570    if (pcm==NULL)
1571       return CELT_BAD_ARG;
1572
1573    for (LM=0;LM<4;LM++)
1574       if (st->mode->shortMdctSize<<LM==frame_size)
1575          break;
1576    if (LM>=MAX_CONFIG_SIZES)
1577       return CELT_BAD_ARG;
1578    M=1<<LM;
1579
1580    C = CHANNELS(st->channels);
1581    N = M*st->mode->shortMdctSize;
1582    
1583    ALLOC(out, C*N, celt_int16);
1584    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1585    if (ret==0)
1586       for (j=0;j<C*N;j++)
1587          pcm[j]=out[j]*(1.f/32768.f);
1588      
1589    RESTORE_STACK;
1590    return ret;
1591 }
1592 #endif /*DISABLE_FLOAT_API*/
1593 #else
1594 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1595 {
1596    int j, ret, C, N, LM, M;
1597    VARDECL(celt_sig, out);
1598    SAVE_STACK;
1599
1600    if (pcm==NULL)
1601       return CELT_BAD_ARG;
1602
1603    for (LM=0;LM<4;LM++)
1604       if (st->mode->shortMdctSize<<LM==frame_size)
1605          break;
1606    if (LM>=MAX_CONFIG_SIZES)
1607       return CELT_BAD_ARG;
1608    M=1<<LM;
1609
1610    C = CHANNELS(st->channels);
1611    N = M*st->mode->shortMdctSize;
1612    ALLOC(out, C*N, celt_sig);
1613
1614    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1615
1616    if (ret==0)
1617       for (j=0;j<C*N;j++)
1618          pcm[j] = FLOAT2INT16 (out[j]);
1619    
1620    RESTORE_STACK;
1621    return ret;
1622 }
1623 #endif
1624
1625 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1626 {
1627    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1628 }
1629
1630 #ifndef DISABLE_FLOAT_API
1631 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1632 {
1633    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1634 }
1635 #endif /* DISABLE_FLOAT_API */
1636
1637 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1638 {
1639    va_list ap;
1640
1641    va_start(ap, request);
1642    switch (request)
1643    {
1644       case CELT_GET_MODE_REQUEST:
1645       {
1646          const CELTMode ** value = va_arg(ap, const CELTMode**);
1647          if (value==0)
1648             goto bad_arg;
1649          *value=st->mode;
1650       }
1651       break;
1652       case CELT_SET_START_BAND_REQUEST:
1653       {
1654          celt_int32 value = va_arg(ap, celt_int32);
1655          if (value<0 || value>=st->mode->nbEBands)
1656             goto bad_arg;
1657          st->start = value;
1658       }
1659       break;
1660       case CELT_SET_END_BAND_REQUEST:
1661       {
1662          celt_int32 value = va_arg(ap, celt_int32);
1663          if (value<0 || value>=st->mode->nbEBands)
1664             goto bad_arg;
1665          st->end = value;
1666       }
1667       break;
1668       case CELT_RESET_STATE:
1669       {
1670          CELT_MEMSET((char*)&st->DECODER_RESET_START, 0,
1671                celt_decoder_get_size(st->mode, st->channels)-
1672                ((char*)&st->DECODER_RESET_START - (char*)st));
1673       }
1674       break;
1675       default:
1676          goto bad_request;
1677    }
1678    va_end(ap);
1679    return CELT_OK;
1680 bad_arg:
1681    va_end(ap);
1682    return CELT_BAD_ARG;
1683 bad_request:
1684       va_end(ap);
1685   return CELT_UNIMPLEMENTED;
1686 }
1687
1688 const char *celt_strerror(int error)
1689 {
1690    static const char *error_strings[8] = {
1691       "success",
1692       "invalid argument",
1693       "invalid mode",
1694       "internal error",
1695       "corrupted stream",
1696       "request not implemented",
1697       "invalid state",
1698       "memory allocation failed"
1699    };
1700    if (error > 0 || error < -7)
1701       return "unknown error";
1702    else 
1703       return error_strings[-error];
1704 }
1705