Demoved the divisions from the inner pitch prediction loops, bumped the version
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 
64    0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 
66    0.8695045, 0.9251086, 0.9662361, 0.9914865};
67 #endif
68
69 #define ENCODERVALID   0x4c434554
70 #define ENCODERPARTIAL 0x5445434c
71 #define ENCODERFREED   0x4c004500
72    
73 /** Encoder state 
74  @brief Encoder state
75  */
76 struct CELTEncoder {
77    celt_uint32_t marker;
78    const CELTMode *mode;     /**< Mode used by the encoder */
79    int frame_size;
80    int block_size;
81    int overlap;
82    int channels;
83    
84    int pitch_enabled;       /* Complexity level is allowed to use pitch */
85    int pitch_permitted;     /*  Use of the LTP is permitted by the user */
86    int pitch_available;     /*  Amount of pitch buffer available */
87    int force_intra;
88    int delayedIntra;
89    celt_word16_t tonal_average;
90    int fold_decision;
91
92    int VBR_rate; /* Target number of 16th bits per frame */
93    celt_word16_t * restrict preemph_memE; 
94    celt_sig_t    * restrict preemph_memD;
95
96    celt_sig_t *in_mem;
97    celt_sig_t *out_mem;
98
99    celt_word16_t *oldBandE;
100 #ifdef EXP_PSY
101    celt_word16_t *psy_mem;
102    struct PsyDecay psy;
103 #endif
104 };
105
106 int check_encoder(const CELTEncoder *st) 
107 {
108    if (st==NULL)
109    {
110       celt_warning("NULL passed as an encoder structure");  
111       return CELT_INVALID_STATE;
112    }
113    if (st->marker == ENCODERVALID)
114       return CELT_OK;
115    if (st->marker == ENCODERFREED)
116       celt_warning("Referencing an encoder that has already been freed");
117    else
118       celt_warning("This is not a valid CELT encoder structure");
119    return CELT_INVALID_STATE;
120 }
121
122 CELTEncoder *celt_encoder_create(const CELTMode *mode)
123 {
124    int N, C;
125    CELTEncoder *st;
126
127    if (check_mode(mode) != CELT_OK)
128       return NULL;
129
130    N = mode->mdctSize;
131    C = mode->nbChannels;
132    st = celt_alloc(sizeof(CELTEncoder));
133    
134    if (st==NULL) 
135       return NULL;   
136    st->marker = ENCODERPARTIAL;
137    st->mode = mode;
138    st->frame_size = N;
139    st->block_size = N;
140    st->overlap = mode->overlap;
141
142    st->VBR_rate = 0;
143    st->pitch_enabled = 1;
144    st->pitch_permitted = 1;
145    st->pitch_available = 1;
146    st->force_intra  = 0;
147    st->delayedIntra = 1;
148    st->tonal_average = QCONST16(1.,8);
149    st->fold_decision = 1;
150
151    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
152    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
153
154    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
155
156    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
157    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
158
159 #ifdef EXP_PSY
160    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
161    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
162 #endif
163
164    if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) 
165 #ifdef EXP_PSY
166        && (st->psy_mem!=NULL) 
167 #endif   
168        && (st->preemph_memE!=NULL) && (st->preemph_memD!=NULL))
169    {
170       st->marker   = ENCODERVALID;
171       return st;
172    }
173    /* If the setup fails for some reason deallocate it. */
174    celt_encoder_destroy(st);  
175    return NULL;
176 }
177
178 void celt_encoder_destroy(CELTEncoder *st)
179 {
180    if (st == NULL)
181    {
182       celt_warning("NULL passed to celt_encoder_destroy");
183       return;
184    }
185
186    if (st->marker == ENCODERFREED)
187    {
188       celt_warning("Freeing an encoder which has already been freed"); 
189       return;
190    }
191
192    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
193    {
194       celt_warning("This is not a valid CELT encoder structure");
195       return;
196    }
197    /*Check_mode is non-fatal here because we can still free
198     the encoder memory even if the mode is bad, although calling
199     the free functions in this order is a violation of the API.*/
200    check_mode(st->mode);
201    
202    celt_free(st->in_mem);
203    celt_free(st->out_mem);
204    
205    celt_free(st->oldBandE);
206    
207    celt_free(st->preemph_memE);
208    celt_free(st->preemph_memD);
209    
210 #ifdef EXP_PSY
211    celt_free (st->psy_mem);
212    psydecay_clear(&st->psy);
213 #endif
214    st->marker = ENCODERFREED;
215    
216    celt_free(st);
217 }
218
219 static inline celt_int16_t FLOAT2INT16(float x)
220 {
221    x = x*CELT_SIG_SCALE;
222    x = MAX32(x, -32768);
223    x = MIN32(x, 32767);
224    return (celt_int16_t)float2int(x);
225 }
226
227 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
228 {
229 #ifdef FIXED_POINT
230    x = PSHR32(x, SIG_SHIFT);
231    x = MAX32(x, -32768);
232    x = MIN32(x, 32767);
233    return EXTRACT16(x);
234 #else
235    return (celt_word16_t)x;
236 #endif
237 }
238
239 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
240 {
241    int c, i, n;
242    celt_word32_t ratio;
243    VARDECL(celt_word32_t, begin);
244    SAVE_STACK;
245    ALLOC(begin, len, celt_word32_t);
246    for (i=0;i<len;i++)
247       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
248    for (c=1;c<C;c++)
249    {
250       for (i=0;i<len;i++)
251          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
252    }
253    for (i=1;i<len;i++)
254       begin[i] = MAX32(begin[i-1],begin[i]);
255    n = -1;
256    for (i=8;i<len-8;i++)
257    {
258       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
259          n=i;
260    }
261    if (n<32)
262    {
263       n = -1;
264       ratio = 0;
265    } else {
266       ratio = DIV32(begin[len-1],1+begin[n-16]);
267    }
268    if (ratio < 0)
269       ratio = 0;
270    if (ratio > 1000)
271       ratio = 1000;
272    ratio *= ratio;
273    
274    if (ratio > 2048)
275       *transient_shift = 3;
276    else
277       *transient_shift = 0;
278    
279    *transient_time = n;
280    
281    RESTORE_STACK;
282    return ratio > 20;
283 }
284
285 /** Apply window and compute the MDCT for all sub-frames and 
286     all channels in a frame */
287 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
288 {
289    const int C = CHANNELS(mode);
290    if (C==1 && !shortBlocks)
291    {
292       const mdct_lookup *lookup = MDCT(mode);
293       const int overlap = OVERLAP(mode);
294       mdct_forward(lookup, in, out, mode->window, overlap);
295    } else if (!shortBlocks) {
296       const mdct_lookup *lookup = MDCT(mode);
297       const int overlap = OVERLAP(mode);
298       const int N = FRAMESIZE(mode);
299       int c;
300       VARDECL(celt_word32_t, x);
301       VARDECL(celt_word32_t, tmp);
302       SAVE_STACK;
303       ALLOC(x, N+overlap, celt_word32_t);
304       ALLOC(tmp, N, celt_word32_t);
305       for (c=0;c<C;c++)
306       {
307          int j;
308          for (j=0;j<N+overlap;j++)
309             x[j] = in[C*j+c];
310          mdct_forward(lookup, x, tmp, mode->window, overlap);
311          /* Interleaving the sub-frames */
312          for (j=0;j<N;j++)
313             out[j+c*N] = tmp[j];
314       }
315       RESTORE_STACK;
316    } else {
317       const mdct_lookup *lookup = &mode->shortMdct;
318       const int overlap = mode->overlap;
319       const int N = mode->shortMdctSize;
320       int b, c;
321       VARDECL(celt_word32_t, x);
322       VARDECL(celt_word32_t, tmp);
323       SAVE_STACK;
324       ALLOC(x, N+overlap, celt_word32_t);
325       ALLOC(tmp, N, celt_word32_t);
326       for (c=0;c<C;c++)
327       {
328          int B = mode->nbShortMdcts;
329          for (b=0;b<B;b++)
330          {
331             int j;
332             for (j=0;j<N+overlap;j++)
333                x[j] = in[C*(b*N+j)+c];
334             mdct_forward(lookup, x, tmp, mode->window, overlap);
335             /* Interleaving the sub-frames */
336             for (j=0;j<N;j++)
337                out[(j*B+b)+c*N*B] = tmp[j];
338          }
339       }
340       RESTORE_STACK;
341    }
342 }
343
344 /** Compute the IMDCT and apply window for all sub-frames and 
345     all channels in a frame */
346 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
347 {
348    int c, N4;
349    const int C = CHANNELS(mode);
350    const int N = FRAMESIZE(mode);
351    const int overlap = OVERLAP(mode);
352    N4 = (N-overlap)>>1;
353    for (c=0;c<C;c++)
354    {
355       int j;
356       if (transient_shift==0 && C==1 && !shortBlocks) {
357          const mdct_lookup *lookup = MDCT(mode);
358          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
359       } else if (!shortBlocks) {
360          const mdct_lookup *lookup = MDCT(mode);
361          VARDECL(celt_word32_t, x);
362          VARDECL(celt_word32_t, tmp);
363          SAVE_STACK;
364          ALLOC(x, 2*N, celt_word32_t);
365          ALLOC(tmp, N, celt_word32_t);
366          /* De-interleaving the sub-frames */
367          for (j=0;j<N;j++)
368             tmp[j] = X[j+c*N];
369          /* Prevents problems from the imdct doing the overlap-add */
370          CELT_MEMSET(x+N4, 0, N);
371          mdct_backward(lookup, tmp, x, mode->window, overlap);
372          celt_assert(transient_shift == 0);
373          /* The first and last part would need to be set to zero if we actually
374             wanted to use them. */
375          for (j=0;j<overlap;j++)
376             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
377          for (j=0;j<overlap;j++)
378             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
379          for (j=0;j<2*N4;j++)
380             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
381          RESTORE_STACK;
382       } else {
383          int b;
384          const int N2 = mode->shortMdctSize;
385          const int B = mode->nbShortMdcts;
386          const mdct_lookup *lookup = &mode->shortMdct;
387          VARDECL(celt_word32_t, x);
388          VARDECL(celt_word32_t, tmp);
389          SAVE_STACK;
390          ALLOC(x, 2*N, celt_word32_t);
391          ALLOC(tmp, N, celt_word32_t);
392          /* Prevents problems from the imdct doing the overlap-add */
393          CELT_MEMSET(x+N4, 0, N2);
394          for (b=0;b<B;b++)
395          {
396             /* De-interleaving the sub-frames */
397             for (j=0;j<N2;j++)
398                tmp[j] = X[(j*B+b)+c*N2*B];
399             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
400          }
401          if (transient_shift > 0)
402          {
403 #ifdef FIXED_POINT
404             for (j=0;j<16;j++)
405                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
406             for (j=transient_time;j<N+overlap;j++)
407                x[N4+j] = SHL32(x[N4+j], transient_shift);
408 #else
409             for (j=0;j<16;j++)
410                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
411             for (j=transient_time;j<N+overlap;j++)
412                x[N4+j] *= 1<<transient_shift;
413 #endif
414          }
415          /* The first and last part would need to be set to zero 
416             if we actually wanted to use them. */
417          for (j=0;j<overlap;j++)
418             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
419          for (j=0;j<overlap;j++)
420             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
421          for (j=0;j<2*N4;j++)
422             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
423          RESTORE_STACK;
424       }
425    }
426 }
427
428 #define FLAG_NONE        0
429 #define FLAG_INTRA       (1U<<13)
430 #define FLAG_PITCH       (1U<<12)
431 #define FLAG_SHORT       (1U<<11)
432 #define FLAG_FOLD        (1U<<10)
433 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
434
435 int flaglist[8] = {
436       0 /*00  */ | FLAG_FOLD,
437       1 /*01  */ | FLAG_PITCH|FLAG_FOLD,
438       8 /*1000*/ | FLAG_NONE,
439       9 /*1001*/ | FLAG_SHORT|FLAG_FOLD,
440      10 /*1010*/ | FLAG_PITCH,
441      11 /*1011*/ | FLAG_INTRA,
442       6 /*110 */ | FLAG_INTRA|FLAG_FOLD,
443       7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
444 };
445
446 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
447 {
448    int i;
449    int flags=FLAG_NONE;
450    int flag_bits;
451    flags |= intra_ener   ? FLAG_INTRA : 0;
452    flags |= has_pitch    ? FLAG_PITCH : 0;
453    flags |= shortBlocks  ? FLAG_SHORT : 0;
454    flags |= has_fold     ? FLAG_FOLD  : 0;
455    for (i=0;i<8;i++)
456       if (flags == (flaglist[i]&FLAG_MASK))
457          break;
458    celt_assert(i<8);
459    flag_bits = flaglist[i]&0xf;
460    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
461    if (i<2)
462       ec_enc_uint(enc, flag_bits, 4);
463    else if (i<6)
464       ec_enc_uint(enc, flag_bits, 16);
465    else
466       ec_enc_uint(enc, flag_bits, 8);
467 }
468
469 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
470 {
471    int i;
472    int flag_bits;
473    flag_bits = ec_dec_uint(dec, 4);
474    /*printf ("(%d) ", flag_bits);*/
475    if (flag_bits==2)
476       flag_bits = (flag_bits<<2) | ec_dec_uint(dec, 4);
477    else if (flag_bits==3)
478       flag_bits = (flag_bits<<1) | ec_dec_uint(dec, 2);
479    for (i=0;i<8;i++)
480       if (flag_bits == (flaglist[i]&0xf))
481          break;
482    celt_assert(i<8);
483    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
484    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
485    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
486    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
487    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
488 }
489
490 #ifdef FIXED_POINT
491 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
492 {
493 #else
494 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
495 {
496 #endif
497    int i, c, N, N4;
498    int has_pitch;
499    int pitch_index;
500    int bits;
501    int has_fold=1;
502    unsigned coarse_needed;
503    ec_byte_buffer buf;
504    ec_enc         enc;
505    VARDECL(celt_sig_t, in);
506    VARDECL(celt_sig_t, freq);
507    VARDECL(celt_sig_t, pitch_freq);
508    VARDECL(celt_norm_t, X);
509    VARDECL(celt_ener_t, bandE);
510    VARDECL(celt_word16_t, bandLogE);
511    VARDECL(int, fine_quant);
512    VARDECL(celt_word16_t, error);
513    VARDECL(int, pulses);
514    VARDECL(int, offsets);
515    VARDECL(int, fine_priority);
516    int intra_ener = 0;
517    int shortBlocks=0;
518    int transient_time;
519    int transient_shift;
520    const int C = CHANNELS(st->mode);
521    int mdct_weight_shift = 0;
522    int mdct_weight_pos=0;
523    int gain_id=0;
524    int norm_rate;
525    SAVE_STACK;
526
527    if (check_encoder(st) != CELT_OK)
528       return CELT_INVALID_STATE;
529
530    if (check_mode(st->mode) != CELT_OK)
531       return CELT_INVALID_MODE;
532
533    if (nbCompressedBytes<0 || pcm==NULL)
534      return CELT_BAD_ARG; 
535
536    /* The memset is important for now in case the encoder doesn't 
537       fill up all the bytes */
538    CELT_MEMSET(compressed, 0, nbCompressedBytes);
539    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
540    ec_enc_init(&enc,&buf);
541
542    N = st->block_size;
543    N4 = (N-st->overlap)>>1;
544    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
545
546    CELT_COPY(in, st->in_mem, C*st->overlap);
547    for (c=0;c<C;c++)
548    {
549       const celt_word16_t * restrict pcmp = pcm+c;
550       celt_sig_t * restrict inp = in+C*st->overlap+c;
551       for (i=0;i<N;i++)
552       {
553          /* Apply pre-emphasis */
554          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
555          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
556          st->preemph_memE[c] = SCALEIN(*pcmp);
557          inp += C;
558          pcmp += C;
559       }
560    }
561    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
562
563    /* Transient handling */
564    transient_time = -1;
565    transient_shift = 0;
566    shortBlocks = 0;
567
568    if (st->mode->nbShortMdcts > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
569    {
570 #ifndef FIXED_POINT
571       float gain_1;
572 #endif
573       /* Apply the inverse shaping window */
574       if (transient_shift)
575       {
576 #ifdef FIXED_POINT
577          for (c=0;c<C;c++)
578             for (i=0;i<16;i++)
579                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
580          for (c=0;c<C;c++)
581             for (i=transient_time;i<N+st->overlap;i++)
582                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
583 #else
584          for (c=0;c<C;c++)
585             for (i=0;i<16;i++)
586                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
587          gain_1 = 1./(1<<transient_shift);
588          for (c=0;c<C;c++)
589             for (i=transient_time;i<N+st->overlap;i++)
590                in[C*i+c] *= gain_1;
591 #endif
592       }
593       shortBlocks = 1;
594       has_fold = 1;
595    }
596
597    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
598    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
599    ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16_t);
600    /* Compute MDCTs */
601    compute_mdcts(st->mode, shortBlocks, in, freq);
602
603    if (shortBlocks && !transient_shift) 
604    {
605       celt_word32_t sum[8]={1,1,1,1,1,1,1,1};
606       int m;
607       for (c=0;c<C;c++)
608       {
609          m=0;
610          do {
611             celt_word32_t tmp=0;
612             for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
613                tmp += ABS32(freq[i]);
614             sum[m++] += tmp;
615          } while (m<st->mode->nbShortMdcts);
616       }
617       m=0;
618 #ifdef FIXED_POINT
619       do {
620          if (SHR32(sum[m+1],3) > sum[m])
621          {
622             mdct_weight_shift=2;
623             mdct_weight_pos = m;
624          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
625          {
626             mdct_weight_shift=1;
627             mdct_weight_pos = m;
628          }
629          m++;
630       } while (m<st->mode->nbShortMdcts-1);
631       if (mdct_weight_shift)
632       {
633          for (c=0;c<C;c++)
634             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
635                for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
636                   freq[i] = SHR32(freq[i],mdct_weight_shift);
637       }
638 #else
639       do {
640          if (sum[m+1] > 8*sum[m])
641          {
642             mdct_weight_shift=2;
643             mdct_weight_pos = m;
644          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
645          {
646             mdct_weight_shift=1;
647             mdct_weight_pos = m;
648          }
649          m++;
650       } while (m<st->mode->nbShortMdcts-1);
651       if (mdct_weight_shift)
652       {
653          for (c=0;c<C;c++)
654             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
655                for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
656                   freq[i] = (1./(1<<mdct_weight_shift))*freq[i];
657       }
658 #endif
659    }
660
661    norm_rate = (nbCompressedBytes-5)*8*(celt_uint32_t)st->mode->Fs/(C*N)>>10;
662    /* Pitch analysis: we do it early to save on the peak stack space */
663    /* Don't use pitch if there isn't enough data available yet, 
664       or if we're using shortBlocks */
665    has_pitch = st->pitch_enabled && st->pitch_permitted && (N <= 512) 
666             && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks)
667             && norm_rate < 50;
668    if (has_pitch)
669    {
670       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
671    }
672
673    /* Deferred allocation after find_spectral_pitch() to reduce 
674       the peak memory usage */
675    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
676
677    ALLOC(pitch_freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
678    if (has_pitch)
679    {
680       
681       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq);
682       has_pitch = compute_pitch_gain(st->mode, freq, pitch_freq, norm_rate, &gain_id);
683    }
684    
685    if (has_pitch)
686       apply_pitch(st->mode, freq, pitch_freq, gain_id, 1);
687
688    compute_band_energies(st->mode, freq, bandE);
689    for (i=0;i<st->mode->nbEBands*C;i++)
690       bandLogE[i] = amp2Log(bandE[i]);
691    
692    /* Band normalisation */
693    normalise_bands(st->mode, freq, X, bandE);
694    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision))
695       has_fold = 0;
696
697    /* Don't use intra energy when we're operating at low bit-rate */
698    intra_ener = st->force_intra || (!has_pitch && st->delayedIntra && nbCompressedBytes > st->mode->nbEBands);
699    if (shortBlocks || intra_decision(bandLogE, st->oldBandE, st->mode->nbEBands))
700       st->delayedIntra = 1;
701    else
702       st->delayedIntra = 0;
703
704
705    encode_flags(&enc, intra_ener, has_pitch, shortBlocks, has_fold);
706    if (has_pitch)
707    {
708       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
709       ec_enc_uint(&enc, gain_id, 16);
710    }
711    if (shortBlocks)
712    {
713       if (transient_shift)
714       {
715          ec_enc_uint(&enc, transient_shift, 4);
716          ec_enc_uint(&enc, transient_time, N+st->overlap);
717       } else {
718          ec_enc_uint(&enc, mdct_weight_shift, 4);
719          if (mdct_weight_shift && st->mode->nbShortMdcts!=2)
720             ec_enc_uint(&enc, mdct_weight_pos, st->mode->nbShortMdcts-1);
721       }
722    }
723
724    ALLOC(fine_quant, st->mode->nbEBands, int);
725    ALLOC(pulses, st->mode->nbEBands, int);
726
727    /* Bit allocation */
728    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
729    coarse_needed = quant_coarse_energy(st->mode, bandLogE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
730    coarse_needed = ((coarse_needed*3-1)>>3)+1;
731
732    /* Variable bitrate */
733    if (st->VBR_rate>0)
734    {
735      /* The target rate in 16th bits per frame */
736      int target=st->VBR_rate;
737    
738      /* Shortblocks get a large boost in bitrate, but since they 
739         are uncommon long blocks are not greatly effected */
740      if (shortBlocks)
741        target*=2;
742      else if (st->mode->nbShortMdcts > 1)
743        target-=(target+14)/28;     
744
745      /* The average energy is removed from the target and the actual 
746         energy added*/
747      target=target-588+ec_enc_tell(&enc, BITRES);
748
749      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
750      target=IMAX(coarse_needed,(target+64)/128);
751      nbCompressedBytes=IMIN(nbCompressedBytes,target);
752      ec_byte_shrink(&buf, nbCompressedBytes);
753    }
754
755    ALLOC(offsets, st->mode->nbEBands, int);
756    ALLOC(fine_priority, st->mode->nbEBands, int);
757
758    for (i=0;i<st->mode->nbEBands;i++)
759       offsets[i] = 0;
760    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
761    compute_allocation(st->mode, offsets, bits, pulses, fine_quant, fine_priority);
762
763    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
764
765    /* Residual quantisation */
766    if (C==1)
767       quant_bands(st->mode, X, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
768 #ifndef DISABLE_STEREO
769    else
770       quant_bands_stereo(st->mode, X, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
771 #endif
772
773    quant_energy_finalise(st->mode, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(&enc, 0), &enc);
774
775    /* Re-synthesis of the coded audio if required */
776    if (st->pitch_available>0 || optional_synthesis!=NULL)
777    {
778       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
779         st->pitch_available+=st->frame_size;
780
781       /* Synthesis */
782       denormalise_bands(st->mode, X, freq, bandE);
783       
784       
785       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
786       
787       if (mdct_weight_shift)
788       {
789          int m;
790          for (c=0;c<C;c++)
791             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
792                for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
793 #ifdef FIXED_POINT
794                   freq[i] = SHL32(freq[i], mdct_weight_shift);
795 #else
796                   freq[i] = (1<<mdct_weight_shift)*freq[i];
797 #endif
798       }
799       if (has_pitch)
800          apply_pitch(st->mode, freq, pitch_freq, gain_id, 0);
801       
802       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
803       /* De-emphasis and put everything back at the right place 
804          in the synthesis history */
805       if (optional_synthesis != NULL) {
806          for (c=0;c<C;c++)
807          {
808             int j;
809             for (j=0;j<N;j++)
810             {
811                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
812                                    preemph,st->preemph_memD[c]);
813                st->preemph_memD[c] = tmp;
814                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
815             }
816          }
817       }
818    }
819
820    ec_enc_done(&enc);
821    
822    RESTORE_STACK;
823    return nbCompressedBytes;
824 }
825
826 #ifdef FIXED_POINT
827 #ifndef DISABLE_FLOAT_API
828 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
829 {
830    int j, ret, C, N;
831    VARDECL(celt_int16_t, in);
832    SAVE_STACK;
833
834    if (check_encoder(st) != CELT_OK)
835       return CELT_INVALID_STATE;
836
837    if (check_mode(st->mode) != CELT_OK)
838       return CELT_INVALID_MODE;
839
840    if (pcm==NULL)
841       return CELT_BAD_ARG;
842
843    C = CHANNELS(st->mode);
844    N = st->block_size;
845    ALLOC(in, C*N, celt_int16_t);
846
847    for (j=0;j<C*N;j++)
848      in[j] = FLOAT2INT16(pcm[j]);
849
850    if (optional_synthesis != NULL) {
851      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
852       for (j=0;j<C*N;j++)
853          optional_synthesis[j]=in[j]*(1/32768.);
854    } else {
855      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
856    }
857    RESTORE_STACK;
858    return ret;
859
860 }
861 #endif /*DISABLE_FLOAT_API*/
862 #else
863 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
864 {
865    int j, ret, C, N;
866    VARDECL(celt_sig_t, in);
867    SAVE_STACK;
868
869    if (check_encoder(st) != CELT_OK)
870       return CELT_INVALID_STATE;
871
872    if (check_mode(st->mode) != CELT_OK)
873       return CELT_INVALID_MODE;
874
875    if (pcm==NULL)
876       return CELT_BAD_ARG;
877
878    C=CHANNELS(st->mode);
879    N=st->block_size;
880    ALLOC(in, C*N, celt_sig_t);
881    for (j=0;j<C*N;j++) {
882      in[j] = SCALEOUT(pcm[j]);
883    }
884
885    if (optional_synthesis != NULL) {
886       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
887       for (j=0;j<C*N;j++)
888          optional_synthesis[j] = FLOAT2INT16(in[j]);
889    } else {
890       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
891    }
892    RESTORE_STACK;
893    return ret;
894 }
895 #endif
896
897 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
898 {
899    va_list ap;
900    
901    if (check_encoder(st) != CELT_OK)
902       return CELT_INVALID_STATE;
903
904    va_start(ap, request);
905    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
906      goto bad_mode;
907    switch (request)
908    {
909       case CELT_GET_MODE_REQUEST:
910       {
911          const CELTMode ** value = va_arg(ap, const CELTMode**);
912          if (value==0)
913             goto bad_arg;
914          *value=st->mode;
915       }
916       break;
917       case CELT_SET_COMPLEXITY_REQUEST:
918       {
919          int value = va_arg(ap, celt_int32_t);
920          if (value<0 || value>10)
921             goto bad_arg;
922          if (value<=2) {
923             st->pitch_enabled = 0; 
924             st->pitch_available = 0;
925          } else {
926               st->pitch_enabled = 1;
927               if (st->pitch_available<1)
928                 st->pitch_available = 1;
929          }   
930       }
931       break;
932       case CELT_SET_PREDICTION_REQUEST:
933       {
934          int value = va_arg(ap, celt_int32_t);
935          if (value<0 || value>2)
936             goto bad_arg;
937          if (value==0)
938          {
939             st->force_intra   = 1;
940             st->pitch_permitted = 0;
941          } else if (value==1) {
942             st->force_intra   = 0;
943             st->pitch_permitted = 0;
944          } else {
945             st->force_intra   = 0;
946             st->pitch_permitted = 1;
947          }   
948       }
949       break;
950       case CELT_SET_VBR_RATE_REQUEST:
951       {
952          celt_int32_t value = va_arg(ap, celt_int32_t);
953          if (value<0)
954             goto bad_arg;
955          if (value>3072000)
956             value = 3072000;
957          st->VBR_rate = ((st->mode->Fs<<3)+(st->block_size>>1))/st->block_size;
958          st->VBR_rate = ((value<<7)+(st->VBR_rate>>1))/st->VBR_rate;
959       }
960       break;
961       case CELT_RESET_STATE:
962       {
963          const CELTMode *mode = st->mode;
964          int C = mode->nbChannels;
965
966          if (st->pitch_available > 0) st->pitch_available = 1;
967
968          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
969          CELT_MEMSET(st->out_mem, 0, (MAX_PERIOD+st->overlap)*C);
970
971          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
972
973          CELT_MEMSET(st->preemph_memE, 0, C);
974          CELT_MEMSET(st->preemph_memD, 0, C);
975          st->delayedIntra = 1;
976       }
977       break;
978       default:
979          goto bad_request;
980    }
981    va_end(ap);
982    return CELT_OK;
983 bad_mode:
984   va_end(ap);
985   return CELT_INVALID_MODE;
986 bad_arg:
987    va_end(ap);
988    return CELT_BAD_ARG;
989 bad_request:
990    va_end(ap);
991    return CELT_UNIMPLEMENTED;
992 }
993
994 /**********************************************************************/
995 /*                                                                    */
996 /*                             DECODER                                */
997 /*                                                                    */
998 /**********************************************************************/
999 #ifdef NEW_PLC
1000 #define DECODE_BUFFER_SIZE 2048
1001 #else
1002 #define DECODE_BUFFER_SIZE MAX_PERIOD
1003 #endif
1004
1005 #define DECODERVALID   0x4c434454
1006 #define DECODERPARTIAL 0x5444434c
1007 #define DECODERFREED   0x4c004400
1008
1009 /** Decoder state 
1010  @brief Decoder state
1011  */
1012 struct CELTDecoder {
1013    celt_uint32_t marker;
1014    const CELTMode *mode;
1015    int frame_size;
1016    int block_size;
1017    int overlap;
1018
1019    ec_byte_buffer buf;
1020    ec_enc         enc;
1021
1022    celt_sig_t * restrict preemph_memD;
1023
1024    celt_sig_t *out_mem;
1025    celt_sig_t *decode_mem;
1026
1027    celt_word16_t *oldBandE;
1028    
1029    int last_pitch_index;
1030    int loss_count;
1031 };
1032
1033 int check_decoder(const CELTDecoder *st) 
1034 {
1035    if (st==NULL)
1036    {
1037       celt_warning("NULL passed a decoder structure");  
1038       return CELT_INVALID_STATE;
1039    }
1040    if (st->marker == DECODERVALID)
1041       return CELT_OK;
1042    if (st->marker == DECODERFREED)
1043       celt_warning("Referencing a decoder that has already been freed");
1044    else
1045       celt_warning("This is not a valid CELT decoder structure");
1046    return CELT_INVALID_STATE;
1047 }
1048
1049 CELTDecoder *celt_decoder_create(const CELTMode *mode)
1050 {
1051    int N, C;
1052    CELTDecoder *st;
1053
1054    if (check_mode(mode) != CELT_OK)
1055       return NULL;
1056
1057    N = mode->mdctSize;
1058    C = CHANNELS(mode);
1059    st = celt_alloc(sizeof(CELTDecoder));
1060
1061    if (st==NULL)
1062       return NULL;
1063    
1064    st->marker = DECODERPARTIAL;
1065    st->mode = mode;
1066    st->frame_size = N;
1067    st->block_size = N;
1068    st->overlap = mode->overlap;
1069
1070    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
1071    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
1072    
1073    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
1074    
1075    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
1076
1077    st->loss_count = 0;
1078
1079    if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
1080        (st->preemph_memD!=NULL))
1081    {
1082       st->marker = DECODERVALID;
1083       return st;
1084    }
1085    /* If the setup fails for some reason deallocate it. */
1086    celt_decoder_destroy(st);
1087    return NULL;
1088 }
1089
1090 void celt_decoder_destroy(CELTDecoder *st)
1091 {
1092    if (st == NULL)
1093    {
1094       celt_warning("NULL passed to celt_decoder_destroy");
1095       return;
1096    }
1097
1098    if (st->marker == DECODERFREED) 
1099    {
1100       celt_warning("Freeing a decoder which has already been freed"); 
1101       return;
1102    }
1103    
1104    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1105    {
1106       celt_warning("This is not a valid CELT decoder structure");
1107       return;
1108    }
1109    
1110    /*Check_mode is non-fatal here because we can still free
1111      the encoder memory even if the mode is bad, although calling
1112      the free functions in this order is a violation of the API.*/
1113    check_mode(st->mode);
1114    
1115    celt_free(st->decode_mem);
1116    celt_free(st->oldBandE);
1117    celt_free(st->preemph_memD);
1118    
1119    st->marker = DECODERFREED;
1120    
1121    celt_free(st);
1122 }
1123
1124 /** Handles lost packets by just copying past data with the same
1125     offset as the last
1126     pitch period */
1127 #ifdef NEW_PLC
1128 #include "plc.c"
1129 #else
1130 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
1131 {
1132    int c, N;
1133    int pitch_index;
1134    celt_word16_t fade = Q15ONE;
1135    int i, len;
1136    VARDECL(celt_sig_t, freq);
1137    const int C = CHANNELS(st->mode);
1138    int offset;
1139    SAVE_STACK;
1140    N = st->block_size;
1141    ALLOC(freq,C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1142    
1143    len = N+st->mode->overlap;
1144    
1145    if (st->loss_count == 0)
1146    {
1147       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
1148       pitch_index = MAX_PERIOD-len-pitch_index;
1149       st->last_pitch_index = pitch_index;
1150    } else {
1151       pitch_index = st->last_pitch_index;
1152       if (st->loss_count < 5)
1153          fade = QCONST16(.8f,15);
1154       else
1155          fade = 0;
1156    }
1157
1158    offset = MAX_PERIOD-pitch_index;
1159    while (offset+len >= MAX_PERIOD)
1160       offset -= pitch_index;
1161    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
1162    for (i=0;i<C*N;i++)
1163       freq[i] = ADD32(VERY_SMALL, MULT16_32_Q15(fade,freq[i]));
1164
1165    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
1166    /* Compute inverse MDCTs */
1167    compute_inv_mdcts(st->mode, 0, freq, -1, 0, st->out_mem);
1168
1169    for (c=0;c<C;c++)
1170    {
1171       int j;
1172       for (j=0;j<N;j++)
1173       {
1174          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1175                                 preemph,st->preemph_memD[c]);
1176          st->preemph_memD[c] = tmp;
1177          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1178       }
1179    }
1180    
1181    st->loss_count++;
1182
1183    RESTORE_STACK;
1184 }
1185 #endif
1186
1187 #ifdef FIXED_POINT
1188 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1189 {
1190 #else
1191 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
1192 {
1193 #endif
1194    int i, c, N, N4;
1195    int has_pitch, has_fold;
1196    int pitch_index;
1197    int bits;
1198    ec_dec dec;
1199    ec_byte_buffer buf;
1200    VARDECL(celt_sig_t, freq);
1201    VARDECL(celt_sig_t, pitch_freq);
1202    VARDECL(celt_norm_t, X);
1203    VARDECL(celt_ener_t, bandE);
1204    VARDECL(int, fine_quant);
1205    VARDECL(int, pulses);
1206    VARDECL(int, offsets);
1207    VARDECL(int, fine_priority);
1208
1209    int shortBlocks;
1210    int intra_ener;
1211    int transient_time;
1212    int transient_shift;
1213    int mdct_weight_shift=0;
1214    const int C = CHANNELS(st->mode);
1215    int mdct_weight_pos=0;
1216    int gain_id=0;
1217    SAVE_STACK;
1218
1219    if (check_decoder(st) != CELT_OK)
1220       return CELT_INVALID_STATE;
1221
1222    if (check_mode(st->mode) != CELT_OK)
1223       return CELT_INVALID_MODE;
1224
1225    if (pcm==NULL)
1226       return CELT_BAD_ARG;
1227
1228    N = st->block_size;
1229    N4 = (N-st->overlap)>>1;
1230
1231    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1232    ALLOC(X, C*N, celt_norm_t);   /**< Interleaved normalised MDCTs */
1233    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1234    
1235    if (data == NULL)
1236    {
1237       celt_decode_lost(st, pcm);
1238       RESTORE_STACK;
1239       return 0;
1240    } else {
1241       st->loss_count = 0;
1242    }
1243    if (len<0) {
1244      RESTORE_STACK;
1245      return CELT_BAD_ARG;
1246    }
1247    
1248    ec_byte_readinit(&buf,(unsigned char*)data,len);
1249    ec_dec_init(&dec,&buf);
1250    
1251    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1252    if (shortBlocks)
1253    {
1254       transient_shift = ec_dec_uint(&dec, 4);
1255       if (transient_shift == 3)
1256       {
1257          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1258       } else {
1259          mdct_weight_shift = transient_shift;
1260          if (mdct_weight_shift && st->mode->nbShortMdcts>2)
1261             mdct_weight_pos = ec_dec_uint(&dec, st->mode->nbShortMdcts-1);
1262          transient_shift = 0;
1263          transient_time = 0;
1264       }
1265    } else {
1266       transient_time = -1;
1267       transient_shift = 0;
1268    }
1269    
1270    if (has_pitch)
1271    {
1272       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1273       gain_id = ec_dec_uint(&dec, 16);
1274    } else {
1275       pitch_index = 0;
1276    }
1277
1278    ALLOC(fine_quant, st->mode->nbEBands, int);
1279    /* Get band energies */
1280    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, intra_ener, st->mode->prob, &dec);
1281    
1282    ALLOC(pulses, st->mode->nbEBands, int);
1283    ALLOC(offsets, st->mode->nbEBands, int);
1284    ALLOC(fine_priority, st->mode->nbEBands, int);
1285
1286    for (i=0;i<st->mode->nbEBands;i++)
1287       offsets[i] = 0;
1288
1289    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1290    compute_allocation(st->mode, offsets, bits, pulses, fine_quant, fine_priority);
1291    /*bits = ec_dec_tell(&dec, 0);
1292    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1293    
1294    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1295
1296    ALLOC(pitch_freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1297    if (has_pitch) 
1298    {
1299       /* Pitch MDCT */
1300       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq);
1301    }
1302
1303    /* Decode fixed codebook and merge with pitch */
1304    if (C==1)
1305       unquant_bands(st->mode, X, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1306 #ifndef DISABLE_STEREO
1307    else
1308       unquant_bands_stereo(st->mode, X, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1309 #endif
1310    unquant_energy_finalise(st->mode, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(&dec, 0), &dec);
1311    
1312    /* Synthesis */
1313    denormalise_bands(st->mode, X, freq, bandE);
1314
1315
1316    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1317    if (mdct_weight_shift)
1318    {
1319       int m;
1320       for (c=0;c<C;c++)
1321          for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
1322             for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
1323 #ifdef FIXED_POINT
1324                freq[i] = SHL32(freq[i], mdct_weight_shift);
1325 #else
1326                freq[i] = (1<<mdct_weight_shift)*freq[i];
1327 #endif
1328    }
1329    
1330    if (has_pitch)
1331       apply_pitch(st->mode, freq, pitch_freq, gain_id, 0);
1332
1333    /* Compute inverse MDCTs */
1334    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1335
1336    for (c=0;c<C;c++)
1337    {
1338       int j;
1339       for (j=0;j<N;j++)
1340       {
1341          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1342                                 preemph,st->preemph_memD[c]);
1343          st->preemph_memD[c] = tmp;
1344          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1345       }
1346    }
1347
1348    RESTORE_STACK;
1349    return 0;
1350 }
1351
1352 #ifdef FIXED_POINT
1353 #ifndef DISABLE_FLOAT_API
1354 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1355 {
1356    int j, ret, C, N;
1357    VARDECL(celt_int16_t, out);
1358    SAVE_STACK;
1359
1360    if (check_decoder(st) != CELT_OK)
1361       return CELT_INVALID_STATE;
1362
1363    if (check_mode(st->mode) != CELT_OK)
1364       return CELT_INVALID_MODE;
1365
1366    if (pcm==NULL)
1367       return CELT_BAD_ARG;
1368
1369    C = CHANNELS(st->mode);
1370    N = st->block_size;
1371    
1372    ALLOC(out, C*N, celt_int16_t);
1373    ret=celt_decode(st, data, len, out);
1374    for (j=0;j<C*N;j++)
1375       pcm[j]=out[j]*(1/32768.);
1376      
1377    RESTORE_STACK;
1378    return ret;
1379 }
1380 #endif /*DISABLE_FLOAT_API*/
1381 #else
1382 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1383 {
1384    int j, ret, C, N;
1385    VARDECL(celt_sig_t, out);
1386    SAVE_STACK;
1387
1388    if (check_decoder(st) != CELT_OK)
1389       return CELT_INVALID_STATE;
1390
1391    if (check_mode(st->mode) != CELT_OK)
1392       return CELT_INVALID_MODE;
1393
1394    if (pcm==NULL)
1395       return CELT_BAD_ARG;
1396
1397    C = CHANNELS(st->mode);
1398    N = st->block_size;
1399    ALLOC(out, C*N, celt_sig_t);
1400
1401    ret=celt_decode_float(st, data, len, out);
1402
1403    for (j=0;j<C*N;j++)
1404       pcm[j] = FLOAT2INT16 (out[j]);
1405    
1406    RESTORE_STACK;
1407    return ret;
1408 }
1409 #endif
1410
1411 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1412 {
1413    va_list ap;
1414
1415    if (check_decoder(st) != CELT_OK)
1416       return CELT_INVALID_STATE;
1417
1418    va_start(ap, request);
1419    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1420      goto bad_mode;
1421    switch (request)
1422    {
1423       case CELT_GET_MODE_REQUEST:
1424       {
1425          const CELTMode ** value = va_arg(ap, const CELTMode**);
1426          if (value==0)
1427             goto bad_arg;
1428          *value=st->mode;
1429       }
1430       break;
1431       case CELT_RESET_STATE:
1432       {
1433          const CELTMode *mode = st->mode;
1434          int C = mode->nbChannels;
1435
1436          CELT_MEMSET(st->decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1437          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1438
1439          CELT_MEMSET(st->preemph_memD, 0, C);
1440
1441          st->loss_count = 0;
1442       }
1443       break;
1444       default:
1445          goto bad_request;
1446    }
1447    va_end(ap);
1448    return CELT_OK;
1449 bad_mode:
1450   va_end(ap);
1451   return CELT_INVALID_MODE;
1452 bad_arg:
1453    va_end(ap);
1454    return CELT_BAD_ARG;
1455 bad_request:
1456       va_end(ap);
1457   return CELT_UNIMPLEMENTED;
1458 }