Cleaned up the pre-echo avoidance code so it works when compiled as fixed-point
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52
53 static const celt_word16_t preemph = QCONST16(0.8f,15);
54
55 static const float gainWindow[16] = {
56    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
57    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
58    
59 /** Encoder state 
60  @brief Encoder state
61  */
62 struct CELTEncoder {
63    const CELTMode *mode;     /**< Mode used by the encoder */
64    int frame_size;
65    int block_size;
66    int overlap;
67    int channels;
68    
69    ec_byte_buffer buf;
70    ec_enc         enc;
71
72    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
73    celt_sig_t    * restrict preemph_memD;
74
75    celt_sig_t *in_mem;
76    celt_sig_t *out_mem;
77
78    celt_word16_t *oldBandE;
79 #ifdef EXP_PSY
80    celt_word16_t *psy_mem;
81    struct PsyDecay psy;
82 #endif
83 };
84
85 CELTEncoder *celt_encoder_create(const CELTMode *mode)
86 {
87    int N, C;
88    CELTEncoder *st;
89
90    if (check_mode(mode) != CELT_OK)
91       return NULL;
92
93    N = mode->mdctSize;
94    C = mode->nbChannels;
95    st = celt_alloc(sizeof(CELTEncoder));
96    
97    st->mode = mode;
98    st->frame_size = N;
99    st->block_size = N;
100    st->overlap = mode->overlap;
101
102    ec_byte_writeinit(&st->buf);
103    ec_enc_init(&st->enc,&st->buf);
104
105    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
106    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
107
108    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
109
110    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
111    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
112
113 #ifdef EXP_PSY
114    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
115    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
116 #endif
117
118    return st;
119 }
120
121 void celt_encoder_destroy(CELTEncoder *st)
122 {
123    if (st == NULL)
124    {
125       celt_warning("NULL passed to celt_encoder_destroy");
126       return;
127    }
128    if (check_mode(st->mode) != CELT_OK)
129       return;
130
131    ec_byte_writeclear(&st->buf);
132
133    celt_free(st->in_mem);
134    celt_free(st->out_mem);
135    
136    celt_free(st->oldBandE);
137    
138    celt_free(st->preemph_memE);
139    celt_free(st->preemph_memD);
140    
141 #ifdef EXP_PSY
142    celt_free (st->psy_mem);
143    psydecay_clear(&st->psy);
144 #endif
145    
146    celt_free(st);
147 }
148
149 static inline celt_int16_t SIG2INT16(celt_sig_t x)
150 {
151    x = PSHR32(x, SIG_SHIFT);
152    x = MAX32(x, -32768);
153    x = MIN32(x, 32767);
154 #ifdef FIXED_POINT
155    return EXTRACT16(x);
156 #else
157    return (celt_int16_t)floor(.5+x);
158 #endif
159 }
160
161 static int transient_analysis(celt_word32_t *in, int len, int C, float *r)
162 {
163    int c, i, n;
164    float ratio, maxN, maxD;
165    float x[len];
166    float begin[len], end[len];
167    
168    for (i=0;i<len;i++)
169       x[i] = in[C*i];
170    for (c=1;c<C;c++)
171    {
172       for (i=0;i<len;i++)
173          x[i] = x[i] + in[C*i+c];
174    }
175    begin[0] = x[0]*x[0];
176    for (i=1;i<len;i++)
177       begin[i] = begin[i-1]+x[i]*x[i];
178    end[len-1] = x[len-1]*x[len-1];
179    for (i=len-2;i>=0;i--)
180       end[i] = end[i+1] + x[i]*x[i];
181    maxD = VERY_LARGE32;
182    maxN = 0;
183    n = -1;
184    for (i=8;i<len-8;i++)
185    {
186       float num, den;
187       num = end[i]*i;
188       den = (1000+begin[i])*(len-i)+.01*end[i]*len;
189       if (num*maxD > den*maxN && end[i] > .05*begin[i])
190       {
191          maxN = num;
192          maxD = den;
193          n = i;
194       }
195    }
196    ratio = (end[n]*n)/((100+begin[n])*(len-n));
197    if (n<32)
198    {
199       n = -1;
200       ratio = 0;
201    }
202    *r = ratio;
203    return n;
204 }
205
206 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
207 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
208 {
209    const int C = CHANNELS(mode);
210    if (C==1 && !shortBlocks)
211    {
212       const mdct_lookup *lookup = MDCT(mode);
213       const int overlap = OVERLAP(mode);
214       mdct_forward(lookup, in, out, mode->window, overlap);
215    } else if (!shortBlocks) {
216       const mdct_lookup *lookup = MDCT(mode);
217       const int overlap = OVERLAP(mode);
218       const int N = FRAMESIZE(mode);
219       int c;
220       VARDECL(celt_word32_t, x);
221       VARDECL(celt_word32_t, tmp);
222       SAVE_STACK;
223       ALLOC(x, N+overlap, celt_word32_t);
224       ALLOC(tmp, N, celt_word32_t);
225       for (c=0;c<C;c++)
226       {
227          int j;
228          for (j=0;j<N+overlap;j++)
229             x[j] = in[C*j+c];
230          mdct_forward(lookup, x, tmp, mode->window, overlap);
231          /* Interleaving the sub-frames */
232          for (j=0;j<N;j++)
233             out[C*j+c] = tmp[j];
234       }
235       RESTORE_STACK;
236    } else {
237       const mdct_lookup *lookup = &mode->shortMdct;
238       const int overlap = mode->shortMdctSize;
239       const int N = mode->shortMdctSize;
240       int b, c;
241       VARDECL(celt_word32_t, x);
242       VARDECL(celt_word32_t, tmp);
243       SAVE_STACK;
244       ALLOC(x, N+overlap, celt_word32_t);
245       ALLOC(tmp, N, celt_word32_t);
246       for (c=0;c<C;c++)
247       {
248          int B = mode->nbShortMdcts;
249          for (b=0;b<B;b++)
250          {
251             int j;
252             for (j=0;j<N+overlap;j++)
253                x[j] = in[C*(b*N+j)+c];
254             mdct_forward(lookup, x, tmp, mode->window, overlap);
255             /* Interleaving the sub-frames */
256             for (j=0;j<N;j++)
257                out[C*(j*B+b)+c] = tmp[j];
258          }
259       }
260       RESTORE_STACK;
261    }
262 }
263
264 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
265 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, float transient_gain, celt_sig_t * restrict out_mem)
266 {
267    int c, N4;
268    const int C = CHANNELS(mode);
269    const int N = FRAMESIZE(mode);
270    const int overlap = OVERLAP(mode);
271    N4 = (N-overlap)>>1;
272    for (c=0;c<C;c++)
273    {
274       int j;
275       if (transient_time<0 && C==1 && !shortBlocks) {
276          const mdct_lookup *lookup = MDCT(mode);
277          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
278       } else if (!shortBlocks) {
279          const mdct_lookup *lookup = MDCT(mode);
280          VARDECL(celt_word32_t, x);
281          VARDECL(celt_word32_t, tmp);
282          SAVE_STACK;
283          ALLOC(x, 2*N, celt_word32_t);
284          ALLOC(tmp, N, celt_word32_t);
285          /* De-interleaving the sub-frames */
286          for (j=0;j<N;j++)
287             tmp[j] = X[C*j+c];
288          /* Prevents problems from the imdct doing the overlap-add */
289          CELT_MEMSET(x+N4, 0, overlap);
290          mdct_backward(lookup, tmp, x, mode->window, overlap);
291          if (transient_time >= 0)
292          {
293             for (j=0;j<16;j++)
294                x[N4+transient_time+j-16] *= 1+gainWindow[j]*(transient_gain-1);
295             for (j=transient_time;j<N+overlap;j++)
296                x[N4+j] *= transient_gain;
297          }
298          /* The first and last part would need to be set to zero if we actually
299          wanted to use them. */
300          for (j=0;j<overlap;j++)
301             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
302          for (j=0;j<overlap;j++)
303             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
304          for (j=0;j<2*N4;j++)
305             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
306          RESTORE_STACK;
307       } else {
308          int b;
309          const int N2 = mode->shortMdctSize;
310          const int B = mode->nbShortMdcts;
311          const mdct_lookup *lookup = &mode->shortMdct;
312          VARDECL(celt_word32_t, x);
313          VARDECL(celt_word32_t, tmp);
314          SAVE_STACK;
315          ALLOC(x, 2*N, celt_word32_t);
316          ALLOC(tmp, N, celt_word32_t);
317          /* Prevents problems from the imdct doing the overlap-add */
318          CELT_MEMSET(x+N4, 0, overlap);
319          for (b=0;b<B;b++)
320          {
321             /* De-interleaving the sub-frames */
322             for (j=0;j<N2;j++)
323                tmp[j] = X[C*(j*B+b)+c];
324             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
325          }
326          if (transient_time >= 0)
327          {
328             for (j=0;j<16;j++)
329                x[N4+transient_time+j-16] *= 1+gainWindow[j]*(transient_gain-1);
330             for (j=transient_time;j<N+overlap;j++)
331                x[N4+j] *= transient_gain;
332          }
333          /* The first and last part would need to be set to zero if we actually
334          wanted to use them. */
335          for (j=0;j<overlap;j++)
336             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
337          for (j=0;j<overlap;j++)
338             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
339          for (j=0;j<2*N4;j++)
340             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
341          RESTORE_STACK;
342       }
343    }
344 }
345
346 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
347 {
348    int i, c, N, N4;
349    int has_pitch;
350    int pitch_index;
351    celt_word32_t curr_power, pitch_power;
352    VARDECL(celt_sig_t, in);
353    VARDECL(celt_sig_t, freq);
354    VARDECL(celt_norm_t, X);
355    VARDECL(celt_norm_t, P);
356    VARDECL(celt_ener_t, bandE);
357    VARDECL(celt_pgain_t, gains);
358    VARDECL(int, stereo_mode);
359 #ifdef EXP_PSY
360    VARDECL(celt_word32_t, mask);
361 #endif
362    int shortBlocks=0;
363    int transient_time;
364    float transient_gain;
365    float maxR;
366    const int C = CHANNELS(st->mode);
367    SAVE_STACK;
368
369    if (check_mode(st->mode) != CELT_OK)
370       return CELT_INVALID_MODE;
371
372    N = st->block_size;
373    N4 = (N-st->overlap)>>1;
374    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
375
376    CELT_COPY(in, st->in_mem, C*st->overlap);
377    for (c=0;c<C;c++)
378    {
379       const celt_int16_t * restrict pcmp = pcm+c;
380       celt_sig_t * restrict inp = in+C*st->overlap+c;
381       for (i=0;i<N;i++)
382       {
383          /* Apply pre-emphasis */
384          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
385          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
386          st->preemph_memE[c] = *pcmp;
387          inp += C;
388          pcmp += C;
389       }
390    }
391    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
392    
393    transient_time = transient_analysis(in, N+st->overlap, C, &maxR);
394    if (maxR > 30)
395    {
396       float gain_1;
397       ec_enc_bits(&st->enc, 1, 1);
398       if (maxR < 30)
399       {
400          transient_time = 16;
401          transient_gain = 1;
402          ec_enc_bits(&st->enc, 0, 2);
403       } else if (maxR < 100)
404       {
405          transient_gain = 2;
406          ec_enc_bits(&st->enc, 1, 2);
407       } else if (maxR < 500)
408       {
409          transient_gain = 4;
410          ec_enc_bits(&st->enc, 2, 2);
411       } else
412       {
413          transient_gain = 8;
414          ec_enc_bits(&st->enc, 3, 2);
415       }
416       ec_enc_uint(&st->enc, transient_time, N+st->overlap);
417       for (c=0;c<C;c++)
418          for (i=0;i<16;i++)
419             in[C*(transient_time+i-16)+c] /= 1+gainWindow[i]*(transient_gain-1);
420       gain_1 = 1./transient_gain;
421       for (c=0;c<C;c++)
422          for (i=transient_time;i<N+st->overlap;i++)
423             in[C*i+c] *= gain_1;
424       shortBlocks = 1;
425    } else {
426       ec_enc_bits(&st->enc, 0, 1);
427       transient_time = -1;
428       transient_gain = 1;
429       shortBlocks = 0;
430    }
431    /* Pitch analysis: we do it early to save on the peak stack space */
432    if (!shortBlocks)
433       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
434
435    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
436    
437    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
438    /* Compute MDCTs */
439    compute_mdcts(st->mode, shortBlocks, in, freq);
440
441 #ifdef EXP_PSY
442    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
443    for (i=0;i<N;i++)
444       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
445    for (c=1;c<C;c++)
446       for (i=0;i<N;i++)
447          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
448
449    ALLOC(mask, N, celt_sig_t);
450    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
451
452    /* Invert and stretch the mask to length of X 
453       For some reason, I get better results by using the sqrt instead,
454       although there's no valid reason to. Must investigate further */
455    for (i=0;i<C*N;i++)
456       mask[i] = 1/(.1+mask[i]);
457 #endif
458    
459    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
460    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
461    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
462    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
463    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
464
465    /*printf ("%f %f\n", curr_power, pitch_power);*/
466    /*int j;
467    for (j=0;j<B*N;j++)
468       printf ("%f ", X[j]);
469    for (j=0;j<B*N;j++)
470       printf ("%f ", P[j]);
471    printf ("\n");*/
472
473    /* Band normalisation */
474    compute_band_energies(st->mode, freq, bandE);
475    normalise_bands(st->mode, freq, X, bandE);
476    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
477    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
478
479    /* Compute MDCTs of the pitch part */
480    if (!shortBlocks)
481       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
482
483    {
484       /* Normalise the pitch vector as well (discard the energies) */
485       VARDECL(celt_ener_t, bandEp);
486       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
487       compute_band_energies(st->mode, freq, bandEp);
488       normalise_bands(st->mode, freq, P, bandEp);
489       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
490    }
491    curr_power = bandE[0]+bandE[1]+bandE[2];
492    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
493    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
494    {
495       /* Simulates intensity stereo */
496       /*for (i=30;i<N*B;i++)
497          X[i*C+1] = P[i*C+1] = 0;*/
498
499       /* Pitch prediction */
500       compute_pitch_gain(st->mode, X, P, gains);
501       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
502       if (has_pitch)
503          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
504    } else {
505       /* No pitch, so we just pretend we found a gain of zero */
506       for (i=0;i<st->mode->nbPBands;i++)
507          gains[i] = 0;
508       ec_enc_bits(&st->enc, 0, 7);
509       for (i=0;i<C*N;i++)
510          P[i] = 0;
511    }
512    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
513
514    ALLOC(stereo_mode, st->mode->nbEBands, int);
515    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
516
517    pitch_quant_bands(st->mode, P, gains);
518
519    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
520
521    /* Residual quantisation */
522    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
523    
524    if (C==2)
525    {
526       renormalise_bands(st->mode, X);
527    }
528    /* Synthesis */
529    denormalise_bands(st->mode, X, freq, bandE);
530
531
532    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
533
534    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_gain, st->out_mem);
535    /* De-emphasis and put everything back at the right place in the synthesis history */
536 #ifndef SHORTCUTS
537    for (c=0;c<C;c++)
538    {
539       int j;
540       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
541       celt_int16_t * restrict pcmp = pcm+c;
542       for (j=0;j<N;j++)
543       {
544          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
545          st->preemph_memD[c] = tmp;
546          *pcmp = SIG2INT16(tmp);
547          pcmp += C;
548          outp += C;
549       }
550    }
551 #endif
552    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
553       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
554    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
555    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
556    {
557       int val = 0;
558       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
559       {
560          ec_enc_uint(&st->enc, val, 2);
561          val = 1-val;
562       }
563    }
564    ec_enc_done(&st->enc);
565    {
566       unsigned char *data;
567       int nbBytes = ec_byte_bytes(&st->buf);
568       if (nbBytes > nbCompressedBytes)
569       {
570          celt_warning_int ("got too many bytes:", nbBytes);
571          RESTORE_STACK;
572          return CELT_INTERNAL_ERROR;
573       }
574       /*printf ("%d\n", *nbBytes);*/
575       data = ec_byte_get_buffer(&st->buf);
576       for (i=0;i<nbBytes;i++)
577          compressed[i] = data[i];
578       for (;i<nbCompressedBytes;i++)
579          compressed[i] = 0;
580    }
581    /* Reset the packing for the next encoding */
582    ec_byte_reset(&st->buf);
583    ec_enc_init(&st->enc,&st->buf);
584
585    RESTORE_STACK;
586    return nbCompressedBytes;
587 }
588
589
590 /****************************************************************************/
591 /*                                                                          */
592 /*                                DECODER                                   */
593 /*                                                                          */
594 /****************************************************************************/
595
596
597 /** Decoder state 
598  @brief Decoder state
599  */
600 struct CELTDecoder {
601    const CELTMode *mode;
602    int frame_size;
603    int block_size;
604    int overlap;
605
606    ec_byte_buffer buf;
607    ec_enc         enc;
608
609    celt_sig_t * restrict preemph_memD;
610
611    celt_sig_t *out_mem;
612
613    celt_word16_t *oldBandE;
614    
615    int last_pitch_index;
616 };
617
618 CELTDecoder *celt_decoder_create(const CELTMode *mode)
619 {
620    int N, C;
621    CELTDecoder *st;
622
623    if (check_mode(mode) != CELT_OK)
624       return NULL;
625
626    N = mode->mdctSize;
627    C = CHANNELS(mode);
628    st = celt_alloc(sizeof(CELTDecoder));
629    
630    st->mode = mode;
631    st->frame_size = N;
632    st->block_size = N;
633    st->overlap = mode->overlap;
634
635    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
636    
637    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
638
639    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
640
641    st->last_pitch_index = 0;
642    return st;
643 }
644
645 void celt_decoder_destroy(CELTDecoder *st)
646 {
647    if (st == NULL)
648    {
649       celt_warning("NULL passed to celt_encoder_destroy");
650       return;
651    }
652    if (check_mode(st->mode) != CELT_OK)
653       return;
654
655
656    celt_free(st->out_mem);
657    
658    celt_free(st->oldBandE);
659    
660    celt_free(st->preemph_memD);
661
662    celt_free(st);
663 }
664
665 /** Handles lost packets by just copying past data with the same offset as the last
666     pitch period */
667 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
668 {
669    int c, N;
670    int pitch_index;
671    int i, len;
672    VARDECL(celt_sig_t, freq);
673    const int C = CHANNELS(st->mode);
674    int offset;
675    SAVE_STACK;
676    N = st->block_size;
677    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
678    
679    len = N+st->mode->overlap;
680 #if 0
681    pitch_index = st->last_pitch_index;
682    
683    /* Use the pitch MDCT as the "guessed" signal */
684    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
685
686 #else
687    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
688    pitch_index = MAX_PERIOD-len-pitch_index;
689    offset = MAX_PERIOD-pitch_index;
690    while (offset+len >= MAX_PERIOD)
691       offset -= pitch_index;
692    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
693    for (i=0;i<N;i++)
694       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
695 #endif
696    
697    
698    
699    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
700    /* Compute inverse MDCTs */
701    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
702
703    for (c=0;c<C;c++)
704    {
705       int j;
706       for (j=0;j<N;j++)
707       {
708          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
709                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
710          st->preemph_memD[c] = tmp;
711          pcm[C*j+c] = SIG2INT16(tmp);
712       }
713    }
714    RESTORE_STACK;
715 }
716
717 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
718 {
719    int c, N, N4;
720    int has_pitch;
721    int pitch_index;
722    ec_dec dec;
723    ec_byte_buffer buf;
724    VARDECL(celt_sig_t, freq);
725    VARDECL(celt_norm_t, X);
726    VARDECL(celt_norm_t, P);
727    VARDECL(celt_ener_t, bandE);
728    VARDECL(celt_pgain_t, gains);
729    VARDECL(int, stereo_mode);
730    int shortBlocks;
731    int transient_time;
732    float transient_gain;
733    const int C = CHANNELS(st->mode);
734    SAVE_STACK;
735
736    if (check_mode(st->mode) != CELT_OK)
737       return CELT_INVALID_MODE;
738
739    N = st->block_size;
740    N4 = (N-st->overlap)>>1;
741
742    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
743    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
744    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
745    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
746    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
747    
748    if (check_mode(st->mode) != CELT_OK)
749    {
750       RESTORE_STACK;
751       return CELT_INVALID_MODE;
752    }
753    if (data == NULL)
754    {
755       celt_decode_lost(st, pcm);
756       RESTORE_STACK;
757       return 0;
758    }
759    
760    ec_byte_readinit(&buf,data,len);
761    ec_dec_init(&dec,&buf);
762    
763    shortBlocks = ec_dec_bits(&dec, 1);
764    if (shortBlocks)
765    {
766       int gainid = ec_dec_bits(&dec, 2);
767       switch(gainid) {
768          case 0:
769             transient_gain = 1;
770             break;
771          case 1:
772             transient_gain = 2;
773             break;
774          case 2:
775             transient_gain = 4;
776             break;
777          case 3:
778          default:
779             transient_gain = 8;
780             break;
781       }
782       transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
783    } else {
784       transient_time = -1;
785       transient_gain = 1;
786    }
787    /* Get the pitch gains */
788    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
789    
790    /* Get the pitch index */
791    if (has_pitch)
792    {
793       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
794       st->last_pitch_index = pitch_index;
795    } else {
796       /* FIXME: We could be more intelligent here and just not compute the MDCT */
797       pitch_index = 0;
798    }
799
800    /* Get band energies */
801    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
802
803    /* Pitch MDCT */
804    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
805
806    {
807       VARDECL(celt_ener_t, bandEp);
808       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
809       compute_band_energies(st->mode, freq, bandEp);
810       normalise_bands(st->mode, freq, P, bandEp);
811    }
812
813    ALLOC(stereo_mode, st->mode->nbEBands, int);
814    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
815    /* Apply pitch gains */
816    pitch_quant_bands(st->mode, P, gains);
817
818    /* Decode fixed codebook and merge with pitch */
819    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
820
821    if (C==2)
822    {
823       renormalise_bands(st->mode, X);
824    }
825    /* Synthesis */
826    denormalise_bands(st->mode, X, freq, bandE);
827
828
829    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
830    /* Compute inverse MDCTs */
831    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_gain, st->out_mem);
832
833    for (c=0;c<C;c++)
834    {
835       int j;
836       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
837       celt_int16_t * restrict pcmp = pcm+c;
838       for (j=0;j<N;j++)
839       {
840          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
841          st->preemph_memD[c] = tmp;
842          *pcmp = SIG2INT16(tmp);
843          pcmp += C;
844          outp += C;
845       }
846    }
847
848    {
849       unsigned int val = 0;
850       while (ec_dec_tell(&dec, 0) < len*8)
851       {
852          if (ec_dec_uint(&dec, 2) != val)
853          {
854             celt_warning("decode error");
855             RESTORE_STACK;
856             return CELT_CORRUPTED_DATA;
857          }
858          val = 1-val;
859       }
860    }
861
862    RESTORE_STACK;
863    return 0;
864    /*printf ("\n");*/
865 }
866