More fixed-point work on the time window -- including conversion of the gain
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52
53 static const celt_word16_t preemph = QCONST16(0.8f,15);
54
55 static const float gainWindow[16] = {
56    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
57    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
58
59 #ifdef FIXED_POINT
60 static const celt_word16_t transientWindow[16] = {
61      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
62    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
63 #else
64 static const float transientWindow[16] = {
65    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
66    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
67 #endif
68
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    ec_byte_buffer buf;
80    ec_enc         enc;
81
82    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
83    celt_sig_t    * restrict preemph_memD;
84
85    celt_sig_t *in_mem;
86    celt_sig_t *out_mem;
87
88    celt_word16_t *oldBandE;
89 #ifdef EXP_PSY
90    celt_word16_t *psy_mem;
91    struct PsyDecay psy;
92 #endif
93 };
94
95 CELTEncoder *celt_encoder_create(const CELTMode *mode)
96 {
97    int N, C;
98    CELTEncoder *st;
99
100    if (check_mode(mode) != CELT_OK)
101       return NULL;
102
103    N = mode->mdctSize;
104    C = mode->nbChannels;
105    st = celt_alloc(sizeof(CELTEncoder));
106    
107    st->mode = mode;
108    st->frame_size = N;
109    st->block_size = N;
110    st->overlap = mode->overlap;
111
112    ec_byte_writeinit(&st->buf);
113    ec_enc_init(&st->enc,&st->buf);
114
115    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
116    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
117
118    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
119
120    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
121    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
122
123 #ifdef EXP_PSY
124    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
125    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
126 #endif
127
128    return st;
129 }
130
131 void celt_encoder_destroy(CELTEncoder *st)
132 {
133    if (st == NULL)
134    {
135       celt_warning("NULL passed to celt_encoder_destroy");
136       return;
137    }
138    if (check_mode(st->mode) != CELT_OK)
139       return;
140
141    ec_byte_writeclear(&st->buf);
142
143    celt_free(st->in_mem);
144    celt_free(st->out_mem);
145    
146    celt_free(st->oldBandE);
147    
148    celt_free(st->preemph_memE);
149    celt_free(st->preemph_memD);
150    
151 #ifdef EXP_PSY
152    celt_free (st->psy_mem);
153    psydecay_clear(&st->psy);
154 #endif
155    
156    celt_free(st);
157 }
158
159 static inline celt_int16_t SIG2INT16(celt_sig_t x)
160 {
161    x = PSHR32(x, SIG_SHIFT);
162    x = MAX32(x, -32768);
163    x = MIN32(x, 32767);
164 #ifdef FIXED_POINT
165    return EXTRACT16(x);
166 #else
167    return (celt_int16_t)floor(.5+x);
168 #endif
169 }
170
171 static int transient_analysis(celt_word32_t *in, int len, int C, float *r)
172 {
173    int c, i, n;
174    float ratio, maxN, maxD;
175    float x[len];
176    float begin[len], end[len];
177    
178    for (i=0;i<len;i++)
179       x[i] = in[C*i];
180    for (c=1;c<C;c++)
181    {
182       for (i=0;i<len;i++)
183          x[i] = x[i] + in[C*i+c];
184    }
185    begin[0] = x[0]*x[0];
186    for (i=1;i<len;i++)
187       begin[i] = begin[i-1]+x[i]*x[i];
188    end[len-1] = x[len-1]*x[len-1];
189    for (i=len-2;i>=0;i--)
190       end[i] = end[i+1] + x[i]*x[i];
191    maxD = VERY_LARGE32;
192    maxN = 0;
193    n = -1;
194    for (i=8;i<len-8;i++)
195    {
196       float num, den;
197       num = end[i]*i;
198       den = (1000+begin[i])*(len-i)+.01*end[i]*len;
199       if ((num*maxD > den*maxN) && (end[i] > .05*begin[i]))
200       {
201          maxN = num;
202          maxD = den;
203          n = i;
204       }
205    }
206    ratio = (end[n]*n)/((100+begin[n])*(len-n));
207    if (n<32)
208    {
209       n = -1;
210       ratio = 0;
211    }
212    *r = ratio;
213    return n;
214 }
215
216 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
217 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
218 {
219    const int C = CHANNELS(mode);
220    if (C==1 && !shortBlocks)
221    {
222       const mdct_lookup *lookup = MDCT(mode);
223       const int overlap = OVERLAP(mode);
224       mdct_forward(lookup, in, out, mode->window, overlap);
225    } else if (!shortBlocks) {
226       const mdct_lookup *lookup = MDCT(mode);
227       const int overlap = OVERLAP(mode);
228       const int N = FRAMESIZE(mode);
229       int c;
230       VARDECL(celt_word32_t, x);
231       VARDECL(celt_word32_t, tmp);
232       SAVE_STACK;
233       ALLOC(x, N+overlap, celt_word32_t);
234       ALLOC(tmp, N, celt_word32_t);
235       for (c=0;c<C;c++)
236       {
237          int j;
238          for (j=0;j<N+overlap;j++)
239             x[j] = in[C*j+c];
240          mdct_forward(lookup, x, tmp, mode->window, overlap);
241          /* Interleaving the sub-frames */
242          for (j=0;j<N;j++)
243             out[C*j+c] = tmp[j];
244       }
245       RESTORE_STACK;
246    } else {
247       const mdct_lookup *lookup = &mode->shortMdct;
248       const int overlap = mode->shortMdctSize;
249       const int N = mode->shortMdctSize;
250       int b, c;
251       VARDECL(celt_word32_t, x);
252       VARDECL(celt_word32_t, tmp);
253       SAVE_STACK;
254       ALLOC(x, N+overlap, celt_word32_t);
255       ALLOC(tmp, N, celt_word32_t);
256       for (c=0;c<C;c++)
257       {
258          int B = mode->nbShortMdcts;
259          for (b=0;b<B;b++)
260          {
261             int j;
262             for (j=0;j<N+overlap;j++)
263                x[j] = in[C*(b*N+j)+c];
264             mdct_forward(lookup, x, tmp, mode->window, overlap);
265             /* Interleaving the sub-frames */
266             for (j=0;j<N;j++)
267                out[C*(j*B+b)+c] = tmp[j];
268          }
269       }
270       RESTORE_STACK;
271    }
272 }
273
274 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
275 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
276 {
277    int c, N4;
278    const int C = CHANNELS(mode);
279    const int N = FRAMESIZE(mode);
280    const int overlap = OVERLAP(mode);
281    N4 = (N-overlap)>>1;
282    for (c=0;c<C;c++)
283    {
284       int j;
285       if (transient_shift==0 && C==1 && !shortBlocks) {
286          const mdct_lookup *lookup = MDCT(mode);
287          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
288       } else if (!shortBlocks) {
289          const mdct_lookup *lookup = MDCT(mode);
290          VARDECL(celt_word32_t, x);
291          VARDECL(celt_word32_t, tmp);
292          SAVE_STACK;
293          ALLOC(x, 2*N, celt_word32_t);
294          ALLOC(tmp, N, celt_word32_t);
295          /* De-interleaving the sub-frames */
296          for (j=0;j<N;j++)
297             tmp[j] = X[C*j+c];
298          /* Prevents problems from the imdct doing the overlap-add */
299          CELT_MEMSET(x+N4, 0, overlap);
300          mdct_backward(lookup, tmp, x, mode->window, overlap);
301          if (transient_shift > 0)
302          {
303 #ifdef FIXED_POINT
304             for (j=0;j<16;j++)
305                x[N4+transient_time+j-16] *= 1-gainWindow[j]+gainWindow[j]*(1<<transient_shift);
306             for (j=transient_time;j<N+overlap;j++)
307                x[N4+j] = SHL32(x[N4+j], transient_shift);
308 #else
309             for (j=0;j<16;j++)
310                x[N4+transient_time+j-16] *= 1+gainWindow[j]*((1<<transient_shift)-1);
311             for (j=transient_time;j<N+overlap;j++)
312                x[N4+j] *= 1<<transient_shift;
313 #endif
314          }
315          /* The first and last part would need to be set to zero if we actually
316          wanted to use them. */
317          for (j=0;j<overlap;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
319          for (j=0;j<overlap;j++)
320             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
321          for (j=0;j<2*N4;j++)
322             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
323          RESTORE_STACK;
324       } else {
325          int b;
326          const int N2 = mode->shortMdctSize;
327          const int B = mode->nbShortMdcts;
328          const mdct_lookup *lookup = &mode->shortMdct;
329          VARDECL(celt_word32_t, x);
330          VARDECL(celt_word32_t, tmp);
331          SAVE_STACK;
332          ALLOC(x, 2*N, celt_word32_t);
333          ALLOC(tmp, N, celt_word32_t);
334          /* Prevents problems from the imdct doing the overlap-add */
335          CELT_MEMSET(x+N4, 0, overlap);
336          for (b=0;b<B;b++)
337          {
338             /* De-interleaving the sub-frames */
339             for (j=0;j<N2;j++)
340                tmp[j] = X[C*(j*B+b)+c];
341             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
342          }
343          if (transient_shift > 0)
344          {
345 #ifdef FIXED_POINT
346             for (j=0;j<16;j++)
347                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
348             for (j=transient_time;j<N+overlap;j++)
349                x[N4+j] = SHL32(x[N4+j], transient_shift);
350 #else
351             for (j=0;j<16;j++)
352                x[N4+transient_time+j-16] *= 1+gainWindow[j]*((1<<transient_shift)-1);
353             for (j=transient_time;j<N+overlap;j++)
354                x[N4+j] *= 1<<transient_shift;
355 #endif
356          }
357          /* The first and last part would need to be set to zero if we actually
358          wanted to use them. */
359          for (j=0;j<overlap;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
361          for (j=0;j<overlap;j++)
362             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
363          for (j=0;j<2*N4;j++)
364             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
365          RESTORE_STACK;
366       }
367    }
368 }
369
370 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
371 {
372    int i, c, N, N4;
373    int has_pitch;
374    int pitch_index;
375    celt_word32_t curr_power, pitch_power;
376    VARDECL(celt_sig_t, in);
377    VARDECL(celt_sig_t, freq);
378    VARDECL(celt_norm_t, X);
379    VARDECL(celt_norm_t, P);
380    VARDECL(celt_ener_t, bandE);
381    VARDECL(celt_pgain_t, gains);
382    VARDECL(int, stereo_mode);
383 #ifdef EXP_PSY
384    VARDECL(celt_word32_t, mask);
385 #endif
386    int shortBlocks=0;
387    int transient_time;
388    int transient_shift;
389    float maxR;
390    const int C = CHANNELS(st->mode);
391    SAVE_STACK;
392
393    if (check_mode(st->mode) != CELT_OK)
394       return CELT_INVALID_MODE;
395
396    N = st->block_size;
397    N4 = (N-st->overlap)>>1;
398    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
399
400    CELT_COPY(in, st->in_mem, C*st->overlap);
401    for (c=0;c<C;c++)
402    {
403       const celt_int16_t * restrict pcmp = pcm+c;
404       celt_sig_t * restrict inp = in+C*st->overlap+c;
405       for (i=0;i<N;i++)
406       {
407          /* Apply pre-emphasis */
408          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
409          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
410          st->preemph_memE[c] = *pcmp;
411          inp += C;
412          pcmp += C;
413       }
414    }
415    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
416    
417    transient_time = transient_analysis(in, N+st->overlap, C, &maxR);
418    if (maxR > 30)
419    {
420       float gain_1;
421       ec_enc_bits(&st->enc, 1, 1);
422       if (maxR < 30)
423       {
424          transient_shift = 0;
425       } else if (maxR < 100)
426       {
427          transient_shift = 1;
428       } else if (maxR < 500)
429       {
430          transient_shift = 2;
431       } else
432       {
433          transient_shift = 3;
434       }
435       ec_enc_bits(&st->enc, transient_shift, 2);
436       if (transient_shift)
437          ec_enc_uint(&st->enc, transient_time, N+st->overlap);
438       if (transient_shift)
439       {
440          for (c=0;c<C;c++)
441             for (i=0;i<16;i++)
442                in[C*(transient_time+i-16)+c] /= 1+gainWindow[i]*((1<<transient_shift)-1);
443 #ifdef FIXED_POINT
444          for (c=0;c<C;c++)
445             for (i=transient_time;i<N+st->overlap;i++)
446                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
447 #else
448          gain_1 = 1./(1<<transient_shift);
449          for (c=0;c<C;c++)
450             for (i=transient_time;i<N+st->overlap;i++)
451                in[C*i+c] *= gain_1;
452 #endif
453       }
454       shortBlocks = 1;
455    } else {
456       ec_enc_bits(&st->enc, 0, 1);
457       transient_time = -1;
458       transient_shift = 0;
459       shortBlocks = 0;
460    }
461    /* Pitch analysis: we do it early to save on the peak stack space */
462    if (!shortBlocks)
463       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
464
465    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
466    
467    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
468    /* Compute MDCTs */
469    compute_mdcts(st->mode, shortBlocks, in, freq);
470
471 #ifdef EXP_PSY
472    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
473    for (i=0;i<N;i++)
474       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
475    for (c=1;c<C;c++)
476       for (i=0;i<N;i++)
477          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
478
479    ALLOC(mask, N, celt_sig_t);
480    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
481
482    /* Invert and stretch the mask to length of X 
483       For some reason, I get better results by using the sqrt instead,
484       although there's no valid reason to. Must investigate further */
485    for (i=0;i<C*N;i++)
486       mask[i] = 1/(.1+mask[i]);
487 #endif
488    
489    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
490    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
491    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
492    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
493    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
494
495    /*printf ("%f %f\n", curr_power, pitch_power);*/
496    /*int j;
497    for (j=0;j<B*N;j++)
498       printf ("%f ", X[j]);
499    for (j=0;j<B*N;j++)
500       printf ("%f ", P[j]);
501    printf ("\n");*/
502
503    /* Band normalisation */
504    compute_band_energies(st->mode, freq, bandE);
505    normalise_bands(st->mode, freq, X, bandE);
506    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
507    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
508
509    /* Compute MDCTs of the pitch part */
510    if (!shortBlocks)
511       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
512
513    {
514       /* Normalise the pitch vector as well (discard the energies) */
515       VARDECL(celt_ener_t, bandEp);
516       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
517       compute_band_energies(st->mode, freq, bandEp);
518       normalise_bands(st->mode, freq, P, bandEp);
519       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
520    }
521    curr_power = bandE[0]+bandE[1]+bandE[2];
522    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
523    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
524    {
525       /* Simulates intensity stereo */
526       /*for (i=30;i<N*B;i++)
527          X[i*C+1] = P[i*C+1] = 0;*/
528
529       /* Pitch prediction */
530       compute_pitch_gain(st->mode, X, P, gains);
531       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
532       if (has_pitch)
533          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
534    } else {
535       /* No pitch, so we just pretend we found a gain of zero */
536       for (i=0;i<st->mode->nbPBands;i++)
537          gains[i] = 0;
538       ec_enc_bits(&st->enc, 0, 7);
539       for (i=0;i<C*N;i++)
540          P[i] = 0;
541    }
542    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
543
544    ALLOC(stereo_mode, st->mode->nbEBands, int);
545    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
546
547    pitch_quant_bands(st->mode, P, gains);
548
549    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
550
551    /* Residual quantisation */
552    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
553    
554    if (C==2)
555    {
556       renormalise_bands(st->mode, X);
557    }
558    /* Synthesis */
559    denormalise_bands(st->mode, X, freq, bandE);
560
561
562    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
563
564    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
565    /* De-emphasis and put everything back at the right place in the synthesis history */
566 #ifndef SHORTCUTS
567    for (c=0;c<C;c++)
568    {
569       int j;
570       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
571       celt_int16_t * restrict pcmp = pcm+c;
572       for (j=0;j<N;j++)
573       {
574          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
575          st->preemph_memD[c] = tmp;
576          *pcmp = SIG2INT16(tmp);
577          pcmp += C;
578          outp += C;
579       }
580    }
581 #endif
582    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
583       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
584    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
585    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
586    {
587       int val = 0;
588       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
589       {
590          ec_enc_uint(&st->enc, val, 2);
591          val = 1-val;
592       }
593    }
594    ec_enc_done(&st->enc);
595    {
596       unsigned char *data;
597       int nbBytes = ec_byte_bytes(&st->buf);
598       if (nbBytes > nbCompressedBytes)
599       {
600          celt_warning_int ("got too many bytes:", nbBytes);
601          RESTORE_STACK;
602          return CELT_INTERNAL_ERROR;
603       }
604       /*printf ("%d\n", *nbBytes);*/
605       data = ec_byte_get_buffer(&st->buf);
606       for (i=0;i<nbBytes;i++)
607          compressed[i] = data[i];
608       for (;i<nbCompressedBytes;i++)
609          compressed[i] = 0;
610    }
611    /* Reset the packing for the next encoding */
612    ec_byte_reset(&st->buf);
613    ec_enc_init(&st->enc,&st->buf);
614
615    RESTORE_STACK;
616    return nbCompressedBytes;
617 }
618
619
620 /****************************************************************************/
621 /*                                                                          */
622 /*                                DECODER                                   */
623 /*                                                                          */
624 /****************************************************************************/
625
626
627 /** Decoder state 
628  @brief Decoder state
629  */
630 struct CELTDecoder {
631    const CELTMode *mode;
632    int frame_size;
633    int block_size;
634    int overlap;
635
636    ec_byte_buffer buf;
637    ec_enc         enc;
638
639    celt_sig_t * restrict preemph_memD;
640
641    celt_sig_t *out_mem;
642
643    celt_word16_t *oldBandE;
644    
645    int last_pitch_index;
646 };
647
648 CELTDecoder *celt_decoder_create(const CELTMode *mode)
649 {
650    int N, C;
651    CELTDecoder *st;
652
653    if (check_mode(mode) != CELT_OK)
654       return NULL;
655
656    N = mode->mdctSize;
657    C = CHANNELS(mode);
658    st = celt_alloc(sizeof(CELTDecoder));
659    
660    st->mode = mode;
661    st->frame_size = N;
662    st->block_size = N;
663    st->overlap = mode->overlap;
664
665    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
666    
667    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
668
669    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
670
671    st->last_pitch_index = 0;
672    return st;
673 }
674
675 void celt_decoder_destroy(CELTDecoder *st)
676 {
677    if (st == NULL)
678    {
679       celt_warning("NULL passed to celt_encoder_destroy");
680       return;
681    }
682    if (check_mode(st->mode) != CELT_OK)
683       return;
684
685
686    celt_free(st->out_mem);
687    
688    celt_free(st->oldBandE);
689    
690    celt_free(st->preemph_memD);
691
692    celt_free(st);
693 }
694
695 /** Handles lost packets by just copying past data with the same offset as the last
696     pitch period */
697 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
698 {
699    int c, N;
700    int pitch_index;
701    int i, len;
702    VARDECL(celt_sig_t, freq);
703    const int C = CHANNELS(st->mode);
704    int offset;
705    SAVE_STACK;
706    N = st->block_size;
707    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
708    
709    len = N+st->mode->overlap;
710 #if 0
711    pitch_index = st->last_pitch_index;
712    
713    /* Use the pitch MDCT as the "guessed" signal */
714    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
715
716 #else
717    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
718    pitch_index = MAX_PERIOD-len-pitch_index;
719    offset = MAX_PERIOD-pitch_index;
720    while (offset+len >= MAX_PERIOD)
721       offset -= pitch_index;
722    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
723    for (i=0;i<N;i++)
724       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
725 #endif
726    
727    
728    
729    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
730    /* Compute inverse MDCTs */
731    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
732
733    for (c=0;c<C;c++)
734    {
735       int j;
736       for (j=0;j<N;j++)
737       {
738          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
739                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
740          st->preemph_memD[c] = tmp;
741          pcm[C*j+c] = SIG2INT16(tmp);
742       }
743    }
744    RESTORE_STACK;
745 }
746
747 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
748 {
749    int c, N, N4;
750    int has_pitch;
751    int pitch_index;
752    ec_dec dec;
753    ec_byte_buffer buf;
754    VARDECL(celt_sig_t, freq);
755    VARDECL(celt_norm_t, X);
756    VARDECL(celt_norm_t, P);
757    VARDECL(celt_ener_t, bandE);
758    VARDECL(celt_pgain_t, gains);
759    VARDECL(int, stereo_mode);
760    int shortBlocks;
761    int transient_time;
762    int transient_shift;
763    const int C = CHANNELS(st->mode);
764    SAVE_STACK;
765
766    if (check_mode(st->mode) != CELT_OK)
767       return CELT_INVALID_MODE;
768
769    N = st->block_size;
770    N4 = (N-st->overlap)>>1;
771
772    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
773    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
774    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
775    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
776    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
777    
778    if (check_mode(st->mode) != CELT_OK)
779    {
780       RESTORE_STACK;
781       return CELT_INVALID_MODE;
782    }
783    if (data == NULL)
784    {
785       celt_decode_lost(st, pcm);
786       RESTORE_STACK;
787       return 0;
788    }
789    
790    ec_byte_readinit(&buf,data,len);
791    ec_dec_init(&dec,&buf);
792    
793    shortBlocks = ec_dec_bits(&dec, 1);
794    if (shortBlocks)
795    {
796       transient_shift = ec_dec_bits(&dec, 2);
797       if (transient_shift)
798          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
799       else
800          transient_time = 0;
801    } else {
802       transient_time = -1;
803       transient_shift = 0;
804    }
805    /* Get the pitch gains */
806    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
807    
808    /* Get the pitch index */
809    if (has_pitch)
810    {
811       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
812       st->last_pitch_index = pitch_index;
813    } else {
814       /* FIXME: We could be more intelligent here and just not compute the MDCT */
815       pitch_index = 0;
816    }
817
818    /* Get band energies */
819    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
820
821    /* Pitch MDCT */
822    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
823
824    {
825       VARDECL(celt_ener_t, bandEp);
826       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
827       compute_band_energies(st->mode, freq, bandEp);
828       normalise_bands(st->mode, freq, P, bandEp);
829    }
830
831    ALLOC(stereo_mode, st->mode->nbEBands, int);
832    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
833    /* Apply pitch gains */
834    pitch_quant_bands(st->mode, P, gains);
835
836    /* Decode fixed codebook and merge with pitch */
837    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
838
839    if (C==2)
840    {
841       renormalise_bands(st->mode, X);
842    }
843    /* Synthesis */
844    denormalise_bands(st->mode, X, freq, bandE);
845
846
847    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
848    /* Compute inverse MDCTs */
849    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
850
851    for (c=0;c<C;c++)
852    {
853       int j;
854       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
855       celt_int16_t * restrict pcmp = pcm+c;
856       for (j=0;j<N;j++)
857       {
858          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
859          st->preemph_memD[c] = tmp;
860          *pcmp = SIG2INT16(tmp);
861          pcmp += C;
862          outp += C;
863       }
864    }
865
866    {
867       unsigned int val = 0;
868       while (ec_dec_tell(&dec, 0) < len*8)
869       {
870          if (ec_dec_uint(&dec, 2) != val)
871          {
872             celt_warning("decode error");
873             RESTORE_STACK;
874             return CELT_CORRUPTED_DATA;
875          }
876          val = 1-val;
877       }
878    }
879
880    RESTORE_STACK;
881    return 0;
882    /*printf ("\n");*/
883 }
884