Implemented variable spreading amount in the decoder
[opus.git] / libcelt / celt.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Copyright (c) 2008 Gregory Maxwell 
4    Written by Jean-Marc Valin and Gregory Maxwell */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9    
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12    
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16    
17    - Neither the name of the Xiph.org Foundation nor the names of its
18    contributors may be used to endorse or promote products derived from
19    this software without specific prior written permission.
20    
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #ifdef HAVE_CONFIG_H
35 #include "config.h"
36 #endif
37
38 #define CELT_C
39
40 #include "os_support.h"
41 #include "mdct.h"
42 #include <math.h>
43 #include "celt.h"
44 #include "pitch.h"
45 #include "bands.h"
46 #include "modes.h"
47 #include "entcode.h"
48 #include "quant_bands.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54 #include "plc.h"
55
56 #ifdef FIXED_POINT
57 static const celt_word16 transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135f, 0.0337639f, 0.0748914f, 0.1304955f,
63    0.1986827f, 0.2771308f, 0.3631685f, 0.4538658f,
64    0.5461342f, 0.6368315f, 0.7228692f, 0.8013173f,
65    0.8695045f, 0.9251086f, 0.9662361f, 0.9914865f};
66 #endif
67
68 #define ENCODERVALID   0x4c434554
69 #define ENCODERPARTIAL 0x5445434c
70 #define ENCODERFREED   0x4c004500
71    
72 /** Encoder state 
73  @brief Encoder state
74  */
75 struct CELTEncoder {
76    celt_uint32 marker;
77    const CELTMode *mode;     /**< Mode used by the encoder */
78    int overlap;
79    int channels;
80    
81    int force_intra;
82    int delayedIntra;
83    celt_word16 tonal_average;
84    int fold_decision;
85    celt_word16 gain_prod;
86    celt_word32 frame_max;
87    int start, end;
88
89    /* VBR-related parameters */
90    celt_int32 vbr_reservoir;
91    celt_int32 vbr_drift;
92    celt_int32 vbr_offset;
93    celt_int32 vbr_count;
94
95    celt_int32 vbr_rate_norm; /* Target number of 16th bits per frame */
96    celt_word32 preemph_memE[2];
97    celt_word32 preemph_memD[2];
98
99    celt_sig *in_mem;
100    celt_sig *out_mem;
101
102    celt_word16 *oldBandE;
103 };
104
105 static int check_encoder(const CELTEncoder *st) 
106 {
107    if (st==NULL)
108    {
109       celt_warning("NULL passed as an encoder structure");  
110       return CELT_INVALID_STATE;
111    }
112    if (st->marker == ENCODERVALID)
113       return CELT_OK;
114    if (st->marker == ENCODERFREED)
115       celt_warning("Referencing an encoder that has already been freed");
116    else
117       celt_warning("This is not a valid CELT encoder structure");
118    return CELT_INVALID_STATE;
119 }
120
121 CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
122 {
123    int C;
124    CELTEncoder *st;
125
126    if (check_mode(mode) != CELT_OK)
127    {
128       if (error)
129          *error = CELT_INVALID_MODE;
130       return NULL;
131    }
132
133    if (channels < 0 || channels > 2)
134    {
135       celt_warning("Only mono and stereo supported");
136       if (error)
137          *error = CELT_BAD_ARG;
138       return NULL;
139    }
140
141    C = channels;
142    st = celt_alloc(sizeof(CELTEncoder));
143    
144    if (st==NULL)
145    {
146       if (error)
147          *error = CELT_ALLOC_FAIL;
148       return NULL;
149    }
150    st->marker = ENCODERPARTIAL;
151    st->mode = mode;
152    st->overlap = mode->overlap;
153    st->channels = channels;
154
155    st->start = 0;
156    st->end = st->mode->effEBands;
157
158    st->vbr_rate_norm = 0;
159    st->force_intra  = 0;
160    st->delayedIntra = 1;
161    st->tonal_average = QCONST16(1.f,8);
162    st->fold_decision = 1;
163
164    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig));
165    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig));
166
167    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
168
169    if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL))
170    {
171       if (error)
172          *error = CELT_OK;
173       st->marker   = ENCODERVALID;
174       return st;
175    }
176    /* If the setup fails for some reason deallocate it. */
177    celt_encoder_destroy(st);  
178    if (error)
179       *error = CELT_ALLOC_FAIL;
180    return NULL;
181 }
182
183 void celt_encoder_destroy(CELTEncoder *st)
184 {
185    if (st == NULL)
186    {
187       celt_warning("NULL passed to celt_encoder_destroy");
188       return;
189    }
190
191    if (st->marker == ENCODERFREED)
192    {
193       celt_warning("Freeing an encoder which has already been freed"); 
194       return;
195    }
196
197    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
198    {
199       celt_warning("This is not a valid CELT encoder structure");
200       return;
201    }
202    /*Check_mode is non-fatal here because we can still free
203     the encoder memory even if the mode is bad, although calling
204     the free functions in this order is a violation of the API.*/
205    check_mode(st->mode);
206    
207    celt_free(st->in_mem);
208    celt_free(st->out_mem);
209    celt_free(st->oldBandE);
210    
211    st->marker = ENCODERFREED;
212    
213    celt_free(st);
214 }
215
216 static inline celt_int16 FLOAT2INT16(float x)
217 {
218    x = x*CELT_SIG_SCALE;
219    x = MAX32(x, -32768);
220    x = MIN32(x, 32767);
221    return (celt_int16)float2int(x);
222 }
223
224 static inline celt_word16 SIG2WORD16(celt_sig x)
225 {
226 #ifdef FIXED_POINT
227    x = PSHR32(x, SIG_SHIFT);
228    x = MAX32(x, -32768);
229    x = MIN32(x, 32767);
230    return EXTRACT16(x);
231 #else
232    return (celt_word16)x;
233 #endif
234 }
235
236 static int transient_analysis(const celt_word32 * restrict in, int len, int C,
237                               int *transient_time, int *transient_shift,
238                               celt_word32 *frame_max, int overlap)
239 {
240    int i, n;
241    celt_word32 ratio;
242    celt_word32 threshold;
243    VARDECL(celt_word32, begin);
244    SAVE_STACK;
245    ALLOC(begin, len+1, celt_word32);
246    begin[0] = 0;
247    if (C==1)
248    {
249       for (i=0;i<len;i++)
250          begin[i+1] = MAX32(begin[i], ABS32(in[i]));
251    } else {
252       for (i=0;i<len;i++)
253          begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
254                                             ABS32(in[C*i+1])));
255    }
256    n = -1;
257
258    threshold = MULT16_32_Q15(QCONST16(.4f,15),begin[len]);
259    /* If the following condition isn't met, there's just no way
260       we'll have a transient*/
261    if (*frame_max < threshold)
262    {
263       /* It's likely we have a transient, now find it */
264       for (i=8;i<len-8;i++)
265       {
266          if (begin[i+1] < threshold)
267             n=i;
268       }
269    }
270    if (n<32)
271    {
272       n = -1;
273       ratio = 0;
274    } else {
275       ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
276    }
277
278    if (ratio > 45)
279       *transient_shift = 3;
280    else
281       *transient_shift = 0;
282    
283    *transient_time = n;
284    *frame_max = begin[len-overlap];
285
286    RESTORE_STACK;
287    return ratio > 0;
288 }
289
290 /** Apply window and compute the MDCT for all sub-frames and 
291     all channels in a frame */
292 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM)
293 {
294    const int C = CHANNELS(_C);
295    if (C==1 && !shortBlocks)
296    {
297       const int overlap = OVERLAP(mode);
298       clt_mdct_forward(&mode->mdct, in, out, mode->window, overlap, mode->maxLM-LM);
299    } else {
300       const int overlap = OVERLAP(mode);
301       int N = mode->shortMdctSize<<LM;
302       int B = 1;
303       int b, c;
304       VARDECL(celt_word32, x);
305       VARDECL(celt_word32, tmp);
306       SAVE_STACK;
307       if (shortBlocks)
308       {
309          /*lookup = &mode->mdct[0];*/
310          N = mode->shortMdctSize;
311          B = shortBlocks;
312       }
313       ALLOC(x, N+overlap, celt_word32);
314       ALLOC(tmp, N, celt_word32);
315       for (c=0;c<C;c++)
316       {
317          for (b=0;b<B;b++)
318          {
319             int j;
320             for (j=0;j<N+overlap;j++)
321                x[j] = in[C*(b*N+j)+c];
322             clt_mdct_forward(&mode->mdct, x, tmp, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
323             /* Interleaving the sub-frames */
324             for (j=0;j<N;j++)
325                out[(j*B+b)+c*N*B] = tmp[j];
326          }
327       }
328       RESTORE_STACK;
329    }
330 }
331
332 /** Compute the IMDCT and apply window for all sub-frames and 
333     all channels in a frame */
334 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X, int transient_time, int transient_shift, celt_sig * restrict out_mem, int _C, int LM)
335 {
336    int c, N4;
337    const int C = CHANNELS(_C);
338    const int N = mode->shortMdctSize<<LM;
339    const int overlap = OVERLAP(mode);
340    N4 = (N-overlap)>>1;
341    for (c=0;c<C;c++)
342    {
343       int j;
344       if (transient_shift==0 && C==1 && !shortBlocks) {
345          clt_mdct_backward(&mode->mdct, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap, mode->maxLM-LM);
346       } else {
347          VARDECL(celt_word32, x);
348          VARDECL(celt_word32, tmp);
349          int b;
350          int N2 = N;
351          int B = 1;
352          int n4offset=0;
353          SAVE_STACK;
354          
355          ALLOC(x, 2*N, celt_word32);
356          ALLOC(tmp, N, celt_word32);
357
358          if (shortBlocks)
359          {
360             /*lookup = &mode->mdct[0];*/
361             N2 = mode->shortMdctSize;
362             B = shortBlocks;
363             n4offset = N4;
364          }
365          /* Prevents problems from the imdct doing the overlap-add */
366          CELT_MEMSET(x+N4, 0, N2);
367
368          for (b=0;b<B;b++)
369          {
370             /* De-interleaving the sub-frames */
371             for (j=0;j<N2;j++)
372                tmp[j] = X[(j*B+b)+c*N2*B];
373             clt_mdct_backward(&mode->mdct, tmp, x+n4offset+N2*b, mode->window, overlap, shortBlocks ? mode->maxLM : mode->maxLM-LM);
374          }
375
376          if (transient_shift > 0)
377          {
378 #ifdef FIXED_POINT
379             for (j=0;j<16;j++)
380                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
381             for (j=transient_time;j<N+overlap;j++)
382                x[N4+j] = SHL32(x[N4+j], transient_shift);
383 #else
384             for (j=0;j<16;j++)
385                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
386             for (j=transient_time;j<N+overlap;j++)
387                x[N4+j] *= 1<<transient_shift;
388 #endif
389          }
390          /* The first and last part would need to be set to zero 
391             if we actually wanted to use them. */
392          for (j=0;j<overlap;j++)
393             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
394          for (j=0;j<overlap;j++)
395             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
396          for (j=0;j<2*N4;j++)
397             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
398          RESTORE_STACK;
399       }
400    }
401 }
402
403 static void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem)
404 {
405    const int C = CHANNELS(_C);
406    int c;
407    for (c=0;c<C;c++)
408    {
409       int j;
410       celt_sig * restrict x;
411       celt_word16  * restrict y;
412       celt_sig m = mem[c];
413       x = &in[C*(MAX_PERIOD-N)+c];
414       y = pcm+c;
415       for (j=0;j<N;j++)
416       {
417          celt_sig tmp = *x + m;
418          m = MULT16_32_Q15(coef[0], tmp)
419            - MULT16_32_Q15(coef[1], *x);
420          tmp = SHL32(MULT16_32_Q15(coef[3], tmp), 2);
421          *y = SCALEOUT(SIG2WORD16(tmp));
422          x+=C;
423          y+=C;
424       }
425       mem[c] = m;
426    }
427 }
428
429 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start,
430                        int end, int N,
431                        int mdct_weight_shift, int end_band, int _C, int renorm, int M)
432 {
433    int m, i, c;
434    const int C = CHANNELS(_C);
435    for (c=0;c<C;c++)
436       for (m=start;m<end;m++)
437          for (i=m+c*N;i<(c+1)*N;i+=M)
438 #ifdef FIXED_POINT
439             X[i] = SHR16(X[i], mdct_weight_shift);
440 #else
441             X[i] = (1.f/(1<<mdct_weight_shift))*X[i];
442 #endif
443    if (renorm)
444       renormalise_bands(mode, X, end_band, C, M);
445 }
446
447 static signed char tf_select_table[4][8] = {
448       {0, -1, 0, -1,    0,-1, 0,-1},
449       {0, -1, 0, -2,    1, 0, 1 -1},
450       {0, -2, 0, -3,    2, 0, 1 -1},
451       {0, -2, 0, -3,    2, 0, 1 -1},
452 };
453
454 static int tf_analysis(celt_word16 *bandLogE, celt_word16 *oldBandE, int len, int C, int isTransient, int *tf_res, int nbCompressedBytes)
455 {
456    int i;
457    celt_word16 threshold;
458    VARDECL(celt_word16, metric);
459    celt_word32 average=0;
460    celt_word32 cost0;
461    celt_word32 cost1;
462    VARDECL(int, path0);
463    VARDECL(int, path1);
464    celt_word16 lambda;
465    int tf_select=0;
466    SAVE_STACK;
467
468    /* FIXME: Should check number of bytes *left* */
469    if (nbCompressedBytes<15*C)
470    {
471       for (i=0;i<len;i++)
472          tf_res[i] = 0;
473       return 0;
474    }
475    if (nbCompressedBytes<40)
476       lambda = QCONST16(5.f, DB_SHIFT);
477    else if (nbCompressedBytes<60)
478       lambda = QCONST16(2.f, DB_SHIFT);
479    else if (nbCompressedBytes<100)
480       lambda = QCONST16(1.f, DB_SHIFT);
481    else
482       lambda = QCONST16(.5f, DB_SHIFT);
483
484    ALLOC(metric, len, celt_word16);
485    ALLOC(path0, len, int);
486    ALLOC(path1, len, int);
487    for (i=0;i<len;i++)
488    {
489       metric[i] = SUB16(bandLogE[i], oldBandE[i]);
490       average += metric[i];
491    }
492    if (C==2)
493    {
494       average = 0;
495       for (i=0;i<len;i++)
496       {
497          metric[i] = HALF32(metric[i]) + HALF32(SUB16(bandLogE[i+len], oldBandE[i+len]));
498          average += metric[i];
499       }
500    }
501    average = DIV32(average, len);
502    /*if (!isTransient)
503       printf ("%f\n", average);*/
504    if (isTransient)
505    {
506       threshold = QCONST16(1.f,DB_SHIFT);
507       tf_select = average > QCONST16(3.f,DB_SHIFT);
508    } else {
509       threshold = QCONST16(.5f,DB_SHIFT);
510       tf_select = average > QCONST16(1.f,DB_SHIFT);
511    }
512    cost0 = 0;
513    cost1 = lambda;
514    /* Viterbi forward pass */
515    for (i=1;i<len;i++)
516    {
517       celt_word32 curr0, curr1;
518       celt_word32 from0, from1;
519
520       from0 = cost0;
521       from1 = cost1 + lambda;
522       if (from0 < from1)
523       {
524          curr0 = from0;
525          path0[i]= 0;
526       } else {
527          curr0 = from1;
528          path0[i]= 1;
529       }
530
531       from0 = cost0 + lambda;
532       from1 = cost1;
533       if (from0 < from1)
534       {
535          curr1 = from0;
536          path1[i]= 0;
537       } else {
538          curr1 = from1;
539          path1[i]= 1;
540       }
541       cost0 = curr0 + (metric[i]-threshold);
542       cost1 = curr1;
543    }
544    tf_res[len-1] = cost0 < cost1 ? 0 : 1;
545    /* Viterbi backward pass to check the decisions */
546    for (i=len-2;i>=0;i--)
547    {
548       if (tf_res[i+1] == 1)
549          tf_res[i] = path1[i+1];
550       else
551          tf_res[i] = path0[i+1];
552    }
553    RESTORE_STACK;
554    return tf_select;
555 }
556
557 static void tf_encode(int start, int end, int isTransient, int *tf_res, int nbCompressedBytes, int LM, int tf_select, ec_enc *enc)
558 {
559    int curr, i;
560    if (8*nbCompressedBytes - ec_enc_tell(enc, 0) < 100)
561    {
562       for (i=start;i<end;i++)
563          tf_res[i] = isTransient;
564    } else {
565       ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
566       curr = tf_res[start];
567       for (i=start+1;i<end;i++)
568       {
569          ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
570          curr = tf_res[i];
571       }
572    }
573    ec_enc_bits(enc, tf_select, 1);
574    for (i=start;i<end;i++)
575       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
576 }
577
578 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int nbCompressedBytes, int LM, ec_dec *dec)
579 {
580    int i, curr, tf_select;
581    if (8*nbCompressedBytes - ec_dec_tell(dec, 0) < 100)
582    {
583       for (i=start;i<end;i++)
584          tf_res[i] = isTransient;
585    } else {
586       tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
587       curr = tf_res[start];
588       for (i=start+1;i<end;i++)
589       {
590          tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
591          curr = tf_res[i];
592       }
593    }
594    tf_select = ec_dec_bits(dec, 1);
595    for (i=start;i<end;i++)
596       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
597 }
598
599 #ifdef FIXED_POINT
600 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
601 {
602 #else
603 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
604 {
605 #endif
606    int i, c, N, NN, N4;
607    int bits;
608    int has_fold=1;
609    ec_byte_buffer buf;
610    ec_enc         _enc;
611    VARDECL(celt_sig, in);
612    VARDECL(celt_sig, freq);
613    VARDECL(celt_norm, X);
614    VARDECL(celt_ener, bandE);
615    VARDECL(celt_word16, bandLogE);
616    VARDECL(int, fine_quant);
617    VARDECL(celt_word16, error);
618    VARDECL(int, pulses);
619    VARDECL(int, offsets);
620    VARDECL(int, fine_priority);
621    VARDECL(int, tf_res);
622    int intra_ener = 0;
623    int shortBlocks=0;
624    int isTransient=0;
625    int transient_time, transient_time_quant;
626    int transient_shift;
627    int resynth;
628    const int C = CHANNELS(st->channels);
629    int mdct_weight_shift = 0;
630    int mdct_weight_pos=0;
631    int LM, M;
632    int tf_select;
633    celt_int32 vbr_rate=0;
634    celt_word16 max_decay;
635    int nbFilledBytes, nbAvailableBytes;
636    int effEnd;
637    SAVE_STACK;
638
639    if (check_encoder(st) != CELT_OK)
640       return CELT_INVALID_STATE;
641
642    if (check_mode(st->mode) != CELT_OK)
643       return CELT_INVALID_MODE;
644
645    if (nbCompressedBytes<0 || pcm==NULL)
646      return CELT_BAD_ARG;
647
648    for (LM=0;LM<4;LM++)
649       if (st->mode->shortMdctSize<<LM==frame_size)
650          break;
651    if (LM>=MAX_CONFIG_SIZES)
652       return CELT_BAD_ARG;
653    M=1<<LM;
654
655    if (enc==NULL)
656    {
657       ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
658       ec_enc_init(&_enc,&buf);
659       enc = &_enc;
660       nbFilledBytes=0;
661    } else {
662       nbFilledBytes=(ec_enc_tell(enc, 0)+4)>>3;
663    }
664    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
665
666    effEnd = st->end;
667    if (effEnd > st->mode->effEBands)
668       effEnd = st->mode->effEBands;
669
670    N = M*st->mode->shortMdctSize;
671    N4 = (N-st->overlap)>>1;
672    ALLOC(in, 2*C*N-2*C*N4, celt_sig);
673
674    CELT_COPY(in, st->in_mem, C*st->overlap);
675    for (c=0;c<C;c++)
676    {
677       const celt_word16 * restrict pcmp = pcm+c;
678       celt_sig * restrict inp = in+C*st->overlap+c;
679       for (i=0;i<N;i++)
680       {
681          /* Apply pre-emphasis */
682          celt_sig tmp = MULT16_16(st->mode->preemph[2], SCALEIN(*pcmp));
683          *inp = tmp + st->preemph_memE[c];
684          st->preemph_memE[c] = MULT16_32_Q15(st->mode->preemph[1], *inp)
685                              - MULT16_32_Q15(st->mode->preemph[0], tmp);
686          inp += C;
687          pcmp += C;
688       }
689    }
690    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
691
692    /* Transient handling */
693    transient_time = -1;
694    transient_time_quant = -1;
695    transient_shift = 0;
696    isTransient = 0;
697
698    resynth = optional_resynthesis!=NULL;
699
700    if (M > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max, st->overlap))
701    {
702 #ifndef FIXED_POINT
703       float gain_1;
704 #endif
705       /* Apply the inverse shaping window */
706       if (transient_shift)
707       {
708          transient_time_quant = transient_time*(celt_int32)8000/st->mode->Fs;
709          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
710 #ifdef FIXED_POINT
711          for (c=0;c<C;c++)
712             for (i=0;i<16;i++)
713                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
714          for (c=0;c<C;c++)
715             for (i=transient_time;i<N+st->overlap;i++)
716                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
717 #else
718          for (c=0;c<C;c++)
719             for (i=0;i<16;i++)
720                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
721          gain_1 = 1.f/(1<<transient_shift);
722          for (c=0;c<C;c++)
723             for (i=transient_time;i<N+st->overlap;i++)
724                in[C*i+c] *= gain_1;
725 #endif
726       }
727       isTransient = 1;
728       has_fold = 1;
729    }
730
731    if (isTransient)
732       shortBlocks = M;
733    else
734       shortBlocks = 0;
735
736    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
737    ALLOC(bandE,st->mode->nbEBands*C, celt_ener);
738    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16);
739    /* Compute MDCTs */
740    compute_mdcts(st->mode, shortBlocks, in, freq, C, LM);
741
742    ALLOC(X, C*N, celt_norm);         /**< Interleaved normalised MDCTs */
743
744    compute_band_energies(st->mode, freq, bandE, effEnd, C, M);
745
746    amp2Log2(st->mode, effEnd, st->end, bandE, bandLogE, C);
747
748    /* Band normalisation */
749    normalise_bands(st->mode, freq, X, bandE, effEnd, C, M);
750    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M))
751       has_fold = 0;
752
753    /* Don't use intra energy when we're operating at low bit-rate */
754    intra_ener = st->force_intra || (st->delayedIntra && nbAvailableBytes > st->end);
755    if (shortBlocks || intra_decision(bandLogE, st->oldBandE, st->start, effEnd, st->mode->nbEBands, C))
756       st->delayedIntra = 1;
757    else
758       st->delayedIntra = 0;
759
760    NN = M*st->mode->eBands[effEnd];
761    if (shortBlocks && !transient_shift)
762    {
763       celt_word32 sum[8]={1,1,1,1,1,1,1,1};
764       int m;
765       for (c=0;c<C;c++)
766       {
767          m=0;
768          do {
769             celt_word32 tmp=0;
770             for (i=m+c*N;i<c*N+NN;i+=M)
771                tmp += ABS32(X[i]);
772             sum[m++] += tmp;
773          } while (m<M);
774       }
775       m=0;
776 #ifdef FIXED_POINT
777       do {
778          if (SHR32(sum[m+1],3) > sum[m])
779          {
780             mdct_weight_shift=2;
781             mdct_weight_pos = m;
782          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
783          {
784             mdct_weight_shift=1;
785             mdct_weight_pos = m;
786          }
787          m++;
788       } while (m<M-1);
789 #else
790       do {
791          if (sum[m+1] > 8*sum[m])
792          {
793             mdct_weight_shift=2;
794             mdct_weight_pos = m;
795          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
796          {
797             mdct_weight_shift=1;
798             mdct_weight_pos = m;
799          }
800          m++;
801       } while (m<M-1);
802 #endif
803       if (mdct_weight_shift)
804          mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M);
805    }
806
807    /* Encode the global flags using a simple probability model
808       (first symbols in the stream) */
809    ec_enc_bit_prob(enc, intra_ener, 8192);
810    ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
811    ec_enc_bit_prob(enc, has_fold>>1, 8192);
812    ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
813
814    if (shortBlocks)
815    {
816       if (transient_shift)
817       {
818          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
819          ec_enc_uint(enc, transient_shift, 4);
820          ec_enc_uint(enc, transient_time_quant, max_time);
821       } else {
822          ec_enc_uint(enc, mdct_weight_shift, 4);
823          if (mdct_weight_shift && M!=2)
824             ec_enc_uint(enc, mdct_weight_pos, M-1);
825       }
826    }
827
828    ALLOC(fine_quant, st->mode->nbEBands, int);
829    ALLOC(pulses, st->mode->nbEBands, int);
830
831    vbr_rate = M*st->vbr_rate_norm;
832    /* Computes the max bit-rate allowed in VBR more to avoid busting the budget */
833    if (st->vbr_rate_norm>0)
834    {
835       celt_int32 vbr_bound, max_allowed;
836
837       vbr_bound = vbr_rate;
838       max_allowed = (vbr_rate + vbr_bound - st->vbr_reservoir)>>(BITRES+3);
839       if (max_allowed < 4)
840          max_allowed = 4;
841       if (max_allowed < nbAvailableBytes)
842          nbAvailableBytes = max_allowed;
843    }
844
845    ALLOC(tf_res, st->mode->nbEBands, int);
846    tf_select = tf_analysis(bandLogE, st->oldBandE, effEnd, C, isTransient, tf_res, nbAvailableBytes);
847    for (i=effEnd;i<st->end;i++)
848       tf_res[i] = tf_res[effEnd-1];
849
850    /* Bit allocation */
851    ALLOC(error, C*st->mode->nbEBands, celt_word16);
852
853 #ifdef FIXED_POINT
854       max_decay = MIN32(QCONST16(16,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
855 #else
856    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
857 #endif
858    quant_coarse_energy(st->mode, st->start, st->end, bandLogE, st->oldBandE, nbCompressedBytes*8, intra_ener, st->mode->prob, error, enc, C, LM, max_decay);
859    /* Variable bitrate */
860    if (vbr_rate>0)
861    {
862      celt_word16 alpha;
863      celt_int32 delta;
864      /* The target rate in 16th bits per frame */
865      celt_int32 target=vbr_rate;
866    
867      /* Shortblocks get a large boost in bitrate, but since they 
868         are uncommon long blocks are not greatly effected */
869      if (shortBlocks)
870        target*=2;
871      else if (M > 1)
872        target-=(target+14)/28;
873
874      /* The average energy is removed from the target and the actual 
875         energy added*/
876      target=target+st->vbr_offset-588+ec_enc_tell(enc, BITRES);
877
878      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
879      target=IMIN(nbAvailableBytes,target);
880      /* Make the adaptation coef (alpha) higher at the beginning */
881      if (st->vbr_count < 990)
882      {
883         st->vbr_count++;
884         alpha = celt_rcp(SHL32(EXTEND32(st->vbr_count+10),16));
885         /*printf ("%d %d\n", st->vbr_count+10, alpha);*/
886      } else
887         alpha = QCONST16(.001f,15);
888
889      /* By how much did we "miss" the target on that frame */
890      delta = (8<<BITRES)*(celt_int32)target - vbr_rate;
891      /* How many bits have we used in excess of what we're allowed */
892      st->vbr_reservoir += delta;
893      /*printf ("%d\n", st->vbr_reservoir);*/
894
895      /* Compute the offset we need to apply in order to reach the target */
896      st->vbr_drift += MULT16_32_Q15(alpha,delta-st->vbr_offset-st->vbr_drift);
897      st->vbr_offset = -st->vbr_drift;
898      /*printf ("%d\n", st->vbr_drift);*/
899
900      /* We could use any multiple of vbr_rate as bound (depending on the delay) */
901      if (st->vbr_reservoir < 0)
902      {
903         /* We're under the min value -- increase rate */
904         int adjust = 1-(st->vbr_reservoir-1)/(8<<BITRES);
905         st->vbr_reservoir += adjust*(8<<BITRES);
906         target += adjust;
907         /*printf ("+%d\n", adjust);*/
908      }
909      if (target < nbAvailableBytes)
910         nbAvailableBytes = target;
911      nbCompressedBytes = nbAvailableBytes + nbFilledBytes;
912
913      /* This moves the raw bits to take into account the new compressed size */
914      ec_byte_shrink(&buf, nbCompressedBytes);
915    }
916
917    tf_encode(st->start, st->end, isTransient, tf_res, nbAvailableBytes, LM, tf_select, enc);
918
919    ALLOC(offsets, st->mode->nbEBands, int);
920    ALLOC(fine_priority, st->mode->nbEBands, int);
921
922    for (i=0;i<st->mode->nbEBands;i++)
923       offsets[i] = 0;
924    bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
925    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
926
927    quant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, enc, C);
928
929 #ifdef MEASURE_NORM_MSE
930    float X0[3000];
931    float bandE0[60];
932    for (c=0;c<C;c++)
933       for (i=0;i<N;i++)
934          X0[i+c*N] = X[i+c*N];
935    for (i=0;i<C*st->mode->nbEBands;i++)
936       bandE0[i] = bandE[i];
937 #endif
938
939    /* Residual quantisation */
940    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, bandE, pulses, shortBlocks, has_fold, tf_res, resynth, nbCompressedBytes*8, enc, LM);
941
942    quant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
943
944    /* Re-synthesis of the coded audio if required */
945    if (resynth)
946    {
947       log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
948
949 #ifdef MEASURE_NORM_MSE
950       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
951 #endif
952
953       if (mdct_weight_shift)
954       {
955          mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
956       }
957
958       /* Synthesis */
959       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
960
961       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
962
963       for (c=0;c<C;c++)
964          for (i=0;i<M*st->mode->eBands[st->start];i++)
965             freq[c*N+i] = 0;
966       for (c=0;c<C;c++)
967          for (i=M*st->mode->eBands[st->end];i<N;i++)
968             freq[c*N+i] = 0;
969
970       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
971
972       /* De-emphasis and put everything back at the right place 
973          in the synthesis history */
974       if (optional_resynthesis != NULL) {
975          deemphasis(st->out_mem, optional_resynthesis, N, C, st->mode->preemph, st->preemph_memD);
976
977       }
978    }
979
980    /* If there's any room left (can only happen for very high rates),
981       fill it with zeros */
982    while (ec_enc_tell(enc,0) + 8 <= nbCompressedBytes*8)
983       ec_enc_bits(enc, 0, 8);
984    ec_enc_done(enc);
985    
986    RESTORE_STACK;
987    if (ec_enc_get_error(enc))
988       return CELT_CORRUPTED_DATA;
989    else
990       return nbCompressedBytes;
991 }
992
993 #ifdef FIXED_POINT
994 #ifndef DISABLE_FLOAT_API
995 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
996 {
997    int j, ret, C, N, LM, M;
998    VARDECL(celt_int16, in);
999    SAVE_STACK;
1000
1001    if (check_encoder(st) != CELT_OK)
1002       return CELT_INVALID_STATE;
1003
1004    if (check_mode(st->mode) != CELT_OK)
1005       return CELT_INVALID_MODE;
1006
1007    if (pcm==NULL)
1008       return CELT_BAD_ARG;
1009
1010    for (LM=0;LM<4;LM++)
1011       if (st->mode->shortMdctSize<<LM==frame_size)
1012          break;
1013    if (LM>=MAX_CONFIG_SIZES)
1014       return CELT_BAD_ARG;
1015    M=1<<LM;
1016
1017    C = CHANNELS(st->channels);
1018    N = M*st->mode->shortMdctSize;
1019    ALLOC(in, C*N, celt_int16);
1020
1021    for (j=0;j<C*N;j++)
1022      in[j] = FLOAT2INT16(pcm[j]);
1023
1024    if (optional_resynthesis != NULL) {
1025      ret=celt_encode_with_ec(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1026       for (j=0;j<C*N;j++)
1027          optional_resynthesis[j]=in[j]*(1.f/32768.f);
1028    } else {
1029      ret=celt_encode_with_ec(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1030    }
1031    RESTORE_STACK;
1032    return ret;
1033
1034 }
1035 #endif /*DISABLE_FLOAT_API*/
1036 #else
1037 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
1038 {
1039    int j, ret, C, N, LM, M;
1040    VARDECL(celt_sig, in);
1041    SAVE_STACK;
1042
1043    if (check_encoder(st) != CELT_OK)
1044       return CELT_INVALID_STATE;
1045
1046    if (check_mode(st->mode) != CELT_OK)
1047       return CELT_INVALID_MODE;
1048
1049    if (pcm==NULL)
1050       return CELT_BAD_ARG;
1051
1052    for (LM=0;LM<4;LM++)
1053       if (st->mode->shortMdctSize<<LM==frame_size)
1054          break;
1055    if (LM>=MAX_CONFIG_SIZES)
1056       return CELT_BAD_ARG;
1057    M=1<<LM;
1058
1059    C=CHANNELS(st->channels);
1060    N=M*st->mode->shortMdctSize;
1061    ALLOC(in, C*N, celt_sig);
1062    for (j=0;j<C*N;j++) {
1063      in[j] = SCALEOUT(pcm[j]);
1064    }
1065
1066    if (optional_resynthesis != NULL) {
1067       ret = celt_encode_with_ec_float(st,in,in,frame_size,compressed,nbCompressedBytes, enc);
1068       for (j=0;j<C*N;j++)
1069          optional_resynthesis[j] = FLOAT2INT16(in[j]);
1070    } else {
1071       ret = celt_encode_with_ec_float(st,in,NULL,frame_size,compressed,nbCompressedBytes, enc);
1072    }
1073    RESTORE_STACK;
1074    return ret;
1075 }
1076 #endif
1077
1078 int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1079 {
1080    return celt_encode_with_ec(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1081 }
1082
1083 #ifndef DISABLE_FLOAT_API
1084 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1085 {
1086    return celt_encode_with_ec_float(st, pcm, NULL, frame_size, compressed, nbCompressedBytes, NULL);
1087 }
1088 #endif /* DISABLE_FLOAT_API */
1089
1090 int celt_encode_resynthesis(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1091 {
1092    return celt_encode_with_ec(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1093 }
1094
1095 #ifndef DISABLE_FLOAT_API
1096 int celt_encode_resynthesis_float(CELTEncoder * restrict st, const float * pcm, float * optional_resynthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes)
1097 {
1098    return celt_encode_with_ec_float(st, pcm, optional_resynthesis, frame_size, compressed, nbCompressedBytes, NULL);
1099 }
1100 #endif /* DISABLE_FLOAT_API */
1101
1102
1103 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
1104 {
1105    va_list ap;
1106    
1107    if (check_encoder(st) != CELT_OK)
1108       return CELT_INVALID_STATE;
1109
1110    va_start(ap, request);
1111    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1112      goto bad_mode;
1113    switch (request)
1114    {
1115       case CELT_GET_MODE_REQUEST:
1116       {
1117          const CELTMode ** value = va_arg(ap, const CELTMode**);
1118          if (value==0)
1119             goto bad_arg;
1120          *value=st->mode;
1121       }
1122       break;
1123       case CELT_SET_COMPLEXITY_REQUEST:
1124       {
1125          int value = va_arg(ap, celt_int32);
1126          if (value<0 || value>10)
1127             goto bad_arg;
1128       }
1129       break;
1130       case CELT_SET_START_BAND_REQUEST:
1131       {
1132          celt_int32 value = va_arg(ap, celt_int32);
1133          if (value<0 || value>=st->mode->nbEBands)
1134             goto bad_arg;
1135          st->start = value;
1136       }
1137       break;
1138       case CELT_SET_END_BAND_REQUEST:
1139       {
1140          celt_int32 value = va_arg(ap, celt_int32);
1141          if (value<0 || value>=st->mode->nbEBands)
1142             goto bad_arg;
1143          st->end = value;
1144       }
1145       break;
1146       case CELT_SET_PREDICTION_REQUEST:
1147       {
1148          int value = va_arg(ap, celt_int32);
1149          if (value<0 || value>2)
1150             goto bad_arg;
1151          if (value==0)
1152          {
1153             st->force_intra   = 1;
1154          } else if (value==1) {
1155             st->force_intra   = 0;
1156          } else {
1157             st->force_intra   = 0;
1158          }   
1159       }
1160       break;
1161       case CELT_SET_VBR_RATE_REQUEST:
1162       {
1163          celt_int32 value = va_arg(ap, celt_int32);
1164          int frame_rate;
1165          int N = st->mode->shortMdctSize;
1166          if (value<0)
1167             goto bad_arg;
1168          if (value>3072000)
1169             value = 3072000;
1170          frame_rate = ((st->mode->Fs<<3)+(N>>1))/N;
1171          st->vbr_rate_norm = ((value<<(BITRES+3))+(frame_rate>>1))/frame_rate;
1172       }
1173       break;
1174       case CELT_RESET_STATE:
1175       {
1176          const CELTMode *mode = st->mode;
1177          int C = st->channels;
1178
1179          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1180          CELT_MEMSET(st->out_mem, 0, (MAX_PERIOD+st->overlap)*C);
1181
1182          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1183
1184          CELT_MEMSET(st->preemph_memE, 0, C);
1185          CELT_MEMSET(st->preemph_memD, 0, C);
1186          st->delayedIntra = 1;
1187
1188          st->fold_decision = 1;
1189          st->tonal_average = QCONST16(1.f,8);
1190          st->gain_prod = 0;
1191          st->vbr_reservoir = 0;
1192          st->vbr_drift = 0;
1193          st->vbr_offset = 0;
1194          st->vbr_count = 0;
1195          st->frame_max = 0;
1196       }
1197       break;
1198       default:
1199          goto bad_request;
1200    }
1201    va_end(ap);
1202    return CELT_OK;
1203 bad_mode:
1204   va_end(ap);
1205   return CELT_INVALID_MODE;
1206 bad_arg:
1207    va_end(ap);
1208    return CELT_BAD_ARG;
1209 bad_request:
1210    va_end(ap);
1211    return CELT_UNIMPLEMENTED;
1212 }
1213
1214 /**********************************************************************/
1215 /*                                                                    */
1216 /*                             DECODER                                */
1217 /*                                                                    */
1218 /**********************************************************************/
1219 #define DECODE_BUFFER_SIZE 2048
1220
1221 #define DECODERVALID   0x4c434454
1222 #define DECODERPARTIAL 0x5444434c
1223 #define DECODERFREED   0x4c004400
1224
1225 /** Decoder state 
1226  @brief Decoder state
1227  */
1228 struct CELTDecoder {
1229    celt_uint32 marker;
1230    const CELTMode *mode;
1231    int overlap;
1232    int channels;
1233
1234    int start, end;
1235
1236    celt_sig preemph_memD[2];
1237
1238    celt_sig *out_mem;
1239    celt_word32 *decode_mem;
1240
1241    celt_word16 *oldBandE;
1242    
1243    celt_word16 *lpc;
1244
1245    int last_pitch_index;
1246    int loss_count;
1247 };
1248
1249 int check_decoder(const CELTDecoder *st) 
1250 {
1251    if (st==NULL)
1252    {
1253       celt_warning("NULL passed a decoder structure");  
1254       return CELT_INVALID_STATE;
1255    }
1256    if (st->marker == DECODERVALID)
1257       return CELT_OK;
1258    if (st->marker == DECODERFREED)
1259       celt_warning("Referencing a decoder that has already been freed");
1260    else
1261       celt_warning("This is not a valid CELT decoder structure");
1262    return CELT_INVALID_STATE;
1263 }
1264
1265 CELTDecoder *celt_decoder_create(const CELTMode *mode, int channels, int *error)
1266 {
1267    int C;
1268    CELTDecoder *st;
1269
1270    if (check_mode(mode) != CELT_OK)
1271    {
1272       if (error)
1273          *error = CELT_INVALID_MODE;
1274       return NULL;
1275    }
1276
1277    if (channels < 0 || channels > 2)
1278    {
1279       celt_warning("Only mono and stereo supported");
1280       if (error)
1281          *error = CELT_BAD_ARG;
1282       return NULL;
1283    }
1284
1285    C = CHANNELS(channels);
1286    st = celt_alloc(sizeof(CELTDecoder));
1287
1288    if (st==NULL)
1289    {
1290       if (error)
1291          *error = CELT_ALLOC_FAIL;
1292       return NULL;
1293    }
1294
1295    st->marker = DECODERPARTIAL;
1296    st->mode = mode;
1297    st->overlap = mode->overlap;
1298    st->channels = channels;
1299
1300    st->start = 0;
1301    st->end = st->mode->effEBands;
1302
1303    st->decode_mem = (celt_sig*)celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig));
1304    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
1305    
1306    st->oldBandE = (celt_word16*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16));
1307    
1308    st->lpc = (celt_word16*)celt_alloc(C*LPC_ORDER*sizeof(celt_word16));
1309
1310    st->loss_count = 0;
1311
1312    if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
1313          (st->lpc!=NULL))
1314    {
1315       if (error)
1316          *error = CELT_OK;
1317       st->marker = DECODERVALID;
1318       return st;
1319    }
1320    /* If the setup fails for some reason deallocate it. */
1321    celt_decoder_destroy(st);
1322    if (error)
1323       *error = CELT_ALLOC_FAIL;
1324    return NULL;
1325 }
1326
1327 void celt_decoder_destroy(CELTDecoder *st)
1328 {
1329    if (st == NULL)
1330    {
1331       celt_warning("NULL passed to celt_decoder_destroy");
1332       return;
1333    }
1334
1335    if (st->marker == DECODERFREED) 
1336    {
1337       celt_warning("Freeing a decoder which has already been freed"); 
1338       return;
1339    }
1340    
1341    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1342    {
1343       celt_warning("This is not a valid CELT decoder structure");
1344       return;
1345    }
1346    
1347    /*Check_mode is non-fatal here because we can still free
1348      the encoder memory even if the mode is bad, although calling
1349      the free functions in this order is a violation of the API.*/
1350    check_mode(st->mode);
1351    
1352    celt_free(st->decode_mem);
1353    celt_free(st->oldBandE);
1354    celt_free(st->lpc);
1355    
1356    st->marker = DECODERFREED;
1357    
1358    celt_free(st);
1359 }
1360
1361 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict pcm, int N, int LM)
1362 {
1363    int c;
1364    int pitch_index;
1365    int overlap = st->mode->overlap;
1366    celt_word16 fade = Q15ONE;
1367    int i, len;
1368    const int C = CHANNELS(st->channels);
1369    int offset;
1370    SAVE_STACK;
1371    
1372    len = N+st->mode->overlap;
1373    
1374    if (st->loss_count == 0)
1375    {
1376       celt_word16 pitch_buf[MAX_PERIOD>>1];
1377       celt_word32 tmp=0;
1378       celt_word32 mem0[2]={0,0};
1379       celt_word16 mem1[2]={0,0};
1380       int len2 = len;
1381       /* FIXME: This is a kludge */
1382       if (len2>MAX_PERIOD>>1)
1383          len2 = MAX_PERIOD>>1;
1384       pitch_downsample(st->out_mem, pitch_buf, MAX_PERIOD, MAX_PERIOD,
1385                        C, mem0, mem1);
1386       pitch_search(st->mode, pitch_buf+((MAX_PERIOD-len2)>>1), pitch_buf, len2,
1387                    MAX_PERIOD-len2-100, &pitch_index, &tmp, 1<<LM);
1388       pitch_index = MAX_PERIOD-len2-pitch_index;
1389       st->last_pitch_index = pitch_index;
1390    } else {
1391       pitch_index = st->last_pitch_index;
1392       if (st->loss_count < 5)
1393          fade = QCONST16(.8f,15);
1394       else
1395          fade = 0;
1396    }
1397
1398    for (c=0;c<C;c++)
1399    {
1400       /* FIXME: This is more memory than necessary */
1401       celt_word32 e[2*MAX_PERIOD];
1402       celt_word16 exc[2*MAX_PERIOD];
1403       celt_word32 ac[LPC_ORDER+1];
1404       celt_word16 decay = 1;
1405       celt_word32 S1=0;
1406       celt_word16 mem[LPC_ORDER]={0};
1407
1408       offset = MAX_PERIOD-pitch_index;
1409       for (i=0;i<MAX_PERIOD;i++)
1410          exc[i] = ROUND16(st->out_mem[i*C+c], SIG_SHIFT);
1411
1412       if (st->loss_count == 0)
1413       {
1414          _celt_autocorr(exc, ac, st->mode->window, st->mode->overlap,
1415                         LPC_ORDER, MAX_PERIOD);
1416
1417          /* Noise floor -40 dB */
1418 #ifdef FIXED_POINT
1419          ac[0] += SHR32(ac[0],13);
1420 #else
1421          ac[0] *= 1.0001f;
1422 #endif
1423          /* Lag windowing */
1424          for (i=1;i<=LPC_ORDER;i++)
1425          {
1426             /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
1427 #ifdef FIXED_POINT
1428             ac[i] -= MULT16_32_Q15(2*i*i, ac[i]);
1429 #else
1430             ac[i] -= ac[i]*(.008f*i)*(.008f*i);
1431 #endif
1432          }
1433
1434          _celt_lpc(st->lpc+c*LPC_ORDER, ac, LPC_ORDER);
1435       }
1436       fir(exc, st->lpc+c*LPC_ORDER, exc, MAX_PERIOD, LPC_ORDER, mem);
1437       /*for (i=0;i<MAX_PERIOD;i++)printf("%d ", exc[i]); printf("\n");*/
1438       /* Check if the waveform is decaying (and if so how fast) */
1439       {
1440          celt_word32 E1=1, E2=1;
1441          int period;
1442          if (pitch_index <= MAX_PERIOD/2)
1443             period = pitch_index;
1444          else
1445             period = MAX_PERIOD/2;
1446          for (i=0;i<period;i++)
1447          {
1448             E1 += SHR32(MULT16_16(exc[MAX_PERIOD-period+i],exc[MAX_PERIOD-period+i]),8);
1449             E2 += SHR32(MULT16_16(exc[MAX_PERIOD-2*period+i],exc[MAX_PERIOD-2*period+i]),8);
1450          }
1451          if (E1 > E2)
1452             E1 = E2;
1453          decay = celt_sqrt(frac_div32(SHR(E1,1),E2));
1454       }
1455
1456       /* Copy excitation, taking decay into account */
1457       for (i=0;i<len+st->mode->overlap;i++)
1458       {
1459          if (offset+i >= MAX_PERIOD)
1460          {
1461             offset -= pitch_index;
1462             decay = MULT16_16_Q15(decay, decay);
1463          }
1464          e[i] = SHL32(EXTEND32(MULT16_16_Q15(decay, exc[offset+i])), SIG_SHIFT);
1465          S1 += SHR32(MULT16_16(st->out_mem[offset+i],st->out_mem[offset+i]),8);
1466       }
1467
1468       iir(e, st->lpc+c*LPC_ORDER, e, len+st->mode->overlap, LPC_ORDER, mem);
1469
1470       {
1471          celt_word32 S2=0;
1472          for (i=0;i<len+overlap;i++)
1473             S2 += SHR32(MULT16_16(e[i],e[i]),8);
1474          /* This checks for an "explosion" in the synthesis */
1475 #ifdef FIXED_POINT
1476          if (!(S1 > SHR32(S2,2)))
1477 #else
1478          /* Float test is written this way to catch NaNs at the same time */
1479          if (!(S1 > 0.2f*S2))
1480 #endif
1481          {
1482             for (i=0;i<len+overlap;i++)
1483                e[i] = 0;
1484          } else if (S1 < S2)
1485          {
1486             celt_word16 ratio = celt_sqrt(frac_div32(SHR32(S1,1)+1,S2+1));
1487             for (i=0;i<len+overlap;i++)
1488                e[i] = MULT16_16_Q15(ratio, e[i]);
1489          }
1490       }
1491
1492       for (i=0;i<MAX_PERIOD+st->mode->overlap-N;i++)
1493          st->out_mem[C*i+c] = st->out_mem[C*(N+i)+c];
1494
1495       /* Apply TDAC to the concealed audio so that it blends with the
1496          previous and next frames */
1497       for (i=0;i<overlap/2;i++)
1498       {
1499          celt_word32 tmp1, tmp2;
1500          tmp1 = MULT16_32_Q15(st->mode->window[i          ], e[i          ]) -
1501                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[overlap-i-1]);
1502          tmp2 = MULT16_32_Q15(st->mode->window[i],           e[N+overlap-1-i]) +
1503                 MULT16_32_Q15(st->mode->window[overlap-i-1], e[N+i          ]);
1504          tmp1 = MULT16_32_Q15(fade, tmp1);
1505          tmp2 = MULT16_32_Q15(fade, tmp2);
1506          st->out_mem[C*(MAX_PERIOD+i)+c] = MULT16_32_Q15(st->mode->window[overlap-i-1], tmp2);
1507          st->out_mem[C*(MAX_PERIOD+overlap-i-1)+c] = MULT16_32_Q15(st->mode->window[i], tmp2);
1508          st->out_mem[C*(MAX_PERIOD-N+i)+c] += MULT16_32_Q15(st->mode->window[i], tmp1);
1509          st->out_mem[C*(MAX_PERIOD-N+overlap-i-1)+c] -= MULT16_32_Q15(st->mode->window[overlap-i-1], tmp1);
1510       }
1511       for (i=0;i<N-overlap;i++)
1512          st->out_mem[C*(MAX_PERIOD-N+overlap+i)+c] = MULT16_32_Q15(fade, e[overlap+i]);
1513    }
1514
1515    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1516    
1517    st->loss_count++;
1518
1519    RESTORE_STACK;
1520 }
1521
1522 #ifdef FIXED_POINT
1523 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1524 {
1525 #else
1526 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
1527 {
1528 #endif
1529    int c, i, N, N4;
1530    int has_fold;
1531    int bits;
1532    ec_dec _dec;
1533    ec_byte_buffer buf;
1534    VARDECL(celt_sig, freq);
1535    VARDECL(celt_norm, X);
1536    VARDECL(celt_ener, bandE);
1537    VARDECL(int, fine_quant);
1538    VARDECL(int, pulses);
1539    VARDECL(int, offsets);
1540    VARDECL(int, fine_priority);
1541    VARDECL(int, tf_res);
1542
1543    int shortBlocks;
1544    int isTransient;
1545    int intra_ener;
1546    int transient_time;
1547    int transient_shift;
1548    int mdct_weight_shift=0;
1549    const int C = CHANNELS(st->channels);
1550    int mdct_weight_pos=0;
1551    int LM, M;
1552    int nbFilledBytes, nbAvailableBytes;
1553    int effEnd;
1554    SAVE_STACK;
1555
1556    if (check_decoder(st) != CELT_OK)
1557       return CELT_INVALID_STATE;
1558
1559    if (check_mode(st->mode) != CELT_OK)
1560       return CELT_INVALID_MODE;
1561
1562    if (pcm==NULL)
1563       return CELT_BAD_ARG;
1564
1565    for (LM=0;LM<4;LM++)
1566       if (st->mode->shortMdctSize<<LM==frame_size)
1567          break;
1568    if (LM>=MAX_CONFIG_SIZES)
1569       return CELT_BAD_ARG;
1570    M=1<<LM;
1571
1572    N = M*st->mode->shortMdctSize;
1573    N4 = (N-st->overlap)>>1;
1574
1575    effEnd = st->end;
1576    if (effEnd > st->mode->effEBands)
1577       effEnd = st->mode->effEBands;
1578
1579    ALLOC(freq, C*N, celt_sig); /**< Interleaved signal MDCTs */
1580    ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
1581    ALLOC(bandE, st->mode->nbEBands*C, celt_ener);
1582    for (c=0;c<C;c++)
1583       for (i=0;i<M*st->mode->eBands[st->start];i++)
1584          X[c*N+i] = 0;
1585    for (c=0;c<C;c++)
1586       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1587          X[c*N+i] = 0;
1588
1589    if (data == NULL)
1590    {
1591       celt_decode_lost(st, pcm, N, LM);
1592       RESTORE_STACK;
1593       return CELT_OK;
1594    }
1595    if (len<0) {
1596      RESTORE_STACK;
1597      return CELT_BAD_ARG;
1598    }
1599    
1600    if (dec == NULL)
1601    {
1602       ec_byte_readinit(&buf,(unsigned char*)data,len);
1603       ec_dec_init(&_dec,&buf);
1604       dec = &_dec;
1605       nbFilledBytes = 0;
1606    } else {
1607       nbFilledBytes = (ec_dec_tell(dec, 0)+4)>>3;
1608    }
1609    nbAvailableBytes = len-nbFilledBytes;
1610
1611    /* Decode the global flags (first symbols in the stream) */
1612    intra_ener = ec_dec_bit_prob(dec, 8192);
1613    isTransient = ec_dec_bit_prob(dec, 8192);
1614    has_fold = ec_dec_bit_prob(dec, 8192)<<1;
1615    has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
1616
1617    if (isTransient)
1618       shortBlocks = M;
1619    else
1620       shortBlocks = 0;
1621
1622    if (isTransient)
1623    {
1624       transient_shift = ec_dec_uint(dec, 4);
1625       if (transient_shift == 3)
1626       {
1627          int transient_time_quant;
1628          int max_time = (N+st->mode->overlap)*(celt_int32)8000/st->mode->Fs;
1629          transient_time_quant = ec_dec_uint(dec, max_time);
1630          transient_time = transient_time_quant*(celt_int32)st->mode->Fs/8000;
1631       } else {
1632          mdct_weight_shift = transient_shift;
1633          if (mdct_weight_shift && M>2)
1634             mdct_weight_pos = ec_dec_uint(dec, M-1);
1635          transient_shift = 0;
1636          transient_time = 0;
1637       }
1638    } else {
1639       transient_time = -1;
1640       transient_shift = 0;
1641    }
1642    
1643    ALLOC(fine_quant, st->mode->nbEBands, int);
1644    /* Get band energies */
1645    unquant_coarse_energy(st->mode, st->start, st->end, bandE, st->oldBandE, intra_ener, st->mode->prob, dec, C, LM);
1646
1647    ALLOC(tf_res, st->mode->nbEBands, int);
1648    tf_decode(st->start, st->end, C, isTransient, tf_res, nbAvailableBytes, LM, dec);
1649
1650    ALLOC(pulses, st->mode->nbEBands, int);
1651    ALLOC(offsets, st->mode->nbEBands, int);
1652    ALLOC(fine_priority, st->mode->nbEBands, int);
1653
1654    for (i=0;i<st->mode->nbEBands;i++)
1655       offsets[i] = 0;
1656
1657    bits = len*8 - ec_dec_tell(dec, 0) - 1;
1658    compute_allocation(st->mode, st->start, st->end, offsets, bits, pulses, fine_quant, fine_priority, C, M);
1659    /*bits = ec_dec_tell(dec, 0);
1660    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(dec, 0)-bits))/C);*/
1661    
1662    unquant_fine_energy(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, dec, C);
1663
1664    /* Decode fixed codebook */
1665    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL, NULL, pulses, shortBlocks, has_fold, tf_res, 1, len*8, dec, LM);
1666
1667    unquant_energy_finalise(st->mode, st->start, st->end, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
1668
1669    log2Amp(st->mode, st->start, st->end, bandE, st->oldBandE, C);
1670
1671    if (mdct_weight_shift)
1672    {
1673       mdct_shape(st->mode, X, 0, mdct_weight_pos+1, N, mdct_weight_shift, effEnd, C, 1, M);
1674    }
1675
1676    /* Synthesis */
1677    denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
1678
1679
1680    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1681
1682    for (c=0;c<C;c++)
1683       for (i=0;i<M*st->mode->eBands[st->start];i++)
1684          freq[c*N+i] = 0;
1685    for (c=0;c<C;c++)
1686       for (i=M*st->mode->eBands[effEnd];i<N;i++)
1687          freq[c*N+i] = 0;
1688
1689    /* Compute inverse MDCTs */
1690    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM);
1691
1692    deemphasis(st->out_mem, pcm, N, C, st->mode->preemph, st->preemph_memD);
1693    st->loss_count = 0;
1694    RESTORE_STACK;
1695    if (ec_dec_get_error(dec))
1696       return CELT_CORRUPTED_DATA;
1697    else
1698       return CELT_OK;
1699 }
1700
1701 #ifdef FIXED_POINT
1702 #ifndef DISABLE_FLOAT_API
1703 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
1704 {
1705    int j, ret, C, N, LM, M;
1706    VARDECL(celt_int16, out);
1707    SAVE_STACK;
1708
1709    if (check_decoder(st) != CELT_OK)
1710       return CELT_INVALID_STATE;
1711
1712    if (check_mode(st->mode) != CELT_OK)
1713       return CELT_INVALID_MODE;
1714
1715    if (pcm==NULL)
1716       return CELT_BAD_ARG;
1717
1718    for (LM=0;LM<4;LM++)
1719       if (st->mode->shortMdctSize<<LM==frame_size)
1720          break;
1721    if (LM>=MAX_CONFIG_SIZES)
1722       return CELT_BAD_ARG;
1723    M=1<<LM;
1724
1725    C = CHANNELS(st->channels);
1726    N = M*st->mode->shortMdctSize;
1727    
1728    ALLOC(out, C*N, celt_int16);
1729    ret=celt_decode_with_ec(st, data, len, out, frame_size, dec);
1730    if (ret==0)
1731       for (j=0;j<C*N;j++)
1732          pcm[j]=out[j]*(1.f/32768.f);
1733      
1734    RESTORE_STACK;
1735    return ret;
1736 }
1737 #endif /*DISABLE_FLOAT_API*/
1738 #else
1739 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
1740 {
1741    int j, ret, C, N, LM, M;
1742    VARDECL(celt_sig, out);
1743    SAVE_STACK;
1744
1745    if (check_decoder(st) != CELT_OK)
1746       return CELT_INVALID_STATE;
1747
1748    if (check_mode(st->mode) != CELT_OK)
1749       return CELT_INVALID_MODE;
1750
1751    if (pcm==NULL)
1752       return CELT_BAD_ARG;
1753
1754    for (LM=0;LM<4;LM++)
1755       if (st->mode->shortMdctSize<<LM==frame_size)
1756          break;
1757    if (LM>=MAX_CONFIG_SIZES)
1758       return CELT_BAD_ARG;
1759    M=1<<LM;
1760
1761    C = CHANNELS(st->channels);
1762    N = M*st->mode->shortMdctSize;
1763    ALLOC(out, C*N, celt_sig);
1764
1765    ret=celt_decode_with_ec_float(st, data, len, out, frame_size, dec);
1766
1767    if (ret==0)
1768       for (j=0;j<C*N;j++)
1769          pcm[j] = FLOAT2INT16 (out[j]);
1770    
1771    RESTORE_STACK;
1772    return ret;
1773 }
1774 #endif
1775
1776 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size)
1777 {
1778    return celt_decode_with_ec(st, data, len, pcm, frame_size, NULL);
1779 }
1780
1781 #ifndef DISABLE_FLOAT_API
1782 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size)
1783 {
1784    return celt_decode_with_ec_float(st, data, len, pcm, frame_size, NULL);
1785 }
1786 #endif /* DISABLE_FLOAT_API */
1787
1788 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1789 {
1790    va_list ap;
1791
1792    if (check_decoder(st) != CELT_OK)
1793       return CELT_INVALID_STATE;
1794
1795    va_start(ap, request);
1796    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1797      goto bad_mode;
1798    switch (request)
1799    {
1800       case CELT_GET_MODE_REQUEST:
1801       {
1802          const CELTMode ** value = va_arg(ap, const CELTMode**);
1803          if (value==0)
1804             goto bad_arg;
1805          *value=st->mode;
1806       }
1807       break;
1808       case CELT_SET_START_BAND_REQUEST:
1809       {
1810          celt_int32 value = va_arg(ap, celt_int32);
1811          if (value<0 || value>=st->mode->nbEBands)
1812             goto bad_arg;
1813          st->start = value;
1814       }
1815       break;
1816       case CELT_SET_END_BAND_REQUEST:
1817       {
1818          celt_int32 value = va_arg(ap, celt_int32);
1819          if (value<0 || value>=st->mode->nbEBands)
1820             goto bad_arg;
1821          st->end = value;
1822       }
1823       break;
1824       case CELT_RESET_STATE:
1825       {
1826          const CELTMode *mode = st->mode;
1827          int C = st->channels;
1828
1829          CELT_MEMSET(st->decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1830          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1831
1832          CELT_MEMSET(st->preemph_memD, 0, C);
1833
1834          st->loss_count = 0;
1835
1836          CELT_MEMSET(st->lpc, 0, C*LPC_ORDER);
1837       }
1838       break;
1839       default:
1840          goto bad_request;
1841    }
1842    va_end(ap);
1843    return CELT_OK;
1844 bad_mode:
1845   va_end(ap);
1846   return CELT_INVALID_MODE;
1847 bad_arg:
1848    va_end(ap);
1849    return CELT_BAD_ARG;
1850 bad_request:
1851       va_end(ap);
1852   return CELT_UNIMPLEMENTED;
1853 }
1854
1855 const char *celt_strerror(int error)
1856 {
1857    static const char *error_strings[8] = {
1858       "success",
1859       "invalid argument",
1860       "invalid mode",
1861       "internal error",
1862       "corrupted stream",
1863       "request not implemented",
1864       "invalid state",
1865       "memory allocation failed"
1866    };
1867    if (error > 0 || error < -7)
1868       return "unknown error";
1869    else 
1870       return error_strings[-error];
1871 }
1872