fix stack handling
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67    
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int frame_size;
74    int block_size;
75    int overlap;
76    int channels;
77    
78    int pitch_enabled;
79    
80    ec_byte_buffer buf;
81    ec_enc         enc;
82
83    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
84    celt_sig_t    * restrict preemph_memD;
85
86    celt_sig_t *in_mem;
87    celt_sig_t *out_mem;
88
89    celt_word16_t *oldBandE;
90 #ifdef EXP_PSY
91    celt_word16_t *psy_mem;
92    struct PsyDecay psy;
93 #endif
94 };
95
96 CELTEncoder *celt_encoder_create(const CELTMode *mode)
97 {
98    int N, C;
99    CELTEncoder *st;
100
101    if (check_mode(mode) != CELT_OK)
102       return NULL;
103
104    N = mode->mdctSize;
105    C = mode->nbChannels;
106    st = celt_alloc(sizeof(CELTEncoder));
107    
108    st->mode = mode;
109    st->frame_size = N;
110    st->block_size = N;
111    st->overlap = mode->overlap;
112
113    st->pitch_enabled = 1;
114    
115    ec_byte_writeinit(&st->buf);
116    ec_enc_init(&st->enc,&st->buf);
117
118    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
119    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
120
121    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
122
123    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
124    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
125
126 #ifdef EXP_PSY
127    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
128    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
129 #endif
130
131    return st;
132 }
133
134 void celt_encoder_destroy(CELTEncoder *st)
135 {
136    if (st == NULL)
137    {
138       celt_warning("NULL passed to celt_encoder_destroy");
139       return;
140    }
141    if (check_mode(st->mode) != CELT_OK)
142       return;
143
144    ec_byte_writeclear(&st->buf);
145
146    celt_free(st->in_mem);
147    celt_free(st->out_mem);
148    
149    celt_free(st->oldBandE);
150    
151    celt_free(st->preemph_memE);
152    celt_free(st->preemph_memD);
153    
154 #ifdef EXP_PSY
155    celt_free (st->psy_mem);
156    psydecay_clear(&st->psy);
157 #endif
158    
159    celt_free(st);
160 }
161
162 static inline celt_int16_t FLOAT2INT16(float x)
163 {
164    x = x*32768.;
165    x = MAX32(x, -32768);
166    x = MIN32(x, 32767);
167    return (celt_int16_t)float2int(x);
168 }
169
170 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
171 {
172 #ifdef FIXED_POINT
173    x = PSHR32(x, SIG_SHIFT);
174    x = MAX32(x, -32768);
175    x = MIN32(x, 32767);
176    return EXTRACT16(x);
177 #else
178    return (celt_word16_t)x;
179 #endif
180 }
181
182 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
183 {
184    int c, i, n;
185    celt_word32_t ratio;
186    /* FIXME: Remove the floats here */
187    VARDECL(celt_word32_t, begin);
188    SAVE_STACK;
189    ALLOC(begin, len, celt_word32_t);
190    for (i=0;i<len;i++)
191       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
192    for (c=1;c<C;c++)
193    {
194       for (i=0;i<len;i++)
195          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
196    }
197    for (i=1;i<len;i++)
198       begin[i] = MAX32(begin[i-1],begin[i]);
199    n = -1;
200    for (i=8;i<len-8;i++)
201    {
202       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
203          n=i;
204    }
205    if (n<32)
206    {
207       n = -1;
208       ratio = 0;
209    } else {
210       ratio = DIV32(begin[len-1],1+begin[n-16]);
211    }
212    /*printf ("%d %f\n", n, ratio*ratio);*/
213    if (ratio < 0)
214       ratio = 0;
215    if (ratio > 1000)
216       ratio = 1000;
217    ratio *= ratio;
218    if (ratio < 50)
219       *transient_shift = 0;
220    else if (ratio < 256)
221       *transient_shift = 1;
222    else if (ratio < 4096)
223       *transient_shift = 2;
224    else
225       *transient_shift = 3;
226    *transient_time = n;
227    
228    RESTORE_STACK;
229    return ratio > 20;
230 }
231
232 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
233 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
234 {
235    const int C = CHANNELS(mode);
236    if (C==1 && !shortBlocks)
237    {
238       const mdct_lookup *lookup = MDCT(mode);
239       const int overlap = OVERLAP(mode);
240       mdct_forward(lookup, in, out, mode->window, overlap);
241    } else if (!shortBlocks) {
242       const mdct_lookup *lookup = MDCT(mode);
243       const int overlap = OVERLAP(mode);
244       const int N = FRAMESIZE(mode);
245       int c;
246       VARDECL(celt_word32_t, x);
247       VARDECL(celt_word32_t, tmp);
248       SAVE_STACK;
249       ALLOC(x, N+overlap, celt_word32_t);
250       ALLOC(tmp, N, celt_word32_t);
251       for (c=0;c<C;c++)
252       {
253          int j;
254          for (j=0;j<N+overlap;j++)
255             x[j] = in[C*j+c];
256          mdct_forward(lookup, x, tmp, mode->window, overlap);
257          /* Interleaving the sub-frames */
258          for (j=0;j<N;j++)
259             out[C*j+c] = tmp[j];
260       }
261       RESTORE_STACK;
262    } else {
263       const mdct_lookup *lookup = &mode->shortMdct;
264       const int overlap = mode->overlap;
265       const int N = mode->shortMdctSize;
266       int b, c;
267       VARDECL(celt_word32_t, x);
268       VARDECL(celt_word32_t, tmp);
269       SAVE_STACK;
270       ALLOC(x, N+overlap, celt_word32_t);
271       ALLOC(tmp, N, celt_word32_t);
272       for (c=0;c<C;c++)
273       {
274          int B = mode->nbShortMdcts;
275          for (b=0;b<B;b++)
276          {
277             int j;
278             for (j=0;j<N+overlap;j++)
279                x[j] = in[C*(b*N+j)+c];
280             mdct_forward(lookup, x, tmp, mode->window, overlap);
281             /* Interleaving the sub-frames */
282             for (j=0;j<N;j++)
283                out[C*(j*B+b)+c] = tmp[j];
284          }
285       }
286       RESTORE_STACK;
287    }
288 }
289
290 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
291 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
292 {
293    int c, N4;
294    const int C = CHANNELS(mode);
295    const int N = FRAMESIZE(mode);
296    const int overlap = OVERLAP(mode);
297    N4 = (N-overlap)>>1;
298    for (c=0;c<C;c++)
299    {
300       int j;
301       if (transient_shift==0 && C==1 && !shortBlocks) {
302          const mdct_lookup *lookup = MDCT(mode);
303          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
304       } else if (!shortBlocks) {
305          const mdct_lookup *lookup = MDCT(mode);
306          VARDECL(celt_word32_t, x);
307          VARDECL(celt_word32_t, tmp);
308          SAVE_STACK;
309          ALLOC(x, 2*N, celt_word32_t);
310          ALLOC(tmp, N, celt_word32_t);
311          /* De-interleaving the sub-frames */
312          for (j=0;j<N;j++)
313             tmp[j] = X[C*j+c];
314          /* Prevents problems from the imdct doing the overlap-add */
315          CELT_MEMSET(x+N4, 0, N);
316          mdct_backward(lookup, tmp, x, mode->window, overlap);
317          celt_assert(transient_shift == 0);
318          /* The first and last part would need to be set to zero if we actually
319             wanted to use them. */
320          for (j=0;j<overlap;j++)
321             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
322          for (j=0;j<overlap;j++)
323             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
324          for (j=0;j<2*N4;j++)
325             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
326          RESTORE_STACK;
327       } else {
328          int b;
329          const int N2 = mode->shortMdctSize;
330          const int B = mode->nbShortMdcts;
331          const mdct_lookup *lookup = &mode->shortMdct;
332          VARDECL(celt_word32_t, x);
333          VARDECL(celt_word32_t, tmp);
334          SAVE_STACK;
335          ALLOC(x, 2*N, celt_word32_t);
336          ALLOC(tmp, N, celt_word32_t);
337          /* Prevents problems from the imdct doing the overlap-add */
338          CELT_MEMSET(x+N4, 0, N2);
339          for (b=0;b<B;b++)
340          {
341             /* De-interleaving the sub-frames */
342             for (j=0;j<N2;j++)
343                tmp[j] = X[C*(j*B+b)+c];
344             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
345          }
346          if (transient_shift > 0)
347          {
348 #ifdef FIXED_POINT
349             for (j=0;j<16;j++)
350                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
351             for (j=transient_time;j<N+overlap;j++)
352                x[N4+j] = SHL32(x[N4+j], transient_shift);
353 #else
354             for (j=0;j<16;j++)
355                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
356             for (j=transient_time;j<N+overlap;j++)
357                x[N4+j] *= 1<<transient_shift;
358 #endif
359          }
360          /* The first and last part would need to be set to zero if we actually
361          wanted to use them. */
362          for (j=0;j<overlap;j++)
363             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
364          for (j=0;j<overlap;j++)
365             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
366          for (j=0;j<2*N4;j++)
367             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
368          RESTORE_STACK;
369       }
370    }
371 }
372
373 #ifdef FIXED_POINT
374 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
375 {
376 #else
377 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
378 {
379 #endif
380    int i, c, N, N4;
381    int has_pitch;
382    int pitch_index;
383    int bits;
384    celt_word32_t curr_power, pitch_power=0;
385    VARDECL(celt_sig_t, in);
386    VARDECL(celt_sig_t, freq);
387    VARDECL(celt_norm_t, X);
388    VARDECL(celt_norm_t, P);
389    VARDECL(celt_ener_t, bandE);
390    VARDECL(celt_pgain_t, gains);
391    VARDECL(int, stereo_mode);
392    VARDECL(int, fine_quant);
393    VARDECL(celt_word16_t, error);
394    VARDECL(int, pulses);
395    VARDECL(int, offsets);
396 #ifdef EXP_PSY
397    VARDECL(celt_word32_t, mask);
398 #endif
399    int shortBlocks=0;
400    int transient_time;
401    int transient_shift;
402    const int C = CHANNELS(st->mode);
403    SAVE_STACK;
404
405    if (check_mode(st->mode) != CELT_OK)
406       return CELT_INVALID_MODE;
407
408    N = st->block_size;
409    N4 = (N-st->overlap)>>1;
410    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
411
412    CELT_COPY(in, st->in_mem, C*st->overlap);
413    for (c=0;c<C;c++)
414    {
415       const celt_word16_t * restrict pcmp = pcm+c;
416       celt_sig_t * restrict inp = in+C*st->overlap+c;
417       for (i=0;i<N;i++)
418       {
419          /* Apply pre-emphasis */
420          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
421          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
422          st->preemph_memE[c] = SCALEIN(*pcmp);
423          inp += C;
424          pcmp += C;
425       }
426    }
427    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
428    
429    if (st->mode->nbShortMdcts > 1)
430    {
431       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
432       {
433 #ifndef FIXED_POINT
434          float gain_1;
435 #endif
436          ec_enc_bits(&st->enc, 0, 1); //Pitch off
437          ec_enc_bits(&st->enc, 1, 1); //Transient on
438          ec_enc_bits(&st->enc, transient_shift, 2);
439          if (transient_shift)
440             ec_enc_uint(&st->enc, transient_time, N+st->overlap);
441          if (transient_shift)
442          {
443 #ifdef FIXED_POINT
444             for (c=0;c<C;c++)
445                for (i=0;i<16;i++)
446                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
447             for (c=0;c<C;c++)
448                for (i=transient_time;i<N+st->overlap;i++)
449                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
450 #else
451             for (c=0;c<C;c++)
452                for (i=0;i<16;i++)
453                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
454             gain_1 = 1./(1<<transient_shift);
455             for (c=0;c<C;c++)
456                for (i=transient_time;i<N+st->overlap;i++)
457                   in[C*i+c] *= gain_1;
458 #endif
459          }
460          shortBlocks = 1;
461       } else {
462          transient_time = -1;
463          transient_shift = 0;
464          shortBlocks = 0;
465       }
466    } else {
467       transient_time = -1;
468       transient_shift = 0;
469       shortBlocks = 0;
470    }
471    /* Pitch analysis: we do it early to save on the peak stack space */
472    if (st->pitch_enabled && !shortBlocks)
473       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
474
475    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
476    
477    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
478    /* Compute MDCTs */
479    compute_mdcts(st->mode, shortBlocks, in, freq);
480
481 #ifdef EXP_PSY
482    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
483    for (i=0;i<N;i++)
484       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
485    for (c=1;c<C;c++)
486       for (i=0;i<N;i++)
487          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
488
489    ALLOC(mask, N, celt_sig_t);
490    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
491
492    /* Invert and stretch the mask to length of X 
493       For some reason, I get better results by using the sqrt instead,
494       although there's no valid reason to. Must investigate further */
495    for (i=0;i<C*N;i++)
496       mask[i] = 1/(.1+mask[i]);
497 #endif
498    
499    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
500    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
501    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
502    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
503    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
504
505    /*printf ("%f %f\n", curr_power, pitch_power);*/
506    /*int j;
507    for (j=0;j<B*N;j++)
508       printf ("%f ", X[j]);
509    for (j=0;j<B*N;j++)
510       printf ("%f ", P[j]);
511    printf ("\n");*/
512
513    /* Band normalisation */
514    compute_band_energies(st->mode, freq, bandE);
515    normalise_bands(st->mode, freq, X, bandE);
516    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
517    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
518
519    /* Compute MDCTs of the pitch part */
520    if (st->pitch_enabled && !shortBlocks)
521    {
522       /* Normalise the pitch vector as well (discard the energies) */
523       VARDECL(celt_ener_t, bandEp);
524       
525       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
526       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
527       compute_band_energies(st->mode, freq, bandEp);
528       normalise_bands(st->mode, freq, P, bandEp);
529       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
530    }
531    curr_power = bandE[0]+bandE[1]+bandE[2];
532    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
533    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
534    {
535       /* Simulates intensity stereo */
536       /*for (i=30;i<N*B;i++)
537          X[i*C+1] = P[i*C+1] = 0;*/
538
539       /* Pitch prediction */
540       compute_pitch_gain(st->mode, X, P, gains);
541       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
542       if (has_pitch)
543          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
544       else if (st->mode->nbShortMdcts > 1)
545          ec_enc_bits(&st->enc, 0, 1); //Transient off
546    } else {
547       if (!shortBlocks)
548       {
549          ec_enc_bits(&st->enc, 0, 1); //Pitch off
550          if (st->mode->nbShortMdcts > 1)
551            ec_enc_bits(&st->enc, 0, 1); //Transient off
552       }
553       /* No pitch, so we just pretend we found a gain of zero */
554       for (i=0;i<st->mode->nbPBands;i++)
555          gains[i] = 0;
556       for (i=0;i<C*N;i++)
557          P[i] = 0;
558    }
559
560 #ifdef STDIN_TUNING2
561    static int fine_quant[30];
562    static int pulses[30];
563    static int init=0;
564    if (!init)
565    {
566       for (i=0;i<st->mode->nbEBands;i++)
567          scanf("%d ", &fine_quant[i]);
568       for (i=0;i<st->mode->nbEBands;i++)
569          scanf("%d ", &pulses[i]);
570       init = 1;
571    }
572 #else
573    ALLOC(fine_quant, st->mode->nbEBands, int);
574    ALLOC(pulses, st->mode->nbEBands, int);
575 #endif
576    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
577    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &st->enc);
578    
579    ALLOC(offsets, st->mode->nbEBands, int);
580    ALLOC(stereo_mode, st->mode->nbEBands, int);
581    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
582
583    for (i=0;i<st->mode->nbEBands;i++)
584       offsets[i] = 0;
585    bits = nbCompressedBytes*8 - ec_enc_tell(&st->enc, 0) - 1;
586 #ifndef STDIN_TUNING
587    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
588 #endif
589    /*for (i=0;i<st->mode->nbEBands;i++)
590       printf("%d ", fine_quant[i]);
591    for (i=0;i<st->mode->nbEBands;i++)
592       printf("%d ", pulses[i]);
593    printf ("\n");*/
594    /*bits = ec_enc_tell(&st->enc, 0);
595    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
596    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
597
598    pitch_quant_bands(st->mode, P, gains);
599
600    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
601
602    /* Residual quantisation */
603    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &st->enc);
604    
605    if (st->pitch_enabled || optional_synthesis!=NULL)
606    {
607       if (C==2)
608          renormalise_bands(st->mode, X);
609       /* Synthesis */
610       denormalise_bands(st->mode, X, freq, bandE);
611       
612       
613       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
614       
615       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
616       /* De-emphasis and put everything back at the right place in the synthesis history */
617       if (optional_synthesis != NULL) {
618          for (c=0;c<C;c++)
619          {
620             int j;
621             for (j=0;j<N;j++)
622             {
623                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
624                                    preemph,st->preemph_memD[c]);
625                st->preemph_memD[c] = tmp;
626                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
627             }
628          }
629       }
630    }
631    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
632    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
633       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
634    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
635    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
636    {
637       int val = 0;
638       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
639       {
640          ec_enc_uint(&st->enc, val, 2);
641          val = 1-val;
642       }
643    }
644    ec_enc_done(&st->enc);
645    {
646       unsigned char *data;
647       int nbBytes = ec_byte_bytes(&st->buf);
648       if (nbBytes > nbCompressedBytes)
649       {
650          celt_warning_int ("got too many bytes:", nbBytes);
651          RESTORE_STACK;
652          return CELT_INTERNAL_ERROR;
653       }
654       /*printf ("%d\n", *nbBytes);*/
655       data = ec_byte_get_buffer(&st->buf);
656       for (i=0;i<nbBytes;i++)
657          compressed[i] = data[i];
658       for (;i<nbCompressedBytes;i++)
659          compressed[i] = 0;
660    }
661    /* Reset the packing for the next encoding */
662    ec_byte_reset(&st->buf);
663    ec_enc_init(&st->enc,&st->buf);
664
665    RESTORE_STACK;
666    return nbCompressedBytes;
667 }
668
669 #ifdef FIXED_POINT
670 #ifndef DISABLE_FLOAT_API
671 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
672 {
673    int j, ret;
674    const int C = CHANNELS(st->mode);
675    const int N = st->block_size;
676    VARDECL(celt_int16_t, in);
677    SAVE_STACK;
678    ALLOC(in, C*N, celt_int16_t);
679
680    for (j=0;j<C*N;j++)
681      in[j] = FLOAT2INT16(pcm[j]);
682
683    if (optional_synthesis != NULL) {
684      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
685    /*Converts backwards for inplace operation*/
686       for (j=0;j=C*N;j++)
687          optional_synthesis[j]=in[j]*(1/32768.);
688    } else {
689      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
690    }
691    RESTORE_STACK;
692    return ret;
693
694 }
695 #endif /*DISABLE_FLOAT_API*/
696 #else
697 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
698 {
699    int j, ret;
700    VARDECL(celt_sig_t, in);
701    const int C = CHANNELS(st->mode);
702    const int N = st->block_size;
703    SAVE_STACK;
704    ALLOC(in, C*N, celt_sig_t);
705    for (j=0;j<C*N;j++) {
706      in[j] = SCALEOUT(pcm[j]);
707    }
708
709    if (optional_synthesis != NULL) {
710       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
711       for (j=0;j<C*N;j++)
712          optional_synthesis[j] = FLOAT2INT16(in[j]);
713    } else {
714       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
715    }
716    RESTORE_STACK;
717    return ret;
718 }
719 #endif
720
721 int celt_encoder_ctl(CELTEncoder * restrict st, int request, celt_int32_t *value)
722 {
723    switch (request)
724    {
725       case CELT_SET_COMPLEXITY:
726       {
727          if (*value<0 || *value>10)
728             return CELT_BAD_ARG;
729          if (*value<=2)
730             st->pitch_enabled = 0;
731          else
732             st->pitch_enabled = 1;
733       }
734       break;
735       default:
736          return CELT_BAD_REQUEST;
737    }
738    return CELT_OK;
739 }
740
741 /****************************************************************************/
742 /*                                                                          */
743 /*                                DECODER                                   */
744 /*                                                                          */
745 /****************************************************************************/
746
747
748 /** Decoder state 
749  @brief Decoder state
750  */
751 struct CELTDecoder {
752    const CELTMode *mode;
753    int frame_size;
754    int block_size;
755    int overlap;
756
757    ec_byte_buffer buf;
758    ec_enc         enc;
759
760    celt_sig_t * restrict preemph_memD;
761
762    celt_sig_t *out_mem;
763
764    celt_word16_t *oldBandE;
765    
766    int last_pitch_index;
767 };
768
769 CELTDecoder *celt_decoder_create(const CELTMode *mode)
770 {
771    int N, C;
772    CELTDecoder *st;
773
774    if (check_mode(mode) != CELT_OK)
775       return NULL;
776
777    N = mode->mdctSize;
778    C = CHANNELS(mode);
779    st = celt_alloc(sizeof(CELTDecoder));
780    
781    st->mode = mode;
782    st->frame_size = N;
783    st->block_size = N;
784    st->overlap = mode->overlap;
785
786    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
787    
788    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
789
790    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
791
792    st->last_pitch_index = 0;
793    return st;
794 }
795
796 void celt_decoder_destroy(CELTDecoder *st)
797 {
798    if (st == NULL)
799    {
800       celt_warning("NULL passed to celt_encoder_destroy");
801       return;
802    }
803    if (check_mode(st->mode) != CELT_OK)
804       return;
805
806
807    celt_free(st->out_mem);
808    
809    celt_free(st->oldBandE);
810    
811    celt_free(st->preemph_memD);
812
813    celt_free(st);
814 }
815
816 /** Handles lost packets by just copying past data with the same offset as the last
817     pitch period */
818 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
819 {
820    int c, N;
821    int pitch_index;
822    int i, len;
823    VARDECL(celt_sig_t, freq);
824    const int C = CHANNELS(st->mode);
825    int offset;
826    SAVE_STACK;
827    N = st->block_size;
828    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
829    
830    len = N+st->mode->overlap;
831 #if 0
832    pitch_index = st->last_pitch_index;
833    
834    /* Use the pitch MDCT as the "guessed" signal */
835    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
836
837 #else
838    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
839    pitch_index = MAX_PERIOD-len-pitch_index;
840    offset = MAX_PERIOD-pitch_index;
841    while (offset+len >= MAX_PERIOD)
842       offset -= pitch_index;
843    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
844    for (i=0;i<N;i++)
845       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
846 #endif
847    
848    
849    
850    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
851    /* Compute inverse MDCTs */
852    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
853
854    for (c=0;c<C;c++)
855    {
856       int j;
857       for (j=0;j<N;j++)
858       {
859          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
860                                 preemph,st->preemph_memD[c]);
861          st->preemph_memD[c] = tmp;
862          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
863       }
864    }
865    RESTORE_STACK;
866 }
867
868 #ifdef FIXED_POINT
869 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
870 {
871 #else
872 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
873 {
874 #endif
875    int i, c, N, N4;
876    int has_pitch, has_fold;
877    int pitch_index;
878    int bits;
879    ec_dec dec;
880    ec_byte_buffer buf;
881    VARDECL(celt_sig_t, freq);
882    VARDECL(celt_norm_t, X);
883    VARDECL(celt_norm_t, P);
884    VARDECL(celt_ener_t, bandE);
885    VARDECL(celt_pgain_t, gains);
886    VARDECL(int, stereo_mode);
887    VARDECL(int, fine_quant);
888    VARDECL(int, pulses);
889    VARDECL(int, offsets);
890
891    int shortBlocks;
892    int transient_time;
893    int transient_shift;
894    const int C = CHANNELS(st->mode);
895    SAVE_STACK;
896
897    if (check_mode(st->mode) != CELT_OK)
898       return CELT_INVALID_MODE;
899
900    N = st->block_size;
901    N4 = (N-st->overlap)>>1;
902
903    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
904    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
905    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
906    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
907    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
908    
909    if (check_mode(st->mode) != CELT_OK)
910    {
911       RESTORE_STACK;
912       return CELT_INVALID_MODE;
913    }
914    if (data == NULL)
915    {
916       celt_decode_lost(st, pcm);
917       RESTORE_STACK;
918       return 0;
919    }
920    
921    ec_byte_readinit(&buf,data,len);
922    ec_dec_init(&dec,&buf);
923    
924    has_pitch = ec_dec_bits(&dec, 1);
925    if (has_pitch)
926    {
927       has_fold = ec_dec_bits(&dec, 1);
928       shortBlocks = 0;
929    } else if (st->mode->nbShortMdcts > 1){
930       shortBlocks = ec_dec_bits(&dec, 1);
931       has_fold = 1;
932    } else {
933       shortBlocks = 0;
934       has_fold = 1;
935    }
936    if (shortBlocks)
937    {
938       transient_shift = ec_dec_bits(&dec, 2);
939       if (transient_shift)
940          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
941       else
942          transient_time = 0;
943    } else {
944       transient_time = -1;
945       transient_shift = 0;
946    }
947    /* Get the pitch gains */
948    
949    /* Get the pitch index */
950    if (has_pitch)
951    {
952       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
953       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
954       st->last_pitch_index = pitch_index;
955    } else {
956       /* FIXME: We could be more intelligent here and just not compute the MDCT */
957       pitch_index = 0;
958       for (i=0;i<st->mode->nbPBands;i++)
959          gains[i] = 0;
960    }
961
962    ALLOC(fine_quant, st->mode->nbEBands, int);
963    /* Get band energies */
964    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
965    
966    ALLOC(pulses, st->mode->nbEBands, int);
967    ALLOC(offsets, st->mode->nbEBands, int);
968    ALLOC(stereo_mode, st->mode->nbEBands, int);
969    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
970
971    for (i=0;i<st->mode->nbEBands;i++)
972       offsets[i] = 0;
973
974    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
975    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
976    /*bits = ec_dec_tell(&dec, 0);
977    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
978    
979    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
980
981
982    if (has_pitch) 
983    {
984       VARDECL(celt_ener_t, bandEp);
985       
986       /* Pitch MDCT */
987       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
988       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
989       compute_band_energies(st->mode, freq, bandEp);
990       normalise_bands(st->mode, freq, P, bandEp);
991    } else {
992       for (i=0;i<C*N;i++)
993          P[i] = 0;
994    }
995
996    /* Apply pitch gains */
997    pitch_quant_bands(st->mode, P, gains);
998
999    /* Decode fixed codebook and merge with pitch */
1000    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
1001
1002    if (C==2)
1003    {
1004       renormalise_bands(st->mode, X);
1005    }
1006    /* Synthesis */
1007    denormalise_bands(st->mode, X, freq, bandE);
1008
1009
1010    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1011    /* Compute inverse MDCTs */
1012    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1013
1014    for (c=0;c<C;c++)
1015    {
1016       int j;
1017       for (j=0;j<N;j++)
1018       {
1019          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1020                                 preemph,st->preemph_memD[c]);
1021          st->preemph_memD[c] = tmp;
1022          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1023       }
1024    }
1025
1026    {
1027       unsigned int val = 0;
1028       while (ec_dec_tell(&dec, 0) < len*8)
1029       {
1030          if (ec_dec_uint(&dec, 2) != val)
1031          {
1032             celt_warning("decode error");
1033             RESTORE_STACK;
1034             return CELT_CORRUPTED_DATA;
1035          }
1036          val = 1-val;
1037       }
1038    }
1039
1040    RESTORE_STACK;
1041    return 0;
1042    /*printf ("\n");*/
1043 }
1044
1045 #ifdef FIXED_POINT
1046 #ifndef DISABLE_FLOAT_API
1047 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1048 {
1049    int j, ret;
1050    const int C = CHANNELS(st->mode);
1051    const int N = st->block_size;
1052    VARDECL(celt_int16_t, out);
1053    SAVE_STACK;
1054    ALLOC(out, C*N, celt_int16_t);
1055
1056    ret=celt_decode(st, data, len, out);
1057
1058    for (j=0;j<C*N;j++)
1059      pcm[j]=out[j]*(1/32768.);
1060    RESTORE_STACK;
1061    return ret;
1062 }
1063 #endif /*DISABLE_FLOAT_API*/
1064 #else
1065 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1066 {
1067    int j, ret;
1068    VARDECL(celt_sig_t, out);
1069    const int C = CHANNELS(st->mode);
1070    const int N = st->block_size;
1071    SAVE_STACK;
1072    ALLOC(out, C*N, celt_sig_t);
1073
1074    ret=celt_decode_float(st, data, len, out);
1075
1076    for (j=0;j<C*N;j++)
1077      pcm[j] = FLOAT2INT16 (out[j]);
1078
1079    RESTORE_STACK;
1080    return ret;
1081 }
1082 #endif