For celt_encoder_ctl CELT_SET_LTP is replaced with CELT_SET_PREDICTION
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67 #define ENCODERVALID   0x4c434554
68 #define ENCODERPARTIAL 0x5445434c
69 #define ENCODERFREED   0x4c004500
70    
71 /** Encoder state 
72  @brief Encoder state
73  */
74 struct CELTEncoder {
75    celt_uint32_t marker;
76    const CELTMode *mode;     /**< Mode used by the encoder */
77    int frame_size;
78    int block_size;
79    int overlap;
80    int channels;
81    
82    int pitch_enabled;       /* Complexity level is allowed to use pitch */
83    int pitch_permitted;     /*  Use of the LTP is permitted by the user */
84    int pitch_available;     /*  Amount of pitch buffer available */
85    int force_intra;
86    int delayedIntra;
87    celt_word16_t tonal_average;
88    int fold_decision;
89
90    int VBR_rate; /* Target number of 16th bits per frame */
91    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
92    celt_sig_t    * restrict preemph_memD;
93
94    celt_sig_t *in_mem;
95    celt_sig_t *out_mem;
96
97    celt_word16_t *oldBandE;
98 #ifdef EXP_PSY
99    celt_word16_t *psy_mem;
100    struct PsyDecay psy;
101 #endif
102 };
103
104 int check_encoder(const CELTEncoder *st) 
105 {
106    if (st==NULL)
107    {
108       celt_warning("NULL passed as an encoder structure");  
109       return CELT_INVALID_STATE;
110    }
111    if (st->marker == ENCODERVALID)
112       return CELT_OK;
113    if (st->marker == ENCODERFREED)
114       celt_warning("Referencing an encoder that has already been freed");
115    else
116       celt_warning("This is not a valid CELT encoder structure");
117    return CELT_INVALID_STATE;
118 }
119
120 CELTEncoder *celt_encoder_create(const CELTMode *mode)
121 {
122    int N, C;
123    CELTEncoder *st;
124
125    if (check_mode(mode) != CELT_OK)
126       return NULL;
127
128    N = mode->mdctSize;
129    C = mode->nbChannels;
130    st = celt_alloc(sizeof(CELTEncoder));
131    
132    if (st==NULL) 
133       return NULL;   
134    st->marker = ENCODERPARTIAL;
135    st->mode = mode;
136    st->frame_size = N;
137    st->block_size = N;
138    st->overlap = mode->overlap;
139
140    st->VBR_rate = 0;
141    st->pitch_enabled = 1;
142    st->pitch_permitted = 1;
143    st->pitch_available = 1;
144    st->force_intra  = 0;
145    st->delayedIntra = 1;
146    st->tonal_average = QCONST16(1.,8);
147    st->fold_decision = 1;
148
149    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
150    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
151
152    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
153
154    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
155    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
156
157 #ifdef EXP_PSY
158    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
159    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
160 #endif
161
162    if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) 
163 #ifdef EXP_PSY
164        && (st->psy_mem!=NULL) 
165 #endif   
166        && (st->preemph_memE!=NULL) && (st->preemph_memD!=NULL))
167    {
168       st->marker   = ENCODERVALID;
169       return st;
170    }
171    /* If the setup fails for some reason deallocate it. */
172    celt_encoder_destroy(st);  
173    return NULL;
174 }
175
176 void celt_encoder_destroy(CELTEncoder *st)
177 {
178    if (st == NULL)
179    {
180       celt_warning("NULL passed to celt_encoder_destroy");
181       return;
182    }
183
184    if (st->marker == ENCODERFREED)
185    {
186       celt_warning("Freeing an encoder which has already been freed"); 
187       return;
188    }
189
190    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
191    {
192       celt_warning("This is not a valid CELT encoder structure");
193       return;
194    }
195    /*Check_mode is non-fatal here because we can still free
196     the encoder memory even if the mode is bad, although calling
197     the free functions in this order is a violation of the API.*/
198    check_mode(st->mode);
199    
200    celt_free(st->in_mem);
201    celt_free(st->out_mem);
202    
203    celt_free(st->oldBandE);
204    
205    celt_free(st->preemph_memE);
206    celt_free(st->preemph_memD);
207    
208 #ifdef EXP_PSY
209    celt_free (st->psy_mem);
210    psydecay_clear(&st->psy);
211 #endif
212    st->marker = ENCODERFREED;
213    
214    celt_free(st);
215 }
216
217 static inline celt_int16_t FLOAT2INT16(float x)
218 {
219    x = x*CELT_SIG_SCALE;
220    x = MAX32(x, -32768);
221    x = MIN32(x, 32767);
222    return (celt_int16_t)float2int(x);
223 }
224
225 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
226 {
227 #ifdef FIXED_POINT
228    x = PSHR32(x, SIG_SHIFT);
229    x = MAX32(x, -32768);
230    x = MIN32(x, 32767);
231    return EXTRACT16(x);
232 #else
233    return (celt_word16_t)x;
234 #endif
235 }
236
237 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
238 {
239    int c, i, n;
240    celt_word32_t ratio;
241    /* FIXME: Remove the floats here */
242    VARDECL(celt_word32_t, begin);
243    SAVE_STACK;
244    ALLOC(begin, len, celt_word32_t);
245    for (i=0;i<len;i++)
246       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
247    for (c=1;c<C;c++)
248    {
249       for (i=0;i<len;i++)
250          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
251    }
252    for (i=1;i<len;i++)
253       begin[i] = MAX32(begin[i-1],begin[i]);
254    n = -1;
255    for (i=8;i<len-8;i++)
256    {
257       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
258          n=i;
259    }
260    if (n<32)
261    {
262       n = -1;
263       ratio = 0;
264    } else {
265       ratio = DIV32(begin[len-1],1+begin[n-16]);
266    }
267    /*printf ("%d %f\n", n, ratio*ratio);*/
268    if (ratio < 0)
269       ratio = 0;
270    if (ratio > 1000)
271       ratio = 1000;
272    ratio *= ratio;
273    
274    if (ratio > 2048)
275       *transient_shift = 3;
276    else
277       *transient_shift = 0;
278    
279    *transient_time = n;
280    
281    RESTORE_STACK;
282    return ratio > 20;
283 }
284
285 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
286 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
287 {
288    const int C = CHANNELS(mode);
289    if (C==1 && !shortBlocks)
290    {
291       const mdct_lookup *lookup = MDCT(mode);
292       const int overlap = OVERLAP(mode);
293       mdct_forward(lookup, in, out, mode->window, overlap);
294    } else if (!shortBlocks) {
295       const mdct_lookup *lookup = MDCT(mode);
296       const int overlap = OVERLAP(mode);
297       const int N = FRAMESIZE(mode);
298       int c;
299       VARDECL(celt_word32_t, x);
300       VARDECL(celt_word32_t, tmp);
301       SAVE_STACK;
302       ALLOC(x, N+overlap, celt_word32_t);
303       ALLOC(tmp, N, celt_word32_t);
304       for (c=0;c<C;c++)
305       {
306          int j;
307          for (j=0;j<N+overlap;j++)
308             x[j] = in[C*j+c];
309          mdct_forward(lookup, x, tmp, mode->window, overlap);
310          /* Interleaving the sub-frames */
311          for (j=0;j<N;j++)
312             out[C*j+c] = tmp[j];
313       }
314       RESTORE_STACK;
315    } else {
316       const mdct_lookup *lookup = &mode->shortMdct;
317       const int overlap = mode->overlap;
318       const int N = mode->shortMdctSize;
319       int b, c;
320       VARDECL(celt_word32_t, x);
321       VARDECL(celt_word32_t, tmp);
322       SAVE_STACK;
323       ALLOC(x, N+overlap, celt_word32_t);
324       ALLOC(tmp, N, celt_word32_t);
325       for (c=0;c<C;c++)
326       {
327          int B = mode->nbShortMdcts;
328          for (b=0;b<B;b++)
329          {
330             int j;
331             for (j=0;j<N+overlap;j++)
332                x[j] = in[C*(b*N+j)+c];
333             mdct_forward(lookup, x, tmp, mode->window, overlap);
334             /* Interleaving the sub-frames */
335             for (j=0;j<N;j++)
336                out[C*(j*B+b)+c] = tmp[j];
337          }
338       }
339       RESTORE_STACK;
340    }
341 }
342
343 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
344 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
345 {
346    int c, N4;
347    const int C = CHANNELS(mode);
348    const int N = FRAMESIZE(mode);
349    const int overlap = OVERLAP(mode);
350    N4 = (N-overlap)>>1;
351    for (c=0;c<C;c++)
352    {
353       int j;
354       if (transient_shift==0 && C==1 && !shortBlocks) {
355          const mdct_lookup *lookup = MDCT(mode);
356          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
357       } else if (!shortBlocks) {
358          const mdct_lookup *lookup = MDCT(mode);
359          VARDECL(celt_word32_t, x);
360          VARDECL(celt_word32_t, tmp);
361          SAVE_STACK;
362          ALLOC(x, 2*N, celt_word32_t);
363          ALLOC(tmp, N, celt_word32_t);
364          /* De-interleaving the sub-frames */
365          for (j=0;j<N;j++)
366             tmp[j] = X[C*j+c];
367          /* Prevents problems from the imdct doing the overlap-add */
368          CELT_MEMSET(x+N4, 0, N);
369          mdct_backward(lookup, tmp, x, mode->window, overlap);
370          celt_assert(transient_shift == 0);
371          /* The first and last part would need to be set to zero if we actually
372             wanted to use them. */
373          for (j=0;j<overlap;j++)
374             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
375          for (j=0;j<overlap;j++)
376             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
377          for (j=0;j<2*N4;j++)
378             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
379          RESTORE_STACK;
380       } else {
381          int b;
382          const int N2 = mode->shortMdctSize;
383          const int B = mode->nbShortMdcts;
384          const mdct_lookup *lookup = &mode->shortMdct;
385          VARDECL(celt_word32_t, x);
386          VARDECL(celt_word32_t, tmp);
387          SAVE_STACK;
388          ALLOC(x, 2*N, celt_word32_t);
389          ALLOC(tmp, N, celt_word32_t);
390          /* Prevents problems from the imdct doing the overlap-add */
391          CELT_MEMSET(x+N4, 0, N2);
392          for (b=0;b<B;b++)
393          {
394             /* De-interleaving the sub-frames */
395             for (j=0;j<N2;j++)
396                tmp[j] = X[C*(j*B+b)+c];
397             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
398          }
399          if (transient_shift > 0)
400          {
401 #ifdef FIXED_POINT
402             for (j=0;j<16;j++)
403                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
404             for (j=transient_time;j<N+overlap;j++)
405                x[N4+j] = SHL32(x[N4+j], transient_shift);
406 #else
407             for (j=0;j<16;j++)
408                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
409             for (j=transient_time;j<N+overlap;j++)
410                x[N4+j] *= 1<<transient_shift;
411 #endif
412          }
413          /* The first and last part would need to be set to zero if we actually
414          wanted to use them. */
415          for (j=0;j<overlap;j++)
416             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
417          for (j=0;j<overlap;j++)
418             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
419          for (j=0;j<2*N4;j++)
420             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
421          RESTORE_STACK;
422       }
423    }
424 }
425
426 #define FLAG_NONE        0
427 #define FLAG_INTRA       1U<<16
428 #define FLAG_PITCH       1U<<15
429 #define FLAG_SHORT       1U<<14
430 #define FLAG_FOLD        1U<<13
431 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
432
433 celt_int32_t flaglist[8] = {
434       0 /*00  */ | FLAG_FOLD,
435       1 /*01  */ | FLAG_PITCH|FLAG_FOLD,
436       8 /*1000*/ | FLAG_NONE,
437       9 /*1001*/ | FLAG_SHORT|FLAG_FOLD,
438      10 /*1010*/ | FLAG_PITCH,
439      11 /*1011*/ | FLAG_INTRA,
440       6 /*110 */ | FLAG_INTRA|FLAG_FOLD,
441       7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
442 };
443
444 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
445 {
446    int i;
447    int flags=FLAG_NONE;
448    int flag_bits;
449    flags |= intra_ener   ? FLAG_INTRA : 0;
450    flags |= has_pitch    ? FLAG_PITCH : 0;
451    flags |= shortBlocks  ? FLAG_SHORT : 0;
452    flags |= has_fold     ? FLAG_FOLD  : 0;
453    for (i=0;i<8;i++)
454       if (flags == (flaglist[i]&FLAG_MASK))
455          break;
456    celt_assert(i<8);
457    flag_bits = flaglist[i]&0xf;
458    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
459    if (i<2)
460       ec_enc_bits(enc, flag_bits, 2);
461    else if (i<6)
462       ec_enc_bits(enc, flag_bits, 4);
463    else
464       ec_enc_bits(enc, flag_bits, 3);
465 }
466
467 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
468 {
469    int i;
470    int flag_bits;
471    flag_bits = ec_dec_bits(dec, 2);
472    /*printf ("(%d) ", flag_bits);*/
473    if (flag_bits==2)
474       flag_bits = (flag_bits<<2) | ec_dec_bits(dec, 2);
475    else if (flag_bits==3)
476       flag_bits = (flag_bits<<1) | ec_dec_bits(dec, 1);
477    for (i=0;i<8;i++)
478       if (flag_bits == (flaglist[i]&0xf))
479          break;
480    celt_assert(i<8);
481    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
482    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
483    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
484    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
485    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
486 }
487
488 #ifdef FIXED_POINT
489 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
490 {
491 #else
492 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
493 {
494 #endif
495    int i, c, N, N4;
496    int has_pitch;
497    int pitch_index;
498    int bits;
499    int has_fold=1;
500    unsigned coarse_needed;
501    ec_byte_buffer buf;
502    ec_enc         enc;
503    VARDECL(celt_sig_t, in);
504    VARDECL(celt_sig_t, freq);
505    VARDECL(celt_norm_t, X);
506    VARDECL(celt_norm_t, P);
507    VARDECL(celt_ener_t, bandE);
508    VARDECL(celt_pgain_t, gains);
509    VARDECL(int, fine_quant);
510    VARDECL(celt_word16_t, error);
511    VARDECL(int, pulses);
512    VARDECL(int, offsets);
513 #ifdef EXP_PSY
514    VARDECL(celt_word32_t, mask);
515    VARDECL(celt_word32_t, tonality);
516    VARDECL(celt_word32_t, bandM);
517    VARDECL(celt_ener_t, bandN);
518 #endif
519    int intra_ener = 0;
520    int shortBlocks=0;
521    int transient_time;
522    int transient_shift;
523    const int C = CHANNELS(st->mode);
524    int mdct_weight_shift = 0;
525    int mdct_weight_pos=0;
526    SAVE_STACK;
527
528    if (check_encoder(st) != CELT_OK)
529       return CELT_INVALID_STATE;
530
531    if (check_mode(st->mode) != CELT_OK)
532       return CELT_INVALID_MODE;
533
534    if (nbCompressedBytes<0)
535      return CELT_BAD_ARG; 
536
537    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
538    CELT_MEMSET(compressed, 0, nbCompressedBytes);
539    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
540    ec_enc_init(&enc,&buf);
541
542    N = st->block_size;
543    N4 = (N-st->overlap)>>1;
544    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
545
546    CELT_COPY(in, st->in_mem, C*st->overlap);
547    for (c=0;c<C;c++)
548    {
549       const celt_word16_t * restrict pcmp = pcm+c;
550       celt_sig_t * restrict inp = in+C*st->overlap+c;
551       for (i=0;i<N;i++)
552       {
553          /* Apply pre-emphasis */
554          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
555          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
556          st->preemph_memE[c] = SCALEIN(*pcmp);
557          inp += C;
558          pcmp += C;
559       }
560    }
561    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
562    
563    /* Transient handling */
564    if (st->mode->nbShortMdcts > 1)
565    {
566       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
567       {
568 #ifndef FIXED_POINT
569          float gain_1;
570 #endif
571          /* Apply the inverse shaping window */
572          if (transient_shift)
573          {
574 #ifdef FIXED_POINT
575             for (c=0;c<C;c++)
576                for (i=0;i<16;i++)
577                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
578             for (c=0;c<C;c++)
579                for (i=transient_time;i<N+st->overlap;i++)
580                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
581 #else
582             for (c=0;c<C;c++)
583                for (i=0;i<16;i++)
584                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
585             gain_1 = 1./(1<<transient_shift);
586             for (c=0;c<C;c++)
587                for (i=transient_time;i<N+st->overlap;i++)
588                   in[C*i+c] *= gain_1;
589 #endif
590          }
591          shortBlocks = 1;
592          has_fold = 1;
593       } else {
594          transient_time = -1;
595          transient_shift = 0;
596          shortBlocks = 0;
597       }
598    } else {
599       transient_time = -1;
600       transient_shift = 0;
601       shortBlocks = 0;
602    }
603
604    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
605    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
606    /* Compute MDCTs */
607    compute_mdcts(st->mode, shortBlocks, in, freq);
608    if (shortBlocks && !transient_shift) 
609    {
610       celt_word32_t sum[4]={1,1,1,1};
611       int m;
612       for (c=0;c<C;c++)
613       {
614          m=0;
615          do {
616             celt_word32_t tmp=0;
617             for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
618                tmp += ABS32(freq[i]);
619             sum[m++] += tmp;
620          } while (m<st->mode->nbShortMdcts);
621       }
622       m=0;
623 #ifdef FIXED_POINT
624       do {
625          if (SHR32(sum[m+1],3) > sum[m])
626          {
627             mdct_weight_shift=2;
628             mdct_weight_pos = m;
629          } else if (SHR32(sum[m+1],1) > sum[m] && mdct_weight_shift < 2)
630          {
631             mdct_weight_shift=1;
632             mdct_weight_pos = m;
633          }
634          m++;
635       } while (m<st->mode->nbShortMdcts-1);
636       if (mdct_weight_shift)
637       {
638          for (c=0;c<C;c++)
639             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
640                for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
641                   freq[i] = SHR32(freq[i],mdct_weight_shift);
642       }
643 #else
644       do {
645          if (sum[m+1] > 8*sum[m])
646          {
647             mdct_weight_shift=2;
648             mdct_weight_pos = m;
649          } else if (sum[m+1] > 2*sum[m] && mdct_weight_shift < 2)
650          {
651             mdct_weight_shift=1;
652             mdct_weight_pos = m;
653          }
654          m++;
655       } while (m<st->mode->nbShortMdcts-1);
656       if (mdct_weight_shift)
657       {
658          for (c=0;c<C;c++)
659             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
660                for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
661                   freq[i] = (1./(1<<mdct_weight_shift))*freq[i];
662       }
663 #endif
664       /*printf ("%f\n", short_ratio);*/
665       /*if (short_ratio < 1)
666          short_ratio = 1;
667       short_ratio = 1<<(int)floor(.5+log2(short_ratio));
668       if (short_ratio>4)
669          short_ratio = 4;*/
670    }/* else if (transient_shift)
671       printf ("8\n");
672       else printf ("1\n");*/
673
674    compute_band_energies(st->mode, freq, bandE);
675
676    intra_ener = (st->force_intra || st->delayedIntra);
677    if (shortBlocks || intra_decision(bandE, st->oldBandE, st->mode->nbEBands))
678       st->delayedIntra = 1;
679    else
680       st->delayedIntra = 0;
681    /* Pitch analysis: we do it early to save on the peak stack space */
682    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
683    has_pitch = st->pitch_enabled && st->pitch_permitted && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
684 #ifdef EXP_PSY
685    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
686    {
687       VARDECL(celt_word16_t, X);
688       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
689       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
690       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
691    }
692 #else
693    if (has_pitch)
694    {
695       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
696    }
697 #endif
698
699 #ifdef EXP_PSY
700    ALLOC(mask, N, celt_sig_t);
701    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
702    /*for (i=0;i<256;i++)
703       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
704    printf ("\n");*/
705 #endif
706
707    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
708    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
709    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
710    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
711
712
713    /* Band normalisation */
714    normalise_bands(st->mode, freq, X, bandE);
715    if (!shortBlocks && !folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision))
716       has_fold = 0;
717 #ifdef EXP_PSY
718    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
719    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
720    compute_noise_energies(st->mode, freq, tonality, bandN);
721
722    /*for (i=0;i<st->mode->nbEBands;i++)
723       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
724    printf ("\n");*/
725    has_fold = 0;
726    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
727       if (bandN[i] < .4*bandE[i])
728          has_fold++;
729    /*printf ("%d\n", has_fold);*/
730    if (has_fold>=2)
731       has_fold = 0;
732    else
733       has_fold = 1;
734    for (i=0;i<N;i++)
735       mask[i] = sqrt(mask[i]);
736    compute_band_energies(st->mode, mask, bandM);
737    /*for (i=0;i<st->mode->nbEBands;i++)
738       printf ("%f %f ", bandE[i], bandM[i]);
739    printf ("\n");*/
740 #endif
741
742    /* Compute MDCTs of the pitch part */
743    if (has_pitch)
744    {
745       celt_word32_t curr_power, pitch_power=0;
746       /* Normalise the pitch vector as well (discard the energies) */
747       VARDECL(celt_ener_t, bandEp);
748       
749       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
750       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
751       compute_band_energies(st->mode, freq, bandEp);
752       normalise_bands(st->mode, freq, P, bandEp);
753       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
754       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
755       curr_power = bandE[0]+bandE[1]+bandE[2];
756       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
757       {
758          /* Pitch prediction */
759          has_pitch = compute_pitch_gain(st->mode, X, P, gains);
760       } else {
761          has_pitch = 0;
762       }
763    }
764    
765    encode_flags(&enc, intra_ener, has_pitch, shortBlocks, has_fold);
766    if (has_pitch)
767    {
768       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
769    } else {
770       for (i=0;i<st->mode->nbPBands;i++)
771          gains[i] = 0;
772       for (i=0;i<C*N;i++)
773          P[i] = 0;
774    }
775    if (shortBlocks)
776    {
777       if (transient_shift)
778       {
779          ec_enc_bits(&enc, transient_shift, 2);
780          ec_enc_uint(&enc, transient_time, N+st->overlap);
781       } else {
782          ec_enc_bits(&enc, mdct_weight_shift, 2);
783          if (mdct_weight_shift && st->mode->nbShortMdcts!=2)
784             ec_enc_uint(&enc, mdct_weight_pos, st->mode->nbShortMdcts-1);
785       }
786    }
787
788 #ifdef STDIN_TUNING2
789    static int fine_quant[30];
790    static int pulses[30];
791    static int init=0;
792    if (!init)
793    {
794       for (i=0;i<st->mode->nbEBands;i++)
795          scanf("%d ", &fine_quant[i]);
796       for (i=0;i<st->mode->nbEBands;i++)
797          scanf("%d ", &pulses[i]);
798       init = 1;
799    }
800 #else
801    ALLOC(fine_quant, st->mode->nbEBands, int);
802    ALLOC(pulses, st->mode->nbEBands, int);
803 #endif
804
805    /* Bit allocation */
806    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
807    coarse_needed = quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
808    coarse_needed = ((coarse_needed*3-1)>>3)+1;
809
810    /* Variable bitrate */
811    if (st->VBR_rate>0)
812    {
813      /* The target rate in 16th bits per frame */
814      int target=st->VBR_rate;
815    
816      /* Shortblocks get a large boost in bitrate, but since they are uncommon long blocks are not greatly effected */
817      if (shortBlocks)
818        target*=2;
819      else if (st->mode->nbShortMdcts > 1)
820        target-=(target+14)/28;     
821
822      /*The average energy is removed from the target and the actual energy added*/
823      target=target-588+ec_enc_tell(&enc, 4);
824
825      /* In VBR mode the frame size must not be reduced so much that it would result in the coarse energy busting its budget */
826      target=IMAX(coarse_needed,(target+64)/128);
827      nbCompressedBytes=IMIN(nbCompressedBytes,target);
828    }
829
830    ALLOC(offsets, st->mode->nbEBands, int);
831
832    for (i=0;i<st->mode->nbEBands;i++)
833       offsets[i] = 0;
834    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
835    if (has_pitch)
836       bits -= st->mode->nbPBands;
837 #ifndef STDIN_TUNING
838    compute_allocation(st->mode, offsets, bits, pulses, fine_quant);
839 #endif
840
841    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
842
843    /* Residual quantisation */
844    if (C==1)
845       quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
846 #ifndef DISABLE_STEREO
847    else
848       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
849 #endif
850    /* Re-synthesis of the coded audio if required */
851    if (st->pitch_available>0 || optional_synthesis!=NULL)
852    {
853       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
854         st->pitch_available+=st->frame_size;
855
856       /* Synthesis */
857       denormalise_bands(st->mode, X, freq, bandE);
858       
859       
860       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
861       
862       if (mdct_weight_shift)
863       {
864          int m;
865          for (c=0;c<C;c++)
866             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
867                for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
868 #ifdef FIXED_POINT
869                   freq[i] = SHL32(freq[i], mdct_weight_shift);
870 #else
871                   freq[i] = (1<<mdct_weight_shift)*freq[i];
872 #endif
873       }
874       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
875       /* De-emphasis and put everything back at the right place in the synthesis history */
876       if (optional_synthesis != NULL) {
877          for (c=0;c<C;c++)
878          {
879             int j;
880             for (j=0;j<N;j++)
881             {
882                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
883                                    preemph,st->preemph_memD[c]);
884                st->preemph_memD[c] = tmp;
885                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
886             }
887          }
888       }
889    }
890
891    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
892    /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
893       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
894
895    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
896    {
897       int val = 0;
898       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
899       {
900          ec_enc_uint(&enc, val, 2);
901          val = 1-val;
902       }
903    }
904    ec_enc_done(&enc);
905    {
906       /*unsigned char *data;*/
907       int nbBytes = ec_byte_bytes(&buf);
908       if (nbBytes > nbCompressedBytes)
909       {
910          celt_warning_int ("got too many bytes:", nbBytes);
911          RESTORE_STACK;
912          return CELT_INTERNAL_ERROR;
913       }
914    }
915
916    RESTORE_STACK;
917    return nbCompressedBytes;
918 }
919
920 #ifdef FIXED_POINT
921 #ifndef DISABLE_FLOAT_API
922 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
923 {
924    int j, ret, C, N;
925    VARDECL(celt_int16_t, in);
926
927    if (check_encoder(st) != CELT_OK)
928       return CELT_INVALID_STATE;
929
930    if (check_mode(st->mode) != CELT_OK)
931       return CELT_INVALID_MODE;
932
933    SAVE_STACK;
934    C = CHANNELS(st->mode);
935    N = st->block_size;
936    ALLOC(in, C*N, celt_int16_t);
937
938    for (j=0;j<C*N;j++)
939      in[j] = FLOAT2INT16(pcm[j]);
940
941    if (optional_synthesis != NULL) {
942      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
943       for (j=0;j<C*N;j++)
944          optional_synthesis[j]=in[j]*(1/32768.);
945    } else {
946      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
947    }
948    RESTORE_STACK;
949    return ret;
950
951 }
952 #endif /*DISABLE_FLOAT_API*/
953 #else
954 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
955 {
956    int j, ret, C, N;
957    VARDECL(celt_sig_t, in);
958
959    if (check_encoder(st) != CELT_OK)
960       return CELT_INVALID_STATE;
961
962    if (check_mode(st->mode) != CELT_OK)
963       return CELT_INVALID_MODE;
964
965    SAVE_STACK;
966    C=CHANNELS(st->mode);
967    N=st->block_size;
968    ALLOC(in, C*N, celt_sig_t);
969    for (j=0;j<C*N;j++) {
970      in[j] = SCALEOUT(pcm[j]);
971    }
972
973    if (optional_synthesis != NULL) {
974       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
975       for (j=0;j<C*N;j++)
976          optional_synthesis[j] = FLOAT2INT16(in[j]);
977    } else {
978       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
979    }
980    RESTORE_STACK;
981    return ret;
982 }
983 #endif
984
985 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
986 {
987    va_list ap;
988    
989    if (check_encoder(st) != CELT_OK)
990       return CELT_INVALID_STATE;
991
992    va_start(ap, request);
993    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
994      goto bad_mode;
995    switch (request)
996    {
997       case CELT_GET_MODE_REQUEST:
998       {
999          const CELTMode ** value = va_arg(ap, const CELTMode**);
1000          if (value==0)
1001             goto bad_arg;
1002          *value=st->mode;
1003       }
1004       break;
1005       case CELT_SET_COMPLEXITY_REQUEST:
1006       {
1007          int value = va_arg(ap, celt_int32_t);
1008          if (value<0 || value>10)
1009             goto bad_arg;
1010          if (value<=2) {
1011             st->pitch_enabled = 0; 
1012             st->pitch_available = 0;
1013          } else {
1014               st->pitch_enabled = 1;
1015               if (st->pitch_available<1)
1016                 st->pitch_available = 1;
1017          }   
1018       }
1019       break;
1020       case CELT_SET_PREDICTION_REQUEST:
1021       {
1022          int value = va_arg(ap, celt_int32_t);
1023          if (value<0 || value>2)
1024             goto bad_arg;
1025          if (value==0)
1026          {
1027             st->force_intra   = 1;
1028             st->pitch_permitted = 0;
1029          } else if (value=1) {
1030             st->force_intra   = 0;
1031             st->pitch_permitted = 0;
1032          } else {
1033             st->force_intra   = 0;
1034             st->pitch_permitted = 1;
1035          }   
1036       }
1037       break;
1038       case CELT_SET_VBR_RATE_REQUEST:
1039       {
1040          int value = va_arg(ap, celt_int32_t);
1041          if (value<0)
1042             goto bad_arg;
1043          if (value>3072000)
1044             value = 3072000;
1045          st->VBR_rate = ((st->mode->Fs<<3)+(st->block_size>>1))/st->block_size;
1046          st->VBR_rate = ((value<<7)+(st->VBR_rate>>1))/st->VBR_rate;
1047       }
1048       break;
1049       case CELT_RESET_STATE:
1050       {
1051          const CELTMode *mode = st->mode;
1052          int C = mode->nbChannels;
1053
1054          if (st->pitch_available > 0) st->pitch_available = 1;
1055
1056          CELT_MEMSET(st->in_mem, 0, st->overlap*C);
1057          CELT_MEMSET(st->out_mem, 0, (MAX_PERIOD+st->overlap)*C);
1058
1059          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1060
1061          CELT_MEMSET(st->preemph_memE, 0, C);
1062          CELT_MEMSET(st->preemph_memD, 0, C);
1063          st->delayedIntra = 1;
1064       }
1065       break;
1066       default:
1067          goto bad_request;
1068    }
1069    va_end(ap);
1070    return CELT_OK;
1071 bad_mode:
1072   va_end(ap);
1073   return CELT_INVALID_MODE;
1074 bad_arg:
1075    va_end(ap);
1076    return CELT_BAD_ARG;
1077 bad_request:
1078    va_end(ap);
1079    return CELT_UNIMPLEMENTED;
1080 }
1081
1082 /****************************************************************************/
1083 /*                                                                          */
1084 /*                                DECODER                                   */
1085 /*                                                                          */
1086 /****************************************************************************/
1087 #ifdef NEW_PLC
1088 #define DECODE_BUFFER_SIZE 2048
1089 #else
1090 #define DECODE_BUFFER_SIZE MAX_PERIOD
1091 #endif
1092
1093 #define DECODERVALID   0x4c434454
1094 #define DECODERPARTIAL 0x5444434c
1095 #define DECODERFREED   0x4c004400
1096
1097 /** Decoder state 
1098  @brief Decoder state
1099  */
1100 struct CELTDecoder {
1101    celt_uint32_t marker;
1102    const CELTMode *mode;
1103    int frame_size;
1104    int block_size;
1105    int overlap;
1106
1107    ec_byte_buffer buf;
1108    ec_enc         enc;
1109
1110    celt_sig_t * restrict preemph_memD;
1111
1112    celt_sig_t *out_mem;
1113    celt_sig_t *decode_mem;
1114
1115    celt_word16_t *oldBandE;
1116    
1117    int last_pitch_index;
1118 };
1119
1120 int check_decoder(const CELTDecoder *st) 
1121 {
1122    if (st==NULL)
1123    {
1124       celt_warning("NULL passed a decoder structure");  
1125       return CELT_INVALID_STATE;
1126    }
1127    if (st->marker == DECODERVALID)
1128       return CELT_OK;
1129    if (st->marker == DECODERFREED)
1130       celt_warning("Referencing a decoder that has already been freed");
1131    else
1132       celt_warning("This is not a valid CELT decoder structure");
1133    return CELT_INVALID_STATE;
1134 }
1135
1136 CELTDecoder *celt_decoder_create(const CELTMode *mode)
1137 {
1138    int N, C;
1139    CELTDecoder *st;
1140
1141    if (check_mode(mode) != CELT_OK)
1142       return NULL;
1143
1144    N = mode->mdctSize;
1145    C = CHANNELS(mode);
1146    st = celt_alloc(sizeof(CELTDecoder));
1147
1148    if (st==NULL)
1149       return NULL;
1150    
1151    st->marker = DECODERPARTIAL;
1152    st->mode = mode;
1153    st->frame_size = N;
1154    st->block_size = N;
1155    st->overlap = mode->overlap;
1156
1157    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
1158    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
1159    
1160    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
1161    
1162    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
1163
1164    st->last_pitch_index = 0;
1165
1166    if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
1167        (st->preemph_memD!=NULL))
1168    {
1169       st->marker = DECODERVALID;
1170       return st;
1171    }
1172    /* If the setup fails for some reason deallocate it. */
1173    celt_decoder_destroy(st);
1174    return NULL;
1175 }
1176
1177 void celt_decoder_destroy(CELTDecoder *st)
1178 {
1179    if (st == NULL)
1180    {
1181       celt_warning("NULL passed to celt_decoder_destroy");
1182       return;
1183    }
1184
1185    if (st->marker == DECODERFREED) 
1186    {
1187       celt_warning("Freeing a decoder which has already been freed"); 
1188       return;
1189    }
1190    
1191    if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
1192    {
1193       celt_warning("This is not a valid CELT decoder structure");
1194       return;
1195    }
1196    
1197    /*Check_mode is non-fatal here because we can still free
1198      the encoder memory even if the mode is bad, although calling
1199      the free functions in this order is a violation of the API.*/
1200    check_mode(st->mode);
1201    
1202    celt_free(st->decode_mem);
1203    celt_free(st->oldBandE);
1204    celt_free(st->preemph_memD);
1205    
1206    st->marker = DECODERFREED;
1207    
1208    celt_free(st);
1209 }
1210
1211 /** Handles lost packets by just copying past data with the same offset as the last
1212     pitch period */
1213 #ifdef NEW_PLC
1214 #include "plc.c"
1215 #else
1216 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
1217 {
1218    int c, N;
1219    int pitch_index;
1220    int i, len;
1221    VARDECL(celt_sig_t, freq);
1222    const int C = CHANNELS(st->mode);
1223    int offset;
1224    SAVE_STACK;
1225    N = st->block_size;
1226    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
1227    
1228    len = N+st->mode->overlap;
1229 #if 0
1230    pitch_index = st->last_pitch_index;
1231    
1232    /* Use the pitch MDCT as the "guessed" signal */
1233    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
1234
1235 #else
1236    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
1237    pitch_index = MAX_PERIOD-len-pitch_index;
1238    offset = MAX_PERIOD-pitch_index;
1239    while (offset+len >= MAX_PERIOD)
1240       offset -= pitch_index;
1241    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
1242    for (i=0;i<N;i++)
1243       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
1244 #endif
1245    
1246    
1247    
1248    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
1249    /* Compute inverse MDCTs */
1250    compute_inv_mdcts(st->mode, 0, freq, -1, 0, st->out_mem);
1251
1252    for (c=0;c<C;c++)
1253    {
1254       int j;
1255       for (j=0;j<N;j++)
1256       {
1257          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1258                                 preemph,st->preemph_memD[c]);
1259          st->preemph_memD[c] = tmp;
1260          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1261       }
1262    }
1263    RESTORE_STACK;
1264 }
1265 #endif
1266
1267 #ifdef FIXED_POINT
1268 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1269 {
1270 #else
1271 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
1272 {
1273 #endif
1274    int i, c, N, N4;
1275    int has_pitch, has_fold;
1276    int pitch_index;
1277    int bits;
1278    ec_dec dec;
1279    ec_byte_buffer buf;
1280    VARDECL(celt_sig_t, freq);
1281    VARDECL(celt_norm_t, X);
1282    VARDECL(celt_norm_t, P);
1283    VARDECL(celt_ener_t, bandE);
1284    VARDECL(celt_pgain_t, gains);
1285    VARDECL(int, fine_quant);
1286    VARDECL(int, pulses);
1287    VARDECL(int, offsets);
1288
1289    int shortBlocks;
1290    int intra_ener;
1291    int transient_time;
1292    int transient_shift;
1293    int mdct_weight_shift=0;
1294    const int C = CHANNELS(st->mode);
1295    int mdct_weight_pos=0;
1296    SAVE_STACK;
1297
1298    if (check_decoder(st) != CELT_OK)
1299       return CELT_INVALID_STATE;
1300
1301    if (check_mode(st->mode) != CELT_OK)
1302       return CELT_INVALID_MODE;
1303
1304    N = st->block_size;
1305    N4 = (N-st->overlap)>>1;
1306
1307    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1308    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
1309    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
1310    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1311    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
1312    
1313    if (data == NULL)
1314    {
1315       celt_decode_lost(st, pcm);
1316       RESTORE_STACK;
1317       return 0;
1318    }
1319    if (len<0) {
1320      RESTORE_STACK;
1321      return CELT_BAD_ARG;
1322    }
1323    
1324    ec_byte_readinit(&buf,(unsigned char*)data,len);
1325    ec_dec_init(&dec,&buf);
1326    
1327    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1328    if (shortBlocks)
1329    {
1330       transient_shift = ec_dec_bits(&dec, 2);
1331       if (transient_shift == 3)
1332       {
1333          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1334       } else {
1335          mdct_weight_shift = transient_shift;
1336          if (mdct_weight_shift && st->mode->nbShortMdcts>2)
1337             mdct_weight_pos = ec_dec_uint(&dec, st->mode->nbShortMdcts-1);
1338          transient_shift = 0;
1339          transient_time = 0;
1340       }
1341    } else {
1342       transient_time = -1;
1343       transient_shift = 0;
1344    }
1345    
1346    if (has_pitch)
1347    {
1348       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1349       st->last_pitch_index = pitch_index;
1350    } else {
1351       pitch_index = 0;
1352       for (i=0;i<st->mode->nbPBands;i++)
1353          gains[i] = 0;
1354    }
1355
1356    ALLOC(fine_quant, st->mode->nbEBands, int);
1357    /* Get band energies */
1358    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, intra_ener, st->mode->prob, &dec);
1359    
1360    ALLOC(pulses, st->mode->nbEBands, int);
1361    ALLOC(offsets, st->mode->nbEBands, int);
1362
1363    for (i=0;i<st->mode->nbEBands;i++)
1364       offsets[i] = 0;
1365
1366    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1367    if (has_pitch)
1368       bits -= st->mode->nbPBands;
1369    compute_allocation(st->mode, offsets, bits, pulses, fine_quant);
1370    /*bits = ec_dec_tell(&dec, 0);
1371    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1372    
1373    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1374
1375
1376    if (has_pitch) 
1377    {
1378       VARDECL(celt_ener_t, bandEp);
1379       
1380       /* Pitch MDCT */
1381       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1382       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1383       compute_band_energies(st->mode, freq, bandEp);
1384       normalise_bands(st->mode, freq, P, bandEp);
1385       /* Apply pitch gains */
1386    } else {
1387       for (i=0;i<C*N;i++)
1388          P[i] = 0;
1389    }
1390
1391    /* Decode fixed codebook and merge with pitch */
1392    if (C==1)
1393       unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1394 #ifndef DISABLE_STEREO
1395    else
1396       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1397 #endif
1398    /* Synthesis */
1399    denormalise_bands(st->mode, X, freq, bandE);
1400
1401
1402    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1403    if (mdct_weight_shift)
1404    {
1405       int m;
1406       for (c=0;c<C;c++)
1407          for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
1408             for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
1409 #ifdef FIXED_POINT
1410                freq[i] = SHL32(freq[i], mdct_weight_shift);
1411 #else
1412                freq[i] = (1<<mdct_weight_shift)*freq[i];
1413 #endif
1414    }
1415    /* Compute inverse MDCTs */
1416    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1417
1418    for (c=0;c<C;c++)
1419    {
1420       int j;
1421       for (j=0;j<N;j++)
1422       {
1423          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1424                                 preemph,st->preemph_memD[c]);
1425          st->preemph_memD[c] = tmp;
1426          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1427       }
1428    }
1429
1430    {
1431       unsigned int val = 0;
1432       while (ec_dec_tell(&dec, 0) < len*8)
1433       {
1434          if (ec_dec_uint(&dec, 2) != val)
1435          {
1436             celt_warning("decode error");
1437             RESTORE_STACK;
1438             return CELT_CORRUPTED_DATA;
1439          }
1440          val = 1-val;
1441       }
1442    }
1443
1444    RESTORE_STACK;
1445    return 0;
1446    /*printf ("\n");*/
1447 }
1448
1449 #ifdef FIXED_POINT
1450 #ifndef DISABLE_FLOAT_API
1451 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1452 {
1453    int j, ret, C, N;
1454    VARDECL(celt_int16_t, out);
1455
1456    if (check_decoder(st) != CELT_OK)
1457       return CELT_INVALID_STATE;
1458
1459    if (check_mode(st->mode) != CELT_OK)
1460       return CELT_INVALID_MODE;
1461
1462    SAVE_STACK;
1463    C = CHANNELS(st->mode);
1464    N = st->block_size;
1465    ALLOC(out, C*N, celt_int16_t);
1466
1467    ret=celt_decode(st, data, len, out);
1468
1469    for (j=0;j<C*N;j++)
1470      pcm[j]=out[j]*(1/32768.);
1471    RESTORE_STACK;
1472    return ret;
1473 }
1474 #endif /*DISABLE_FLOAT_API*/
1475 #else
1476 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1477 {
1478    int j, ret, C, N;
1479    VARDECL(celt_sig_t, out);
1480
1481    if (check_decoder(st) != CELT_OK)
1482       return CELT_INVALID_STATE;
1483
1484    if (check_mode(st->mode) != CELT_OK)
1485       return CELT_INVALID_MODE;
1486
1487    SAVE_STACK;
1488    C = CHANNELS(st->mode);
1489    N = st->block_size;
1490    ALLOC(out, C*N, celt_sig_t);
1491
1492    ret=celt_decode_float(st, data, len, out);
1493
1494    for (j=0;j<C*N;j++)
1495      pcm[j] = FLOAT2INT16 (out[j]);
1496
1497    RESTORE_STACK;
1498    return ret;
1499 }
1500 #endif
1501
1502 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
1503 {
1504    va_list ap;
1505
1506    if (check_decoder(st) != CELT_OK)
1507       return CELT_INVALID_STATE;
1508
1509    va_start(ap, request);
1510    if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
1511      goto bad_mode;
1512    switch (request)
1513    {
1514       case CELT_GET_MODE_REQUEST:
1515       {
1516          const CELTMode ** value = va_arg(ap, const CELTMode**);
1517          if (value==0)
1518             goto bad_arg;
1519          *value=st->mode;
1520       }
1521       break;
1522       case CELT_RESET_STATE:
1523       {
1524          const CELTMode *mode = st->mode;
1525          int C = mode->nbChannels;
1526
1527          CELT_MEMSET(st->decode_mem, 0, (DECODE_BUFFER_SIZE+st->overlap)*C);
1528          CELT_MEMSET(st->oldBandE, 0, C*mode->nbEBands);
1529
1530          CELT_MEMSET(st->preemph_memD, 0, C);
1531
1532          st->last_pitch_index = 0;
1533       }
1534       break;
1535       default:
1536          goto bad_request;
1537    }
1538    va_end(ap);
1539    return CELT_OK;
1540 bad_mode:
1541   va_end(ap);
1542   return CELT_INVALID_MODE;
1543 bad_arg:
1544    va_end(ap);
1545    return CELT_BAD_ARG;
1546 bad_request:
1547       va_end(ap);
1548   return CELT_UNIMPLEMENTED;
1549 }