Completed the separation of coarse and fine energy quantisation
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167
168 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
169 {
170    int c, i, n;
171    celt_word32_t ratio;
172    /* FIXME: Remove the floats here */
173    VARDECL(celt_word32_t, begin);
174    SAVE_STACK;
175    ALLOC(begin, len, celt_word32_t);
176    for (i=0;i<len;i++)
177       begin[i] = EXTEND32(ABS16(SHR32(in[C*i],SIG_SHIFT)));
178    for (c=1;c<C;c++)
179    {
180       for (i=0;i<len;i++)
181          begin[i] = MAX32(begin[i], EXTEND32(ABS16(SHR32(in[C*i+c],SIG_SHIFT))));
182    }
183    for (i=1;i<len;i++)
184       begin[i] = MAX32(begin[i-1],begin[i]);
185    n = -1;
186    for (i=8;i<len-8;i++)
187    {
188       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
189          n=i;
190    }
191    if (n<32)
192    {
193       n = -1;
194       ratio = 0;
195    } else {
196       ratio = DIV32(begin[len-1],1+begin[n-16]);
197    }
198    /*printf ("%d %f\n", n, ratio*ratio);*/
199    if (ratio < 0)
200       ratio = 0;
201    if (ratio > 1000)
202       ratio = 1000;
203    ratio *= ratio;
204    if (ratio < 50)
205       *transient_shift = 0;
206    else if (ratio < 256)
207       *transient_shift = 1;
208    else if (ratio < 4096)
209       *transient_shift = 2;
210    else
211       *transient_shift = 3;
212    *transient_time = n;
213    
214    RESTORE_STACK;
215    return ratio > 20;
216 }
217
218 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
219 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
220 {
221    const int C = CHANNELS(mode);
222    if (C==1 && !shortBlocks)
223    {
224       const mdct_lookup *lookup = MDCT(mode);
225       const int overlap = OVERLAP(mode);
226       mdct_forward(lookup, in, out, mode->window, overlap);
227    } else if (!shortBlocks) {
228       const mdct_lookup *lookup = MDCT(mode);
229       const int overlap = OVERLAP(mode);
230       const int N = FRAMESIZE(mode);
231       int c;
232       VARDECL(celt_word32_t, x);
233       VARDECL(celt_word32_t, tmp);
234       SAVE_STACK;
235       ALLOC(x, N+overlap, celt_word32_t);
236       ALLOC(tmp, N, celt_word32_t);
237       for (c=0;c<C;c++)
238       {
239          int j;
240          for (j=0;j<N+overlap;j++)
241             x[j] = in[C*j+c];
242          mdct_forward(lookup, x, tmp, mode->window, overlap);
243          /* Interleaving the sub-frames */
244          for (j=0;j<N;j++)
245             out[C*j+c] = tmp[j];
246       }
247       RESTORE_STACK;
248    } else {
249       const mdct_lookup *lookup = &mode->shortMdct;
250       const int overlap = mode->shortMdctSize;
251       const int N = mode->shortMdctSize;
252       int b, c;
253       VARDECL(celt_word32_t, x);
254       VARDECL(celt_word32_t, tmp);
255       SAVE_STACK;
256       ALLOC(x, N+overlap, celt_word32_t);
257       ALLOC(tmp, N, celt_word32_t);
258       for (c=0;c<C;c++)
259       {
260          int B = mode->nbShortMdcts;
261          for (b=0;b<B;b++)
262          {
263             int j;
264             for (j=0;j<N+overlap;j++)
265                x[j] = in[C*(b*N+j)+c];
266             mdct_forward(lookup, x, tmp, mode->window, overlap);
267             /* Interleaving the sub-frames */
268             for (j=0;j<N;j++)
269                out[C*(j*B+b)+c] = tmp[j];
270          }
271       }
272       RESTORE_STACK;
273    }
274 }
275
276 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
277 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
278 {
279    int c, N4;
280    const int C = CHANNELS(mode);
281    const int N = FRAMESIZE(mode);
282    const int overlap = OVERLAP(mode);
283    N4 = (N-overlap)>>1;
284    for (c=0;c<C;c++)
285    {
286       int j;
287       if (transient_shift==0 && C==1 && !shortBlocks) {
288          const mdct_lookup *lookup = MDCT(mode);
289          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
290       } else if (!shortBlocks) {
291          const mdct_lookup *lookup = MDCT(mode);
292          VARDECL(celt_word32_t, x);
293          VARDECL(celt_word32_t, tmp);
294          SAVE_STACK;
295          ALLOC(x, 2*N, celt_word32_t);
296          ALLOC(tmp, N, celt_word32_t);
297          /* De-interleaving the sub-frames */
298          for (j=0;j<N;j++)
299             tmp[j] = X[C*j+c];
300          /* Prevents problems from the imdct doing the overlap-add */
301          CELT_MEMSET(x+N4, 0, overlap);
302          mdct_backward(lookup, tmp, x, mode->window, overlap);
303          celt_assert(transient_shift == 0);
304          /* The first and last part would need to be set to zero if we actually
305             wanted to use them. */
306          for (j=0;j<overlap;j++)
307             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
308          for (j=0;j<overlap;j++)
309             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
310          for (j=0;j<2*N4;j++)
311             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
312          RESTORE_STACK;
313       } else {
314          int b;
315          const int N2 = mode->shortMdctSize;
316          const int B = mode->nbShortMdcts;
317          const mdct_lookup *lookup = &mode->shortMdct;
318          VARDECL(celt_word32_t, x);
319          VARDECL(celt_word32_t, tmp);
320          SAVE_STACK;
321          ALLOC(x, 2*N, celt_word32_t);
322          ALLOC(tmp, N, celt_word32_t);
323          /* Prevents problems from the imdct doing the overlap-add */
324          CELT_MEMSET(x+N4, 0, overlap);
325          for (b=0;b<B;b++)
326          {
327             /* De-interleaving the sub-frames */
328             for (j=0;j<N2;j++)
329                tmp[j] = X[C*(j*B+b)+c];
330             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
331          }
332          if (transient_shift > 0)
333          {
334 #ifdef FIXED_POINT
335             for (j=0;j<16;j++)
336                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
337             for (j=transient_time;j<N+overlap;j++)
338                x[N4+j] = SHL32(x[N4+j], transient_shift);
339 #else
340             for (j=0;j<16;j++)
341                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
342             for (j=transient_time;j<N+overlap;j++)
343                x[N4+j] *= 1<<transient_shift;
344 #endif
345          }
346          /* The first and last part would need to be set to zero if we actually
347          wanted to use them. */
348          for (j=0;j<overlap;j++)
349             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
350          for (j=0;j<overlap;j++)
351             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
352          for (j=0;j<2*N4;j++)
353             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
354          RESTORE_STACK;
355       }
356    }
357 }
358
359 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
360 {
361    int i, c, N, N4;
362    int has_pitch;
363    int pitch_index;
364    int bits;
365    celt_word32_t curr_power, pitch_power;
366    VARDECL(celt_sig_t, in);
367    VARDECL(celt_sig_t, freq);
368    VARDECL(celt_norm_t, X);
369    VARDECL(celt_norm_t, P);
370    VARDECL(celt_ener_t, bandE);
371    VARDECL(celt_pgain_t, gains);
372    VARDECL(int, stereo_mode);
373    VARDECL(celt_int16_t, fine_quant);
374    VARDECL(celt_word16_t, error);
375 #ifdef EXP_PSY
376    VARDECL(celt_word32_t, mask);
377 #endif
378    int shortBlocks=0;
379    int transient_time;
380    int transient_shift;
381    const int C = CHANNELS(st->mode);
382    SAVE_STACK;
383
384    if (check_mode(st->mode) != CELT_OK)
385       return CELT_INVALID_MODE;
386
387    N = st->block_size;
388    N4 = (N-st->overlap)>>1;
389    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
390
391    CELT_COPY(in, st->in_mem, C*st->overlap);
392    for (c=0;c<C;c++)
393    {
394       const celt_int16_t * restrict pcmp = pcm+c;
395       celt_sig_t * restrict inp = in+C*st->overlap+c;
396       for (i=0;i<N;i++)
397       {
398          /* Apply pre-emphasis */
399          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
400          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
401          st->preemph_memE[c] = *pcmp;
402          inp += C;
403          pcmp += C;
404       }
405    }
406    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
407    
408    if (st->mode->nbShortMdcts > 1)
409    {
410       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
411       {
412 #ifndef FIXED_POINT
413          float gain_1;
414 #endif
415          ec_enc_bits(&st->enc, 1, 1);
416          ec_enc_bits(&st->enc, transient_shift, 2);
417          if (transient_shift)
418             ec_enc_uint(&st->enc, transient_time, N+st->overlap);
419          if (transient_shift)
420          {
421 #ifdef FIXED_POINT
422             for (c=0;c<C;c++)
423                for (i=0;i<16;i++)
424                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
425             for (c=0;c<C;c++)
426                for (i=transient_time;i<N+st->overlap;i++)
427                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
428 #else
429             for (c=0;c<C;c++)
430                for (i=0;i<16;i++)
431                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
432             gain_1 = 1./(1<<transient_shift);
433             for (c=0;c<C;c++)
434                for (i=transient_time;i<N+st->overlap;i++)
435                   in[C*i+c] *= gain_1;
436 #endif
437          }
438          shortBlocks = 1;
439       } else {
440          ec_enc_bits(&st->enc, 0, 1);
441          transient_time = -1;
442          transient_shift = 0;
443          shortBlocks = 0;
444       }
445    } else {
446       transient_time = -1;
447       transient_shift = 0;
448       shortBlocks = 0;
449    }
450    /* Pitch analysis: we do it early to save on the peak stack space */
451    if (!shortBlocks)
452       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
453
454    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
455    
456    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
457    /* Compute MDCTs */
458    compute_mdcts(st->mode, shortBlocks, in, freq);
459
460 #ifdef EXP_PSY
461    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
462    for (i=0;i<N;i++)
463       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
464    for (c=1;c<C;c++)
465       for (i=0;i<N;i++)
466          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
467
468    ALLOC(mask, N, celt_sig_t);
469    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
470
471    /* Invert and stretch the mask to length of X 
472       For some reason, I get better results by using the sqrt instead,
473       although there's no valid reason to. Must investigate further */
474    for (i=0;i<C*N;i++)
475       mask[i] = 1/(.1+mask[i]);
476 #endif
477    
478    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
479    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
480    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
481    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
482    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
483
484    /*printf ("%f %f\n", curr_power, pitch_power);*/
485    /*int j;
486    for (j=0;j<B*N;j++)
487       printf ("%f ", X[j]);
488    for (j=0;j<B*N;j++)
489       printf ("%f ", P[j]);
490    printf ("\n");*/
491
492    /* Band normalisation */
493    compute_band_energies(st->mode, freq, bandE);
494    normalise_bands(st->mode, freq, X, bandE);
495    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
496    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
497
498    /* Compute MDCTs of the pitch part */
499    if (!shortBlocks)
500       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
501
502    {
503       /* Normalise the pitch vector as well (discard the energies) */
504       VARDECL(celt_ener_t, bandEp);
505       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
506       compute_band_energies(st->mode, freq, bandEp);
507       normalise_bands(st->mode, freq, P, bandEp);
508       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
509    }
510    curr_power = bandE[0]+bandE[1]+bandE[2];
511    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
512    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
513    {
514       /* Simulates intensity stereo */
515       /*for (i=30;i<N*B;i++)
516          X[i*C+1] = P[i*C+1] = 0;*/
517
518       /* Pitch prediction */
519       compute_pitch_gain(st->mode, X, P, gains);
520       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
521       if (has_pitch)
522          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
523    } else {
524       /* No pitch, so we just pretend we found a gain of zero */
525       for (i=0;i<st->mode->nbPBands;i++)
526          gains[i] = 0;
527       ec_enc_bits(&st->enc, 0, 7);
528       for (i=0;i<C*N;i++)
529          P[i] = 0;
530    }
531
532    ALLOC(fine_quant, st->mode->nbEBands, celt_int16_t);
533    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
534    bits = ec_enc_tell(&st->enc, 0);
535    quant_coarse_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8, st->mode->prob, error, &st->enc);
536    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);
537    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
538
539    ALLOC(stereo_mode, st->mode->nbEBands, int);
540    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
541
542    pitch_quant_bands(st->mode, P, gains);
543
544    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
545
546    /* Residual quantisation */
547    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
548    
549    if (C==2)
550    {
551       renormalise_bands(st->mode, X);
552    }
553    /* Synthesis */
554    denormalise_bands(st->mode, X, freq, bandE);
555
556
557    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
558
559    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
560    /* De-emphasis and put everything back at the right place in the synthesis history */
561 #ifndef SHORTCUTS
562    for (c=0;c<C;c++)
563    {
564       int j;
565       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
566       celt_int16_t * restrict pcmp = pcm+c;
567       for (j=0;j<N;j++)
568       {
569          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
570          st->preemph_memD[c] = tmp;
571          *pcmp = SIG2INT16(tmp);
572          pcmp += C;
573          outp += C;
574       }
575    }
576 #endif
577    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
578       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
579    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
580    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
581    {
582       int val = 0;
583       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
584       {
585          ec_enc_uint(&st->enc, val, 2);
586          val = 1-val;
587       }
588    }
589    ec_enc_done(&st->enc);
590    {
591       unsigned char *data;
592       int nbBytes = ec_byte_bytes(&st->buf);
593       if (nbBytes > nbCompressedBytes)
594       {
595          celt_warning_int ("got too many bytes:", nbBytes);
596          RESTORE_STACK;
597          return CELT_INTERNAL_ERROR;
598       }
599       /*printf ("%d\n", *nbBytes);*/
600       data = ec_byte_get_buffer(&st->buf);
601       for (i=0;i<nbBytes;i++)
602          compressed[i] = data[i];
603       for (;i<nbCompressedBytes;i++)
604          compressed[i] = 0;
605    }
606    /* Reset the packing for the next encoding */
607    ec_byte_reset(&st->buf);
608    ec_enc_init(&st->enc,&st->buf);
609
610    RESTORE_STACK;
611    return nbCompressedBytes;
612 }
613
614
615 /****************************************************************************/
616 /*                                                                          */
617 /*                                DECODER                                   */
618 /*                                                                          */
619 /****************************************************************************/
620
621
622 /** Decoder state 
623  @brief Decoder state
624  */
625 struct CELTDecoder {
626    const CELTMode *mode;
627    int frame_size;
628    int block_size;
629    int overlap;
630
631    ec_byte_buffer buf;
632    ec_enc         enc;
633
634    celt_sig_t * restrict preemph_memD;
635
636    celt_sig_t *out_mem;
637
638    celt_word16_t *oldBandE;
639    
640    int last_pitch_index;
641 };
642
643 CELTDecoder *celt_decoder_create(const CELTMode *mode)
644 {
645    int N, C;
646    CELTDecoder *st;
647
648    if (check_mode(mode) != CELT_OK)
649       return NULL;
650
651    N = mode->mdctSize;
652    C = CHANNELS(mode);
653    st = celt_alloc(sizeof(CELTDecoder));
654    
655    st->mode = mode;
656    st->frame_size = N;
657    st->block_size = N;
658    st->overlap = mode->overlap;
659
660    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
661    
662    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
663
664    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
665
666    st->last_pitch_index = 0;
667    return st;
668 }
669
670 void celt_decoder_destroy(CELTDecoder *st)
671 {
672    if (st == NULL)
673    {
674       celt_warning("NULL passed to celt_encoder_destroy");
675       return;
676    }
677    if (check_mode(st->mode) != CELT_OK)
678       return;
679
680
681    celt_free(st->out_mem);
682    
683    celt_free(st->oldBandE);
684    
685    celt_free(st->preemph_memD);
686
687    celt_free(st);
688 }
689
690 /** Handles lost packets by just copying past data with the same offset as the last
691     pitch period */
692 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
693 {
694    int c, N;
695    int pitch_index;
696    int i, len;
697    VARDECL(celt_sig_t, freq);
698    const int C = CHANNELS(st->mode);
699    int offset;
700    SAVE_STACK;
701    N = st->block_size;
702    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
703    
704    len = N+st->mode->overlap;
705 #if 0
706    pitch_index = st->last_pitch_index;
707    
708    /* Use the pitch MDCT as the "guessed" signal */
709    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
710
711 #else
712    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
713    pitch_index = MAX_PERIOD-len-pitch_index;
714    offset = MAX_PERIOD-pitch_index;
715    while (offset+len >= MAX_PERIOD)
716       offset -= pitch_index;
717    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
718    for (i=0;i<N;i++)
719       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
720 #endif
721    
722    
723    
724    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
725    /* Compute inverse MDCTs */
726    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
727
728    for (c=0;c<C;c++)
729    {
730       int j;
731       for (j=0;j<N;j++)
732       {
733          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
734                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
735          st->preemph_memD[c] = tmp;
736          pcm[C*j+c] = SIG2INT16(tmp);
737       }
738    }
739    RESTORE_STACK;
740 }
741
742 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
743 {
744    int c, N, N4;
745    int has_pitch;
746    int pitch_index;
747    int bits;
748    ec_dec dec;
749    ec_byte_buffer buf;
750    VARDECL(celt_sig_t, freq);
751    VARDECL(celt_norm_t, X);
752    VARDECL(celt_norm_t, P);
753    VARDECL(celt_ener_t, bandE);
754    VARDECL(celt_pgain_t, gains);
755    VARDECL(int, stereo_mode);
756    VARDECL(celt_int16_t, fine_quant);
757
758    int shortBlocks;
759    int transient_time;
760    int transient_shift;
761    const int C = CHANNELS(st->mode);
762    SAVE_STACK;
763
764    if (check_mode(st->mode) != CELT_OK)
765       return CELT_INVALID_MODE;
766
767    N = st->block_size;
768    N4 = (N-st->overlap)>>1;
769
770    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
771    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
772    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
773    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
774    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
775    
776    if (check_mode(st->mode) != CELT_OK)
777    {
778       RESTORE_STACK;
779       return CELT_INVALID_MODE;
780    }
781    if (data == NULL)
782    {
783       celt_decode_lost(st, pcm);
784       RESTORE_STACK;
785       return 0;
786    }
787    
788    ec_byte_readinit(&buf,data,len);
789    ec_dec_init(&dec,&buf);
790    
791    if (st->mode->nbShortMdcts > 1)
792    {
793       shortBlocks = ec_dec_bits(&dec, 1);
794       if (shortBlocks)
795       {
796          transient_shift = ec_dec_bits(&dec, 2);
797          if (transient_shift)
798             transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
799          else
800             transient_time = 0;
801       } else {
802          transient_time = -1;
803          transient_shift = 0;
804       }
805    } else {
806       shortBlocks = 0;
807       transient_time = -1;
808       transient_shift = 0;
809    }
810    /* Get the pitch gains */
811    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
812    
813    /* Get the pitch index */
814    if (has_pitch)
815    {
816       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
817       st->last_pitch_index = pitch_index;
818    } else {
819       /* FIXME: We could be more intelligent here and just not compute the MDCT */
820       pitch_index = 0;
821    }
822
823    ALLOC(fine_quant, st->mode->nbEBands, celt_int16_t);
824    bits = ec_dec_tell(&dec, 0);
825    /* Get band energies */
826    unquant_coarse_energy(st->mode, bandE, st->oldBandE, 20*C+len*8, st->mode->prob, &dec);
827    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);
828    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
829
830    /* Pitch MDCT */
831    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
832
833    {
834       VARDECL(celt_ener_t, bandEp);
835       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
836       compute_band_energies(st->mode, freq, bandEp);
837       normalise_bands(st->mode, freq, P, bandEp);
838    }
839
840    ALLOC(stereo_mode, st->mode->nbEBands, int);
841    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
842    /* Apply pitch gains */
843    pitch_quant_bands(st->mode, P, gains);
844
845    /* Decode fixed codebook and merge with pitch */
846    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
847
848    if (C==2)
849    {
850       renormalise_bands(st->mode, X);
851    }
852    /* Synthesis */
853    denormalise_bands(st->mode, X, freq, bandE);
854
855
856    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
857    /* Compute inverse MDCTs */
858    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
859
860    for (c=0;c<C;c++)
861    {
862       int j;
863       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
864       celt_int16_t * restrict pcmp = pcm+c;
865       for (j=0;j<N;j++)
866       {
867          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
868          st->preemph_memD[c] = tmp;
869          *pcmp = SIG2INT16(tmp);
870          pcmp += C;
871          outp += C;
872       }
873    }
874
875    {
876       unsigned int val = 0;
877       while (ec_dec_tell(&dec, 0) < len*8)
878       {
879          if (ec_dec_uint(&dec, 2) != val)
880          {
881             celt_warning("decode error");
882             RESTORE_STACK;
883             return CELT_CORRUPTED_DATA;
884          }
885          val = 1-val;
886       }
887    }
888
889    RESTORE_STACK;
890    return 0;
891    /*printf ("\n");*/
892 }
893