More fixed-point conversion of the time window.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167
168 static int transient_analysis(celt_word32_t *in, int len, int C, float *r)
169 {
170    int c, i, n;
171    float ratio, maxN, maxD;
172    float x[len];
173    float begin[len], end[len];
174    
175    for (i=0;i<len;i++)
176       x[i] = in[C*i];
177    for (c=1;c<C;c++)
178    {
179       for (i=0;i<len;i++)
180          x[i] = x[i] + in[C*i+c];
181    }
182    begin[0] = x[0]*x[0];
183    for (i=1;i<len;i++)
184       begin[i] = begin[i-1]+x[i]*x[i];
185    end[len-1] = x[len-1]*x[len-1];
186    for (i=len-2;i>=0;i--)
187       end[i] = end[i+1] + x[i]*x[i];
188    maxD = VERY_LARGE32;
189    maxN = 0;
190    n = -1;
191    for (i=8;i<len-8;i++)
192    {
193       float num, den;
194       num = end[i]*i;
195       den = (1000+begin[i])*(len-i)+.01*end[i]*len;
196       if ((num*maxD > den*maxN) && (end[i] > .05*begin[i]))
197       {
198          maxN = num;
199          maxD = den;
200          n = i;
201       }
202    }
203    ratio = (end[n]*n)/((100+begin[n])*(len-n));
204    if (n<32)
205    {
206       n = -1;
207       ratio = 0;
208    }
209    *r = ratio;
210    return n;
211 }
212
213 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
214 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
215 {
216    const int C = CHANNELS(mode);
217    if (C==1 && !shortBlocks)
218    {
219       const mdct_lookup *lookup = MDCT(mode);
220       const int overlap = OVERLAP(mode);
221       mdct_forward(lookup, in, out, mode->window, overlap);
222    } else if (!shortBlocks) {
223       const mdct_lookup *lookup = MDCT(mode);
224       const int overlap = OVERLAP(mode);
225       const int N = FRAMESIZE(mode);
226       int c;
227       VARDECL(celt_word32_t, x);
228       VARDECL(celt_word32_t, tmp);
229       SAVE_STACK;
230       ALLOC(x, N+overlap, celt_word32_t);
231       ALLOC(tmp, N, celt_word32_t);
232       for (c=0;c<C;c++)
233       {
234          int j;
235          for (j=0;j<N+overlap;j++)
236             x[j] = in[C*j+c];
237          mdct_forward(lookup, x, tmp, mode->window, overlap);
238          /* Interleaving the sub-frames */
239          for (j=0;j<N;j++)
240             out[C*j+c] = tmp[j];
241       }
242       RESTORE_STACK;
243    } else {
244       const mdct_lookup *lookup = &mode->shortMdct;
245       const int overlap = mode->shortMdctSize;
246       const int N = mode->shortMdctSize;
247       int b, c;
248       VARDECL(celt_word32_t, x);
249       VARDECL(celt_word32_t, tmp);
250       SAVE_STACK;
251       ALLOC(x, N+overlap, celt_word32_t);
252       ALLOC(tmp, N, celt_word32_t);
253       for (c=0;c<C;c++)
254       {
255          int B = mode->nbShortMdcts;
256          for (b=0;b<B;b++)
257          {
258             int j;
259             for (j=0;j<N+overlap;j++)
260                x[j] = in[C*(b*N+j)+c];
261             mdct_forward(lookup, x, tmp, mode->window, overlap);
262             /* Interleaving the sub-frames */
263             for (j=0;j<N;j++)
264                out[C*(j*B+b)+c] = tmp[j];
265          }
266       }
267       RESTORE_STACK;
268    }
269 }
270
271 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
272 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
273 {
274    int c, N4;
275    const int C = CHANNELS(mode);
276    const int N = FRAMESIZE(mode);
277    const int overlap = OVERLAP(mode);
278    N4 = (N-overlap)>>1;
279    for (c=0;c<C;c++)
280    {
281       int j;
282       if (transient_shift==0 && C==1 && !shortBlocks) {
283          const mdct_lookup *lookup = MDCT(mode);
284          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
285       } else if (!shortBlocks) {
286          const mdct_lookup *lookup = MDCT(mode);
287          VARDECL(celt_word32_t, x);
288          VARDECL(celt_word32_t, tmp);
289          SAVE_STACK;
290          ALLOC(x, 2*N, celt_word32_t);
291          ALLOC(tmp, N, celt_word32_t);
292          /* De-interleaving the sub-frames */
293          for (j=0;j<N;j++)
294             tmp[j] = X[C*j+c];
295          /* Prevents problems from the imdct doing the overlap-add */
296          CELT_MEMSET(x+N4, 0, overlap);
297          mdct_backward(lookup, tmp, x, mode->window, overlap);
298          celt_assert(transient_shift == 0)
299          /* The first and last part would need to be set to zero if we actually
300             wanted to use them. */
301          for (j=0;j<overlap;j++)
302             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
303          for (j=0;j<overlap;j++)
304             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
305          for (j=0;j<2*N4;j++)
306             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
307          RESTORE_STACK;
308       } else {
309          int b;
310          const int N2 = mode->shortMdctSize;
311          const int B = mode->nbShortMdcts;
312          const mdct_lookup *lookup = &mode->shortMdct;
313          VARDECL(celt_word32_t, x);
314          VARDECL(celt_word32_t, tmp);
315          SAVE_STACK;
316          ALLOC(x, 2*N, celt_word32_t);
317          ALLOC(tmp, N, celt_word32_t);
318          /* Prevents problems from the imdct doing the overlap-add */
319          CELT_MEMSET(x+N4, 0, overlap);
320          for (b=0;b<B;b++)
321          {
322             /* De-interleaving the sub-frames */
323             for (j=0;j<N2;j++)
324                tmp[j] = X[C*(j*B+b)+c];
325             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
326          }
327          if (transient_shift > 0)
328          {
329 #ifdef FIXED_POINT
330             for (j=0;j<16;j++)
331                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
332             for (j=transient_time;j<N+overlap;j++)
333                x[N4+j] = SHL32(x[N4+j], transient_shift);
334 #else
335             for (j=0;j<16;j++)
336                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
337             for (j=transient_time;j<N+overlap;j++)
338                x[N4+j] *= 1<<transient_shift;
339 #endif
340          }
341          /* The first and last part would need to be set to zero if we actually
342          wanted to use them. */
343          for (j=0;j<overlap;j++)
344             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
345          for (j=0;j<overlap;j++)
346             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
347          for (j=0;j<2*N4;j++)
348             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
349          RESTORE_STACK;
350       }
351    }
352 }
353
354 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
355 {
356    int i, c, N, N4;
357    int has_pitch;
358    int pitch_index;
359    celt_word32_t curr_power, pitch_power;
360    VARDECL(celt_sig_t, in);
361    VARDECL(celt_sig_t, freq);
362    VARDECL(celt_norm_t, X);
363    VARDECL(celt_norm_t, P);
364    VARDECL(celt_ener_t, bandE);
365    VARDECL(celt_pgain_t, gains);
366    VARDECL(int, stereo_mode);
367 #ifdef EXP_PSY
368    VARDECL(celt_word32_t, mask);
369 #endif
370    int shortBlocks=0;
371    int transient_time;
372    int transient_shift;
373    float maxR;
374    const int C = CHANNELS(st->mode);
375    SAVE_STACK;
376
377    if (check_mode(st->mode) != CELT_OK)
378       return CELT_INVALID_MODE;
379
380    N = st->block_size;
381    N4 = (N-st->overlap)>>1;
382    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
383
384    CELT_COPY(in, st->in_mem, C*st->overlap);
385    for (c=0;c<C;c++)
386    {
387       const celt_int16_t * restrict pcmp = pcm+c;
388       celt_sig_t * restrict inp = in+C*st->overlap+c;
389       for (i=0;i<N;i++)
390       {
391          /* Apply pre-emphasis */
392          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
393          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
394          st->preemph_memE[c] = *pcmp;
395          inp += C;
396          pcmp += C;
397       }
398    }
399    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
400    
401    transient_time = transient_analysis(in, N+st->overlap, C, &maxR);
402    if (maxR > 30)
403    {
404       float gain_1;
405       ec_enc_bits(&st->enc, 1, 1);
406       if (maxR < 30)
407       {
408          transient_shift = 0;
409       } else if (maxR < 100)
410       {
411          transient_shift = 1;
412       } else if (maxR < 500)
413       {
414          transient_shift = 2;
415       } else
416       {
417          transient_shift = 3;
418       }
419       ec_enc_bits(&st->enc, transient_shift, 2);
420       if (transient_shift)
421          ec_enc_uint(&st->enc, transient_time, N+st->overlap);
422       if (transient_shift)
423       {
424 #ifdef FIXED_POINT
425          for (c=0;c<C;c++)
426             for (i=0;i<16;i++)
427                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
428          for (c=0;c<C;c++)
429             for (i=transient_time;i<N+st->overlap;i++)
430                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
431 #else
432          for (c=0;c<C;c++)
433             for (i=0;i<16;i++)
434                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
435          gain_1 = 1./(1<<transient_shift);
436          for (c=0;c<C;c++)
437             for (i=transient_time;i<N+st->overlap;i++)
438                in[C*i+c] *= gain_1;
439 #endif
440       }
441       shortBlocks = 1;
442    } else {
443       ec_enc_bits(&st->enc, 0, 1);
444       transient_time = -1;
445       transient_shift = 0;
446       shortBlocks = 0;
447    }
448    /* Pitch analysis: we do it early to save on the peak stack space */
449    if (!shortBlocks)
450       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
451
452    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
453    
454    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
455    /* Compute MDCTs */
456    compute_mdcts(st->mode, shortBlocks, in, freq);
457
458 #ifdef EXP_PSY
459    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
460    for (i=0;i<N;i++)
461       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
462    for (c=1;c<C;c++)
463       for (i=0;i<N;i++)
464          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
465
466    ALLOC(mask, N, celt_sig_t);
467    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
468
469    /* Invert and stretch the mask to length of X 
470       For some reason, I get better results by using the sqrt instead,
471       although there's no valid reason to. Must investigate further */
472    for (i=0;i<C*N;i++)
473       mask[i] = 1/(.1+mask[i]);
474 #endif
475    
476    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
477    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
478    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
479    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
480    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
481
482    /*printf ("%f %f\n", curr_power, pitch_power);*/
483    /*int j;
484    for (j=0;j<B*N;j++)
485       printf ("%f ", X[j]);
486    for (j=0;j<B*N;j++)
487       printf ("%f ", P[j]);
488    printf ("\n");*/
489
490    /* Band normalisation */
491    compute_band_energies(st->mode, freq, bandE);
492    normalise_bands(st->mode, freq, X, bandE);
493    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
494    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
495
496    /* Compute MDCTs of the pitch part */
497    if (!shortBlocks)
498       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
499
500    {
501       /* Normalise the pitch vector as well (discard the energies) */
502       VARDECL(celt_ener_t, bandEp);
503       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
504       compute_band_energies(st->mode, freq, bandEp);
505       normalise_bands(st->mode, freq, P, bandEp);
506       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
507    }
508    curr_power = bandE[0]+bandE[1]+bandE[2];
509    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
510    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
511    {
512       /* Simulates intensity stereo */
513       /*for (i=30;i<N*B;i++)
514          X[i*C+1] = P[i*C+1] = 0;*/
515
516       /* Pitch prediction */
517       compute_pitch_gain(st->mode, X, P, gains);
518       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
519       if (has_pitch)
520          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
521    } else {
522       /* No pitch, so we just pretend we found a gain of zero */
523       for (i=0;i<st->mode->nbPBands;i++)
524          gains[i] = 0;
525       ec_enc_bits(&st->enc, 0, 7);
526       for (i=0;i<C*N;i++)
527          P[i] = 0;
528    }
529    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
530
531    ALLOC(stereo_mode, st->mode->nbEBands, int);
532    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
533
534    pitch_quant_bands(st->mode, P, gains);
535
536    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
537
538    /* Residual quantisation */
539    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
540    
541    if (C==2)
542    {
543       renormalise_bands(st->mode, X);
544    }
545    /* Synthesis */
546    denormalise_bands(st->mode, X, freq, bandE);
547
548
549    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
550
551    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
552    /* De-emphasis and put everything back at the right place in the synthesis history */
553 #ifndef SHORTCUTS
554    for (c=0;c<C;c++)
555    {
556       int j;
557       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
558       celt_int16_t * restrict pcmp = pcm+c;
559       for (j=0;j<N;j++)
560       {
561          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
562          st->preemph_memD[c] = tmp;
563          *pcmp = SIG2INT16(tmp);
564          pcmp += C;
565          outp += C;
566       }
567    }
568 #endif
569    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
570       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
571    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
572    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
573    {
574       int val = 0;
575       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
576       {
577          ec_enc_uint(&st->enc, val, 2);
578          val = 1-val;
579       }
580    }
581    ec_enc_done(&st->enc);
582    {
583       unsigned char *data;
584       int nbBytes = ec_byte_bytes(&st->buf);
585       if (nbBytes > nbCompressedBytes)
586       {
587          celt_warning_int ("got too many bytes:", nbBytes);
588          RESTORE_STACK;
589          return CELT_INTERNAL_ERROR;
590       }
591       /*printf ("%d\n", *nbBytes);*/
592       data = ec_byte_get_buffer(&st->buf);
593       for (i=0;i<nbBytes;i++)
594          compressed[i] = data[i];
595       for (;i<nbCompressedBytes;i++)
596          compressed[i] = 0;
597    }
598    /* Reset the packing for the next encoding */
599    ec_byte_reset(&st->buf);
600    ec_enc_init(&st->enc,&st->buf);
601
602    RESTORE_STACK;
603    return nbCompressedBytes;
604 }
605
606
607 /****************************************************************************/
608 /*                                                                          */
609 /*                                DECODER                                   */
610 /*                                                                          */
611 /****************************************************************************/
612
613
614 /** Decoder state 
615  @brief Decoder state
616  */
617 struct CELTDecoder {
618    const CELTMode *mode;
619    int frame_size;
620    int block_size;
621    int overlap;
622
623    ec_byte_buffer buf;
624    ec_enc         enc;
625
626    celt_sig_t * restrict preemph_memD;
627
628    celt_sig_t *out_mem;
629
630    celt_word16_t *oldBandE;
631    
632    int last_pitch_index;
633 };
634
635 CELTDecoder *celt_decoder_create(const CELTMode *mode)
636 {
637    int N, C;
638    CELTDecoder *st;
639
640    if (check_mode(mode) != CELT_OK)
641       return NULL;
642
643    N = mode->mdctSize;
644    C = CHANNELS(mode);
645    st = celt_alloc(sizeof(CELTDecoder));
646    
647    st->mode = mode;
648    st->frame_size = N;
649    st->block_size = N;
650    st->overlap = mode->overlap;
651
652    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
653    
654    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
655
656    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
657
658    st->last_pitch_index = 0;
659    return st;
660 }
661
662 void celt_decoder_destroy(CELTDecoder *st)
663 {
664    if (st == NULL)
665    {
666       celt_warning("NULL passed to celt_encoder_destroy");
667       return;
668    }
669    if (check_mode(st->mode) != CELT_OK)
670       return;
671
672
673    celt_free(st->out_mem);
674    
675    celt_free(st->oldBandE);
676    
677    celt_free(st->preemph_memD);
678
679    celt_free(st);
680 }
681
682 /** Handles lost packets by just copying past data with the same offset as the last
683     pitch period */
684 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
685 {
686    int c, N;
687    int pitch_index;
688    int i, len;
689    VARDECL(celt_sig_t, freq);
690    const int C = CHANNELS(st->mode);
691    int offset;
692    SAVE_STACK;
693    N = st->block_size;
694    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
695    
696    len = N+st->mode->overlap;
697 #if 0
698    pitch_index = st->last_pitch_index;
699    
700    /* Use the pitch MDCT as the "guessed" signal */
701    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
702
703 #else
704    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
705    pitch_index = MAX_PERIOD-len-pitch_index;
706    offset = MAX_PERIOD-pitch_index;
707    while (offset+len >= MAX_PERIOD)
708       offset -= pitch_index;
709    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
710    for (i=0;i<N;i++)
711       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
712 #endif
713    
714    
715    
716    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
717    /* Compute inverse MDCTs */
718    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
719
720    for (c=0;c<C;c++)
721    {
722       int j;
723       for (j=0;j<N;j++)
724       {
725          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
726                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
727          st->preemph_memD[c] = tmp;
728          pcm[C*j+c] = SIG2INT16(tmp);
729       }
730    }
731    RESTORE_STACK;
732 }
733
734 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
735 {
736    int c, N, N4;
737    int has_pitch;
738    int pitch_index;
739    ec_dec dec;
740    ec_byte_buffer buf;
741    VARDECL(celt_sig_t, freq);
742    VARDECL(celt_norm_t, X);
743    VARDECL(celt_norm_t, P);
744    VARDECL(celt_ener_t, bandE);
745    VARDECL(celt_pgain_t, gains);
746    VARDECL(int, stereo_mode);
747    int shortBlocks;
748    int transient_time;
749    int transient_shift;
750    const int C = CHANNELS(st->mode);
751    SAVE_STACK;
752
753    if (check_mode(st->mode) != CELT_OK)
754       return CELT_INVALID_MODE;
755
756    N = st->block_size;
757    N4 = (N-st->overlap)>>1;
758
759    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
760    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
761    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
762    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
763    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
764    
765    if (check_mode(st->mode) != CELT_OK)
766    {
767       RESTORE_STACK;
768       return CELT_INVALID_MODE;
769    }
770    if (data == NULL)
771    {
772       celt_decode_lost(st, pcm);
773       RESTORE_STACK;
774       return 0;
775    }
776    
777    ec_byte_readinit(&buf,data,len);
778    ec_dec_init(&dec,&buf);
779    
780    shortBlocks = ec_dec_bits(&dec, 1);
781    if (shortBlocks)
782    {
783       transient_shift = ec_dec_bits(&dec, 2);
784       if (transient_shift)
785          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
786       else
787          transient_time = 0;
788    } else {
789       transient_time = -1;
790       transient_shift = 0;
791    }
792    /* Get the pitch gains */
793    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
794    
795    /* Get the pitch index */
796    if (has_pitch)
797    {
798       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
799       st->last_pitch_index = pitch_index;
800    } else {
801       /* FIXME: We could be more intelligent here and just not compute the MDCT */
802       pitch_index = 0;
803    }
804
805    /* Get band energies */
806    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
807
808    /* Pitch MDCT */
809    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
810
811    {
812       VARDECL(celt_ener_t, bandEp);
813       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
814       compute_band_energies(st->mode, freq, bandEp);
815       normalise_bands(st->mode, freq, P, bandEp);
816    }
817
818    ALLOC(stereo_mode, st->mode->nbEBands, int);
819    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
820    /* Apply pitch gains */
821    pitch_quant_bands(st->mode, P, gains);
822
823    /* Decode fixed codebook and merge with pitch */
824    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
825
826    if (C==2)
827    {
828       renormalise_bands(st->mode, X);
829    }
830    /* Synthesis */
831    denormalise_bands(st->mode, X, freq, bandE);
832
833
834    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
835    /* Compute inverse MDCTs */
836    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
837
838    for (c=0;c<C;c++)
839    {
840       int j;
841       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
842       celt_int16_t * restrict pcmp = pcm+c;
843       for (j=0;j<N;j++)
844       {
845          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
846          st->preemph_memD[c] = tmp;
847          *pcmp = SIG2INT16(tmp);
848          pcmp += C;
849          outp += C;
850       }
851    }
852
853    {
854       unsigned int val = 0;
855       while (ec_dec_tell(&dec, 0) < len*8)
856       {
857          if (ec_dec_uint(&dec, 2) != val)
858          {
859             celt_warning("decode error");
860             RESTORE_STACK;
861             return CELT_CORRUPTED_DATA;
862          }
863          val = 1-val;
864       }
865    }
866
867    RESTORE_STACK;
868    return 0;
869    /*printf ("\n");*/
870 }
871