cleaned up transient_analysis() and replaced the algorithm with a simpler one
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167
168 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
169 {
170    int c, i, n;
171    celt_word32_t ratio;
172    /* FIXME: Remove the floats here */
173    VARDECL(celt_word32_t, begin);
174    SAVE_STACK;
175    ALLOC(begin, len, celt_word32_t);
176    for (i=0;i<len;i++)
177       begin[i] = EXTEND32(ABS16(SHR32(in[C*i],SIG_SHIFT)));
178    for (c=1;c<C;c++)
179    {
180       for (i=0;i<len;i++)
181          begin[i] = MAX32(begin[i], EXTEND32(ABS16(SHR32(in[C*i+c],SIG_SHIFT))));
182    }
183    for (i=1;i<len;i++)
184       begin[i] = MAX32(begin[i-1],begin[i]);
185    n = -1;
186    for (i=8;i<len-8;i++)
187    {
188       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
189          n=i;
190    }
191    if (n<32)
192    {
193       n = -1;
194       ratio = 0;
195    } else {
196       ratio = DIV32(begin[len-1],1+begin[n-16]);
197    }
198    /*printf ("%d %f\n", n, ratio*ratio);*/
199    if (ratio < 0)
200       ratio = 0;
201    if (ratio > 1000)
202       ratio = 1000;
203    ratio *= ratio;
204    if (ratio < 50)
205       *transient_shift = 0;
206    else if (ratio < 256)
207       *transient_shift = 1;
208    else if (ratio < 4096)
209       *transient_shift = 2;
210    else
211       *transient_shift = 3;
212    *transient_time = n;
213    
214    RESTORE_STACK;
215    return ratio > 20;
216 }
217
218 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
219 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
220 {
221    const int C = CHANNELS(mode);
222    if (C==1 && !shortBlocks)
223    {
224       const mdct_lookup *lookup = MDCT(mode);
225       const int overlap = OVERLAP(mode);
226       mdct_forward(lookup, in, out, mode->window, overlap);
227    } else if (!shortBlocks) {
228       const mdct_lookup *lookup = MDCT(mode);
229       const int overlap = OVERLAP(mode);
230       const int N = FRAMESIZE(mode);
231       int c;
232       VARDECL(celt_word32_t, x);
233       VARDECL(celt_word32_t, tmp);
234       SAVE_STACK;
235       ALLOC(x, N+overlap, celt_word32_t);
236       ALLOC(tmp, N, celt_word32_t);
237       for (c=0;c<C;c++)
238       {
239          int j;
240          for (j=0;j<N+overlap;j++)
241             x[j] = in[C*j+c];
242          mdct_forward(lookup, x, tmp, mode->window, overlap);
243          /* Interleaving the sub-frames */
244          for (j=0;j<N;j++)
245             out[C*j+c] = tmp[j];
246       }
247       RESTORE_STACK;
248    } else {
249       const mdct_lookup *lookup = &mode->shortMdct;
250       const int overlap = mode->shortMdctSize;
251       const int N = mode->shortMdctSize;
252       int b, c;
253       VARDECL(celt_word32_t, x);
254       VARDECL(celt_word32_t, tmp);
255       SAVE_STACK;
256       ALLOC(x, N+overlap, celt_word32_t);
257       ALLOC(tmp, N, celt_word32_t);
258       for (c=0;c<C;c++)
259       {
260          int B = mode->nbShortMdcts;
261          for (b=0;b<B;b++)
262          {
263             int j;
264             for (j=0;j<N+overlap;j++)
265                x[j] = in[C*(b*N+j)+c];
266             mdct_forward(lookup, x, tmp, mode->window, overlap);
267             /* Interleaving the sub-frames */
268             for (j=0;j<N;j++)
269                out[C*(j*B+b)+c] = tmp[j];
270          }
271       }
272       RESTORE_STACK;
273    }
274 }
275
276 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
277 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
278 {
279    int c, N4;
280    const int C = CHANNELS(mode);
281    const int N = FRAMESIZE(mode);
282    const int overlap = OVERLAP(mode);
283    N4 = (N-overlap)>>1;
284    for (c=0;c<C;c++)
285    {
286       int j;
287       if (transient_shift==0 && C==1 && !shortBlocks) {
288          const mdct_lookup *lookup = MDCT(mode);
289          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
290       } else if (!shortBlocks) {
291          const mdct_lookup *lookup = MDCT(mode);
292          VARDECL(celt_word32_t, x);
293          VARDECL(celt_word32_t, tmp);
294          SAVE_STACK;
295          ALLOC(x, 2*N, celt_word32_t);
296          ALLOC(tmp, N, celt_word32_t);
297          /* De-interleaving the sub-frames */
298          for (j=0;j<N;j++)
299             tmp[j] = X[C*j+c];
300          /* Prevents problems from the imdct doing the overlap-add */
301          CELT_MEMSET(x+N4, 0, overlap);
302          mdct_backward(lookup, tmp, x, mode->window, overlap);
303          celt_assert(transient_shift == 0);
304          /* The first and last part would need to be set to zero if we actually
305             wanted to use them. */
306          for (j=0;j<overlap;j++)
307             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
308          for (j=0;j<overlap;j++)
309             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
310          for (j=0;j<2*N4;j++)
311             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
312          RESTORE_STACK;
313       } else {
314          int b;
315          const int N2 = mode->shortMdctSize;
316          const int B = mode->nbShortMdcts;
317          const mdct_lookup *lookup = &mode->shortMdct;
318          VARDECL(celt_word32_t, x);
319          VARDECL(celt_word32_t, tmp);
320          SAVE_STACK;
321          ALLOC(x, 2*N, celt_word32_t);
322          ALLOC(tmp, N, celt_word32_t);
323          /* Prevents problems from the imdct doing the overlap-add */
324          CELT_MEMSET(x+N4, 0, overlap);
325          for (b=0;b<B;b++)
326          {
327             /* De-interleaving the sub-frames */
328             for (j=0;j<N2;j++)
329                tmp[j] = X[C*(j*B+b)+c];
330             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
331          }
332          if (transient_shift > 0)
333          {
334 #ifdef FIXED_POINT
335             for (j=0;j<16;j++)
336                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
337             for (j=transient_time;j<N+overlap;j++)
338                x[N4+j] = SHL32(x[N4+j], transient_shift);
339 #else
340             for (j=0;j<16;j++)
341                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
342             for (j=transient_time;j<N+overlap;j++)
343                x[N4+j] *= 1<<transient_shift;
344 #endif
345          }
346          /* The first and last part would need to be set to zero if we actually
347          wanted to use them. */
348          for (j=0;j<overlap;j++)
349             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
350          for (j=0;j<overlap;j++)
351             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
352          for (j=0;j<2*N4;j++)
353             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
354          RESTORE_STACK;
355       }
356    }
357 }
358
359 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
360 {
361    int i, c, N, N4;
362    int has_pitch;
363    int pitch_index;
364    celt_word32_t curr_power, pitch_power;
365    VARDECL(celt_sig_t, in);
366    VARDECL(celt_sig_t, freq);
367    VARDECL(celt_norm_t, X);
368    VARDECL(celt_norm_t, P);
369    VARDECL(celt_ener_t, bandE);
370    VARDECL(celt_pgain_t, gains);
371    VARDECL(int, stereo_mode);
372 #ifdef EXP_PSY
373    VARDECL(celt_word32_t, mask);
374 #endif
375    int shortBlocks=0;
376    int transient_time;
377    int transient_shift;
378    const int C = CHANNELS(st->mode);
379    SAVE_STACK;
380
381    if (check_mode(st->mode) != CELT_OK)
382       return CELT_INVALID_MODE;
383
384    N = st->block_size;
385    N4 = (N-st->overlap)>>1;
386    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
387
388    CELT_COPY(in, st->in_mem, C*st->overlap);
389    for (c=0;c<C;c++)
390    {
391       const celt_int16_t * restrict pcmp = pcm+c;
392       celt_sig_t * restrict inp = in+C*st->overlap+c;
393       for (i=0;i<N;i++)
394       {
395          /* Apply pre-emphasis */
396          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
397          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
398          st->preemph_memE[c] = *pcmp;
399          inp += C;
400          pcmp += C;
401       }
402    }
403    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
404    
405    if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
406    {
407 #ifndef FIXED_POINT
408       float gain_1;
409 #endif
410       ec_enc_bits(&st->enc, 1, 1);
411       ec_enc_bits(&st->enc, transient_shift, 2);
412       if (transient_shift)
413          ec_enc_uint(&st->enc, transient_time, N+st->overlap);
414       if (transient_shift)
415       {
416 #ifdef FIXED_POINT
417          for (c=0;c<C;c++)
418             for (i=0;i<16;i++)
419                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
420          for (c=0;c<C;c++)
421             for (i=transient_time;i<N+st->overlap;i++)
422                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
423 #else
424          for (c=0;c<C;c++)
425             for (i=0;i<16;i++)
426                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
427          gain_1 = 1./(1<<transient_shift);
428          for (c=0;c<C;c++)
429             for (i=transient_time;i<N+st->overlap;i++)
430                in[C*i+c] *= gain_1;
431 #endif
432       }
433       shortBlocks = 1;
434    } else {
435       ec_enc_bits(&st->enc, 0, 1);
436       transient_time = -1;
437       transient_shift = 0;
438       shortBlocks = 0;
439    }
440    /* Pitch analysis: we do it early to save on the peak stack space */
441    if (!shortBlocks)
442       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
443
444    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
445    
446    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
447    /* Compute MDCTs */
448    compute_mdcts(st->mode, shortBlocks, in, freq);
449
450 #ifdef EXP_PSY
451    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
452    for (i=0;i<N;i++)
453       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
454    for (c=1;c<C;c++)
455       for (i=0;i<N;i++)
456          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
457
458    ALLOC(mask, N, celt_sig_t);
459    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
460
461    /* Invert and stretch the mask to length of X 
462       For some reason, I get better results by using the sqrt instead,
463       although there's no valid reason to. Must investigate further */
464    for (i=0;i<C*N;i++)
465       mask[i] = 1/(.1+mask[i]);
466 #endif
467    
468    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
469    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
470    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
471    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
472    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
473
474    /*printf ("%f %f\n", curr_power, pitch_power);*/
475    /*int j;
476    for (j=0;j<B*N;j++)
477       printf ("%f ", X[j]);
478    for (j=0;j<B*N;j++)
479       printf ("%f ", P[j]);
480    printf ("\n");*/
481
482    /* Band normalisation */
483    compute_band_energies(st->mode, freq, bandE);
484    normalise_bands(st->mode, freq, X, bandE);
485    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
486    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
487
488    /* Compute MDCTs of the pitch part */
489    if (!shortBlocks)
490       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
491
492    {
493       /* Normalise the pitch vector as well (discard the energies) */
494       VARDECL(celt_ener_t, bandEp);
495       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
496       compute_band_energies(st->mode, freq, bandEp);
497       normalise_bands(st->mode, freq, P, bandEp);
498       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
499    }
500    curr_power = bandE[0]+bandE[1]+bandE[2];
501    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
502    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
503    {
504       /* Simulates intensity stereo */
505       /*for (i=30;i<N*B;i++)
506          X[i*C+1] = P[i*C+1] = 0;*/
507
508       /* Pitch prediction */
509       compute_pitch_gain(st->mode, X, P, gains);
510       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
511       if (has_pitch)
512          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
513    } else {
514       /* No pitch, so we just pretend we found a gain of zero */
515       for (i=0;i<st->mode->nbPBands;i++)
516          gains[i] = 0;
517       ec_enc_bits(&st->enc, 0, 7);
518       for (i=0;i<C*N;i++)
519          P[i] = 0;
520    }
521    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
522
523    ALLOC(stereo_mode, st->mode->nbEBands, int);
524    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
525
526    pitch_quant_bands(st->mode, P, gains);
527
528    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
529
530    /* Residual quantisation */
531    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
532    
533    if (C==2)
534    {
535       renormalise_bands(st->mode, X);
536    }
537    /* Synthesis */
538    denormalise_bands(st->mode, X, freq, bandE);
539
540
541    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
542
543    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
544    /* De-emphasis and put everything back at the right place in the synthesis history */
545 #ifndef SHORTCUTS
546    for (c=0;c<C;c++)
547    {
548       int j;
549       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
550       celt_int16_t * restrict pcmp = pcm+c;
551       for (j=0;j<N;j++)
552       {
553          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
554          st->preemph_memD[c] = tmp;
555          *pcmp = SIG2INT16(tmp);
556          pcmp += C;
557          outp += C;
558       }
559    }
560 #endif
561    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
562       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
563    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
564    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
565    {
566       int val = 0;
567       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
568       {
569          ec_enc_uint(&st->enc, val, 2);
570          val = 1-val;
571       }
572    }
573    ec_enc_done(&st->enc);
574    {
575       unsigned char *data;
576       int nbBytes = ec_byte_bytes(&st->buf);
577       if (nbBytes > nbCompressedBytes)
578       {
579          celt_warning_int ("got too many bytes:", nbBytes);
580          RESTORE_STACK;
581          return CELT_INTERNAL_ERROR;
582       }
583       /*printf ("%d\n", *nbBytes);*/
584       data = ec_byte_get_buffer(&st->buf);
585       for (i=0;i<nbBytes;i++)
586          compressed[i] = data[i];
587       for (;i<nbCompressedBytes;i++)
588          compressed[i] = 0;
589    }
590    /* Reset the packing for the next encoding */
591    ec_byte_reset(&st->buf);
592    ec_enc_init(&st->enc,&st->buf);
593
594    RESTORE_STACK;
595    return nbCompressedBytes;
596 }
597
598
599 /****************************************************************************/
600 /*                                                                          */
601 /*                                DECODER                                   */
602 /*                                                                          */
603 /****************************************************************************/
604
605
606 /** Decoder state 
607  @brief Decoder state
608  */
609 struct CELTDecoder {
610    const CELTMode *mode;
611    int frame_size;
612    int block_size;
613    int overlap;
614
615    ec_byte_buffer buf;
616    ec_enc         enc;
617
618    celt_sig_t * restrict preemph_memD;
619
620    celt_sig_t *out_mem;
621
622    celt_word16_t *oldBandE;
623    
624    int last_pitch_index;
625 };
626
627 CELTDecoder *celt_decoder_create(const CELTMode *mode)
628 {
629    int N, C;
630    CELTDecoder *st;
631
632    if (check_mode(mode) != CELT_OK)
633       return NULL;
634
635    N = mode->mdctSize;
636    C = CHANNELS(mode);
637    st = celt_alloc(sizeof(CELTDecoder));
638    
639    st->mode = mode;
640    st->frame_size = N;
641    st->block_size = N;
642    st->overlap = mode->overlap;
643
644    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
645    
646    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
647
648    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
649
650    st->last_pitch_index = 0;
651    return st;
652 }
653
654 void celt_decoder_destroy(CELTDecoder *st)
655 {
656    if (st == NULL)
657    {
658       celt_warning("NULL passed to celt_encoder_destroy");
659       return;
660    }
661    if (check_mode(st->mode) != CELT_OK)
662       return;
663
664
665    celt_free(st->out_mem);
666    
667    celt_free(st->oldBandE);
668    
669    celt_free(st->preemph_memD);
670
671    celt_free(st);
672 }
673
674 /** Handles lost packets by just copying past data with the same offset as the last
675     pitch period */
676 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
677 {
678    int c, N;
679    int pitch_index;
680    int i, len;
681    VARDECL(celt_sig_t, freq);
682    const int C = CHANNELS(st->mode);
683    int offset;
684    SAVE_STACK;
685    N = st->block_size;
686    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
687    
688    len = N+st->mode->overlap;
689 #if 0
690    pitch_index = st->last_pitch_index;
691    
692    /* Use the pitch MDCT as the "guessed" signal */
693    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
694
695 #else
696    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
697    pitch_index = MAX_PERIOD-len-pitch_index;
698    offset = MAX_PERIOD-pitch_index;
699    while (offset+len >= MAX_PERIOD)
700       offset -= pitch_index;
701    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
702    for (i=0;i<N;i++)
703       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
704 #endif
705    
706    
707    
708    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
709    /* Compute inverse MDCTs */
710    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
711
712    for (c=0;c<C;c++)
713    {
714       int j;
715       for (j=0;j<N;j++)
716       {
717          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
718                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
719          st->preemph_memD[c] = tmp;
720          pcm[C*j+c] = SIG2INT16(tmp);
721       }
722    }
723    RESTORE_STACK;
724 }
725
726 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
727 {
728    int c, N, N4;
729    int has_pitch;
730    int pitch_index;
731    ec_dec dec;
732    ec_byte_buffer buf;
733    VARDECL(celt_sig_t, freq);
734    VARDECL(celt_norm_t, X);
735    VARDECL(celt_norm_t, P);
736    VARDECL(celt_ener_t, bandE);
737    VARDECL(celt_pgain_t, gains);
738    VARDECL(int, stereo_mode);
739    int shortBlocks;
740    int transient_time;
741    int transient_shift;
742    const int C = CHANNELS(st->mode);
743    SAVE_STACK;
744
745    if (check_mode(st->mode) != CELT_OK)
746       return CELT_INVALID_MODE;
747
748    N = st->block_size;
749    N4 = (N-st->overlap)>>1;
750
751    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
752    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
753    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
754    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
755    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
756    
757    if (check_mode(st->mode) != CELT_OK)
758    {
759       RESTORE_STACK;
760       return CELT_INVALID_MODE;
761    }
762    if (data == NULL)
763    {
764       celt_decode_lost(st, pcm);
765       RESTORE_STACK;
766       return 0;
767    }
768    
769    ec_byte_readinit(&buf,data,len);
770    ec_dec_init(&dec,&buf);
771    
772    shortBlocks = ec_dec_bits(&dec, 1);
773    if (shortBlocks)
774    {
775       transient_shift = ec_dec_bits(&dec, 2);
776       if (transient_shift)
777          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
778       else
779          transient_time = 0;
780    } else {
781       transient_time = -1;
782       transient_shift = 0;
783    }
784    /* Get the pitch gains */
785    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
786    
787    /* Get the pitch index */
788    if (has_pitch)
789    {
790       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
791       st->last_pitch_index = pitch_index;
792    } else {
793       /* FIXME: We could be more intelligent here and just not compute the MDCT */
794       pitch_index = 0;
795    }
796
797    /* Get band energies */
798    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
799
800    /* Pitch MDCT */
801    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
802
803    {
804       VARDECL(celt_ener_t, bandEp);
805       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
806       compute_band_energies(st->mode, freq, bandEp);
807       normalise_bands(st->mode, freq, P, bandEp);
808    }
809
810    ALLOC(stereo_mode, st->mode->nbEBands, int);
811    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
812    /* Apply pitch gains */
813    pitch_quant_bands(st->mode, P, gains);
814
815    /* Decode fixed codebook and merge with pitch */
816    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
817
818    if (C==2)
819    {
820       renormalise_bands(st->mode, X);
821    }
822    /* Synthesis */
823    denormalise_bands(st->mode, X, freq, bandE);
824
825
826    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
827    /* Compute inverse MDCTs */
828    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
829
830    for (c=0;c<C;c++)
831    {
832       int j;
833       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
834       celt_int16_t * restrict pcmp = pcm+c;
835       for (j=0;j<N;j++)
836       {
837          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
838          st->preemph_memD[c] = tmp;
839          *pcmp = SIG2INT16(tmp);
840          pcmp += C;
841          outp += C;
842       }
843    }
844
845    {
846       unsigned int val = 0;
847       while (ec_dec_tell(&dec, 0) < len*8)
848       {
849          if (ec_dec_uint(&dec, 2) != val)
850          {
851             celt_warning("decode error");
852             RESTORE_STACK;
853             return CELT_CORRUPTED_DATA;
854          }
855          val = 1-val;
856       }
857    }
858
859    RESTORE_STACK;
860    return 0;
861    /*printf ("\n");*/
862 }
863