Changing the allocation algorithm to better take into account the fixed cost
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167
168 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
169 {
170    int c, i, n;
171    celt_word32_t ratio;
172    /* FIXME: Remove the floats here */
173    VARDECL(celt_word32_t, begin);
174    SAVE_STACK;
175    ALLOC(begin, len, celt_word32_t);
176    for (i=0;i<len;i++)
177       begin[i] = EXTEND32(ABS16(SHR32(in[C*i],SIG_SHIFT)));
178    for (c=1;c<C;c++)
179    {
180       for (i=0;i<len;i++)
181          begin[i] = MAX32(begin[i], EXTEND32(ABS16(SHR32(in[C*i+c],SIG_SHIFT))));
182    }
183    for (i=1;i<len;i++)
184       begin[i] = MAX32(begin[i-1],begin[i]);
185    n = -1;
186    for (i=8;i<len-8;i++)
187    {
188       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
189          n=i;
190    }
191    if (n<32)
192    {
193       n = -1;
194       ratio = 0;
195    } else {
196       ratio = DIV32(begin[len-1],1+begin[n-16]);
197    }
198    /*printf ("%d %f\n", n, ratio*ratio);*/
199    if (ratio < 0)
200       ratio = 0;
201    if (ratio > 1000)
202       ratio = 1000;
203    ratio *= ratio;
204    if (ratio < 50)
205       *transient_shift = 0;
206    else if (ratio < 256)
207       *transient_shift = 1;
208    else if (ratio < 4096)
209       *transient_shift = 2;
210    else
211       *transient_shift = 3;
212    *transient_time = n;
213    
214    RESTORE_STACK;
215    return ratio > 20;
216 }
217
218 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
219 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
220 {
221    const int C = CHANNELS(mode);
222    if (C==1 && !shortBlocks)
223    {
224       const mdct_lookup *lookup = MDCT(mode);
225       const int overlap = OVERLAP(mode);
226       mdct_forward(lookup, in, out, mode->window, overlap);
227    } else if (!shortBlocks) {
228       const mdct_lookup *lookup = MDCT(mode);
229       const int overlap = OVERLAP(mode);
230       const int N = FRAMESIZE(mode);
231       int c;
232       VARDECL(celt_word32_t, x);
233       VARDECL(celt_word32_t, tmp);
234       SAVE_STACK;
235       ALLOC(x, N+overlap, celt_word32_t);
236       ALLOC(tmp, N, celt_word32_t);
237       for (c=0;c<C;c++)
238       {
239          int j;
240          for (j=0;j<N+overlap;j++)
241             x[j] = in[C*j+c];
242          mdct_forward(lookup, x, tmp, mode->window, overlap);
243          /* Interleaving the sub-frames */
244          for (j=0;j<N;j++)
245             out[C*j+c] = tmp[j];
246       }
247       RESTORE_STACK;
248    } else {
249       const mdct_lookup *lookup = &mode->shortMdct;
250       const int overlap = mode->shortMdctSize;
251       const int N = mode->shortMdctSize;
252       int b, c;
253       VARDECL(celt_word32_t, x);
254       VARDECL(celt_word32_t, tmp);
255       SAVE_STACK;
256       ALLOC(x, N+overlap, celt_word32_t);
257       ALLOC(tmp, N, celt_word32_t);
258       for (c=0;c<C;c++)
259       {
260          int B = mode->nbShortMdcts;
261          for (b=0;b<B;b++)
262          {
263             int j;
264             for (j=0;j<N+overlap;j++)
265                x[j] = in[C*(b*N+j)+c];
266             mdct_forward(lookup, x, tmp, mode->window, overlap);
267             /* Interleaving the sub-frames */
268             for (j=0;j<N;j++)
269                out[C*(j*B+b)+c] = tmp[j];
270          }
271       }
272       RESTORE_STACK;
273    }
274 }
275
276 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
277 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
278 {
279    int c, N4;
280    const int C = CHANNELS(mode);
281    const int N = FRAMESIZE(mode);
282    const int overlap = OVERLAP(mode);
283    N4 = (N-overlap)>>1;
284    for (c=0;c<C;c++)
285    {
286       int j;
287       if (transient_shift==0 && C==1 && !shortBlocks) {
288          const mdct_lookup *lookup = MDCT(mode);
289          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
290       } else if (!shortBlocks) {
291          const mdct_lookup *lookup = MDCT(mode);
292          VARDECL(celt_word32_t, x);
293          VARDECL(celt_word32_t, tmp);
294          SAVE_STACK;
295          ALLOC(x, 2*N, celt_word32_t);
296          ALLOC(tmp, N, celt_word32_t);
297          /* De-interleaving the sub-frames */
298          for (j=0;j<N;j++)
299             tmp[j] = X[C*j+c];
300          /* Prevents problems from the imdct doing the overlap-add */
301          CELT_MEMSET(x+N4, 0, overlap);
302          mdct_backward(lookup, tmp, x, mode->window, overlap);
303          celt_assert(transient_shift == 0);
304          /* The first and last part would need to be set to zero if we actually
305             wanted to use them. */
306          for (j=0;j<overlap;j++)
307             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
308          for (j=0;j<overlap;j++)
309             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
310          for (j=0;j<2*N4;j++)
311             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
312          RESTORE_STACK;
313       } else {
314          int b;
315          const int N2 = mode->shortMdctSize;
316          const int B = mode->nbShortMdcts;
317          const mdct_lookup *lookup = &mode->shortMdct;
318          VARDECL(celt_word32_t, x);
319          VARDECL(celt_word32_t, tmp);
320          SAVE_STACK;
321          ALLOC(x, 2*N, celt_word32_t);
322          ALLOC(tmp, N, celt_word32_t);
323          /* Prevents problems from the imdct doing the overlap-add */
324          CELT_MEMSET(x+N4, 0, overlap);
325          for (b=0;b<B;b++)
326          {
327             /* De-interleaving the sub-frames */
328             for (j=0;j<N2;j++)
329                tmp[j] = X[C*(j*B+b)+c];
330             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
331          }
332          if (transient_shift > 0)
333          {
334 #ifdef FIXED_POINT
335             for (j=0;j<16;j++)
336                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
337             for (j=transient_time;j<N+overlap;j++)
338                x[N4+j] = SHL32(x[N4+j], transient_shift);
339 #else
340             for (j=0;j<16;j++)
341                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
342             for (j=transient_time;j<N+overlap;j++)
343                x[N4+j] *= 1<<transient_shift;
344 #endif
345          }
346          /* The first and last part would need to be set to zero if we actually
347          wanted to use them. */
348          for (j=0;j<overlap;j++)
349             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
350          for (j=0;j<overlap;j++)
351             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
352          for (j=0;j<2*N4;j++)
353             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
354          RESTORE_STACK;
355       }
356    }
357 }
358
359 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
360 {
361    int i, c, N, N4;
362    int has_pitch;
363    int pitch_index;
364    int bits;
365    celt_word32_t curr_power, pitch_power;
366    VARDECL(celt_sig_t, in);
367    VARDECL(celt_sig_t, freq);
368    VARDECL(celt_norm_t, X);
369    VARDECL(celt_norm_t, P);
370    VARDECL(celt_ener_t, bandE);
371    VARDECL(celt_pgain_t, gains);
372    VARDECL(int, stereo_mode);
373    VARDECL(int, fine_quant);
374    VARDECL(celt_word16_t, error);
375    VARDECL(int, pulses);
376    VARDECL(int, offsets);
377 #ifdef EXP_PSY
378    VARDECL(celt_word32_t, mask);
379 #endif
380    int shortBlocks=0;
381    int transient_time;
382    int transient_shift;
383    const int C = CHANNELS(st->mode);
384    SAVE_STACK;
385
386    if (check_mode(st->mode) != CELT_OK)
387       return CELT_INVALID_MODE;
388
389    N = st->block_size;
390    N4 = (N-st->overlap)>>1;
391    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
392
393    CELT_COPY(in, st->in_mem, C*st->overlap);
394    for (c=0;c<C;c++)
395    {
396       const celt_int16_t * restrict pcmp = pcm+c;
397       celt_sig_t * restrict inp = in+C*st->overlap+c;
398       for (i=0;i<N;i++)
399       {
400          /* Apply pre-emphasis */
401          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
402          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
403          st->preemph_memE[c] = *pcmp;
404          inp += C;
405          pcmp += C;
406       }
407    }
408    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
409    
410    if (st->mode->nbShortMdcts > 1)
411    {
412       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
413       {
414 #ifndef FIXED_POINT
415          float gain_1;
416 #endif
417          ec_enc_bits(&st->enc, 1, 1);
418          ec_enc_bits(&st->enc, transient_shift, 2);
419          if (transient_shift)
420             ec_enc_uint(&st->enc, transient_time, N+st->overlap);
421          if (transient_shift)
422          {
423 #ifdef FIXED_POINT
424             for (c=0;c<C;c++)
425                for (i=0;i<16;i++)
426                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
427             for (c=0;c<C;c++)
428                for (i=transient_time;i<N+st->overlap;i++)
429                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
430 #else
431             for (c=0;c<C;c++)
432                for (i=0;i<16;i++)
433                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
434             gain_1 = 1./(1<<transient_shift);
435             for (c=0;c<C;c++)
436                for (i=transient_time;i<N+st->overlap;i++)
437                   in[C*i+c] *= gain_1;
438 #endif
439          }
440          shortBlocks = 1;
441       } else {
442          ec_enc_bits(&st->enc, 0, 1);
443          transient_time = -1;
444          transient_shift = 0;
445          shortBlocks = 0;
446       }
447    } else {
448       transient_time = -1;
449       transient_shift = 0;
450       shortBlocks = 0;
451    }
452    /* Pitch analysis: we do it early to save on the peak stack space */
453    if (!shortBlocks)
454       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
455
456    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
457    
458    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
459    /* Compute MDCTs */
460    compute_mdcts(st->mode, shortBlocks, in, freq);
461
462 #ifdef EXP_PSY
463    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
464    for (i=0;i<N;i++)
465       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
466    for (c=1;c<C;c++)
467       for (i=0;i<N;i++)
468          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
469
470    ALLOC(mask, N, celt_sig_t);
471    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
472
473    /* Invert and stretch the mask to length of X 
474       For some reason, I get better results by using the sqrt instead,
475       although there's no valid reason to. Must investigate further */
476    for (i=0;i<C*N;i++)
477       mask[i] = 1/(.1+mask[i]);
478 #endif
479    
480    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
481    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
482    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
483    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
484    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
485
486    /*printf ("%f %f\n", curr_power, pitch_power);*/
487    /*int j;
488    for (j=0;j<B*N;j++)
489       printf ("%f ", X[j]);
490    for (j=0;j<B*N;j++)
491       printf ("%f ", P[j]);
492    printf ("\n");*/
493
494    /* Band normalisation */
495    compute_band_energies(st->mode, freq, bandE);
496    normalise_bands(st->mode, freq, X, bandE);
497    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
498    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
499
500    /* Compute MDCTs of the pitch part */
501    if (!shortBlocks)
502       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
503
504    {
505       /* Normalise the pitch vector as well (discard the energies) */
506       VARDECL(celt_ener_t, bandEp);
507       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
508       compute_band_energies(st->mode, freq, bandEp);
509       normalise_bands(st->mode, freq, P, bandEp);
510       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
511    }
512    curr_power = bandE[0]+bandE[1]+bandE[2];
513    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
514    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
515    {
516       /* Simulates intensity stereo */
517       /*for (i=30;i<N*B;i++)
518          X[i*C+1] = P[i*C+1] = 0;*/
519
520       /* Pitch prediction */
521       compute_pitch_gain(st->mode, X, P, gains);
522       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
523       if (has_pitch)
524          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
525    } else {
526       /* No pitch, so we just pretend we found a gain of zero */
527       for (i=0;i<st->mode->nbPBands;i++)
528          gains[i] = 0;
529       ec_enc_bits(&st->enc, 0, 7);
530       for (i=0;i<C*N;i++)
531          P[i] = 0;
532    }
533
534    ALLOC(fine_quant, st->mode->nbEBands, int);
535    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
536    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &st->enc);
537    
538    ALLOC(pulses, st->mode->nbEBands, int);
539    ALLOC(offsets, st->mode->nbEBands, int);
540    ALLOC(stereo_mode, st->mode->nbEBands, int);
541    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
542
543    for (i=0;i<st->mode->nbEBands;i++)
544       offsets[i] = 0;
545    bits = nbCompressedBytes*8 - ec_enc_tell(&st->enc, 0) - 1;
546    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
547    /*for (i=0;i<st->mode->nbEBands;i++)
548       printf("%d ", fine_quant[i]);
549    for (i=0;i<st->mode->nbEBands;i++)
550       printf("%d ", pulses[i]);
551    printf ("\n");*/
552    /*bits = ec_enc_tell(&st->enc, 0);
553    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
554    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
555
556    pitch_quant_bands(st->mode, P, gains);
557
558    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
559
560    /* Residual quantisation */
561    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, &st->enc);
562    
563    if (C==2)
564    {
565       renormalise_bands(st->mode, X);
566    }
567    /* Synthesis */
568    denormalise_bands(st->mode, X, freq, bandE);
569
570
571    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
572
573    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
574    /* De-emphasis and put everything back at the right place in the synthesis history */
575 #ifndef SHORTCUTS
576    for (c=0;c<C;c++)
577    {
578       int j;
579       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
580       celt_int16_t * restrict pcmp = pcm+c;
581       for (j=0;j<N;j++)
582       {
583          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
584          st->preemph_memD[c] = tmp;
585          *pcmp = SIG2INT16(tmp);
586          pcmp += C;
587          outp += C;
588       }
589    }
590 #endif
591    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
592       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
593    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
594    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
595    {
596       int val = 0;
597       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
598       {
599          ec_enc_uint(&st->enc, val, 2);
600          val = 1-val;
601       }
602    }
603    ec_enc_done(&st->enc);
604    {
605       unsigned char *data;
606       int nbBytes = ec_byte_bytes(&st->buf);
607       if (nbBytes > nbCompressedBytes)
608       {
609          celt_warning_int ("got too many bytes:", nbBytes);
610          RESTORE_STACK;
611          return CELT_INTERNAL_ERROR;
612       }
613       /*printf ("%d\n", *nbBytes);*/
614       data = ec_byte_get_buffer(&st->buf);
615       for (i=0;i<nbBytes;i++)
616          compressed[i] = data[i];
617       for (;i<nbCompressedBytes;i++)
618          compressed[i] = 0;
619    }
620    /* Reset the packing for the next encoding */
621    ec_byte_reset(&st->buf);
622    ec_enc_init(&st->enc,&st->buf);
623
624    RESTORE_STACK;
625    return nbCompressedBytes;
626 }
627
628
629 /****************************************************************************/
630 /*                                                                          */
631 /*                                DECODER                                   */
632 /*                                                                          */
633 /****************************************************************************/
634
635
636 /** Decoder state 
637  @brief Decoder state
638  */
639 struct CELTDecoder {
640    const CELTMode *mode;
641    int frame_size;
642    int block_size;
643    int overlap;
644
645    ec_byte_buffer buf;
646    ec_enc         enc;
647
648    celt_sig_t * restrict preemph_memD;
649
650    celt_sig_t *out_mem;
651
652    celt_word16_t *oldBandE;
653    
654    int last_pitch_index;
655 };
656
657 CELTDecoder *celt_decoder_create(const CELTMode *mode)
658 {
659    int N, C;
660    CELTDecoder *st;
661
662    if (check_mode(mode) != CELT_OK)
663       return NULL;
664
665    N = mode->mdctSize;
666    C = CHANNELS(mode);
667    st = celt_alloc(sizeof(CELTDecoder));
668    
669    st->mode = mode;
670    st->frame_size = N;
671    st->block_size = N;
672    st->overlap = mode->overlap;
673
674    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
675    
676    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
677
678    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
679
680    st->last_pitch_index = 0;
681    return st;
682 }
683
684 void celt_decoder_destroy(CELTDecoder *st)
685 {
686    if (st == NULL)
687    {
688       celt_warning("NULL passed to celt_encoder_destroy");
689       return;
690    }
691    if (check_mode(st->mode) != CELT_OK)
692       return;
693
694
695    celt_free(st->out_mem);
696    
697    celt_free(st->oldBandE);
698    
699    celt_free(st->preemph_memD);
700
701    celt_free(st);
702 }
703
704 /** Handles lost packets by just copying past data with the same offset as the last
705     pitch period */
706 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
707 {
708    int c, N;
709    int pitch_index;
710    int i, len;
711    VARDECL(celt_sig_t, freq);
712    const int C = CHANNELS(st->mode);
713    int offset;
714    SAVE_STACK;
715    N = st->block_size;
716    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
717    
718    len = N+st->mode->overlap;
719 #if 0
720    pitch_index = st->last_pitch_index;
721    
722    /* Use the pitch MDCT as the "guessed" signal */
723    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
724
725 #else
726    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
727    pitch_index = MAX_PERIOD-len-pitch_index;
728    offset = MAX_PERIOD-pitch_index;
729    while (offset+len >= MAX_PERIOD)
730       offset -= pitch_index;
731    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
732    for (i=0;i<N;i++)
733       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
734 #endif
735    
736    
737    
738    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
739    /* Compute inverse MDCTs */
740    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
741
742    for (c=0;c<C;c++)
743    {
744       int j;
745       for (j=0;j<N;j++)
746       {
747          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
748                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
749          st->preemph_memD[c] = tmp;
750          pcm[C*j+c] = SIG2INT16(tmp);
751       }
752    }
753    RESTORE_STACK;
754 }
755
756 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
757 {
758    int i, c, N, N4;
759    int has_pitch;
760    int pitch_index;
761    int bits;
762    ec_dec dec;
763    ec_byte_buffer buf;
764    VARDECL(celt_sig_t, freq);
765    VARDECL(celt_norm_t, X);
766    VARDECL(celt_norm_t, P);
767    VARDECL(celt_ener_t, bandE);
768    VARDECL(celt_pgain_t, gains);
769    VARDECL(int, stereo_mode);
770    VARDECL(int, fine_quant);
771    VARDECL(int, pulses);
772    VARDECL(int, offsets);
773
774    int shortBlocks;
775    int transient_time;
776    int transient_shift;
777    const int C = CHANNELS(st->mode);
778    SAVE_STACK;
779
780    if (check_mode(st->mode) != CELT_OK)
781       return CELT_INVALID_MODE;
782
783    N = st->block_size;
784    N4 = (N-st->overlap)>>1;
785
786    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
787    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
788    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
789    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
790    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
791    
792    if (check_mode(st->mode) != CELT_OK)
793    {
794       RESTORE_STACK;
795       return CELT_INVALID_MODE;
796    }
797    if (data == NULL)
798    {
799       celt_decode_lost(st, pcm);
800       RESTORE_STACK;
801       return 0;
802    }
803    
804    ec_byte_readinit(&buf,data,len);
805    ec_dec_init(&dec,&buf);
806    
807    if (st->mode->nbShortMdcts > 1)
808    {
809       shortBlocks = ec_dec_bits(&dec, 1);
810       if (shortBlocks)
811       {
812          transient_shift = ec_dec_bits(&dec, 2);
813          if (transient_shift)
814             transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
815          else
816             transient_time = 0;
817       } else {
818          transient_time = -1;
819          transient_shift = 0;
820       }
821    } else {
822       shortBlocks = 0;
823       transient_time = -1;
824       transient_shift = 0;
825    }
826    /* Get the pitch gains */
827    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
828    
829    /* Get the pitch index */
830    if (has_pitch)
831    {
832       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
833       st->last_pitch_index = pitch_index;
834    } else {
835       /* FIXME: We could be more intelligent here and just not compute the MDCT */
836       pitch_index = 0;
837    }
838
839    ALLOC(fine_quant, st->mode->nbEBands, int);
840    /* Get band energies */
841    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
842    
843    ALLOC(pulses, st->mode->nbEBands, int);
844    ALLOC(offsets, st->mode->nbEBands, int);
845    ALLOC(stereo_mode, st->mode->nbEBands, int);
846    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
847
848    for (i=0;i<st->mode->nbEBands;i++)
849       offsets[i] = 0;
850
851    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
852    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
853    /*bits = ec_dec_tell(&dec, 0);
854    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
855    
856    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
857
858    /* Pitch MDCT */
859    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
860
861    {
862       VARDECL(celt_ener_t, bandEp);
863       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
864       compute_band_energies(st->mode, freq, bandEp);
865       normalise_bands(st->mode, freq, P, bandEp);
866    }
867
868    /* Apply pitch gains */
869    pitch_quant_bands(st->mode, P, gains);
870
871    /* Decode fixed codebook and merge with pitch */
872    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, &dec);
873
874    if (C==2)
875    {
876       renormalise_bands(st->mode, X);
877    }
878    /* Synthesis */
879    denormalise_bands(st->mode, X, freq, bandE);
880
881
882    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
883    /* Compute inverse MDCTs */
884    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
885
886    for (c=0;c<C;c++)
887    {
888       int j;
889       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
890       celt_int16_t * restrict pcmp = pcm+c;
891       for (j=0;j<N;j++)
892       {
893          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
894          st->preemph_memD[c] = tmp;
895          *pcmp = SIG2INT16(tmp);
896          pcmp += C;
897          outp += C;
898       }
899    }
900
901    {
902       unsigned int val = 0;
903       while (ec_dec_tell(&dec, 0) < len*8)
904       {
905          if (ec_dec_uint(&dec, 2) != val)
906          {
907             celt_warning("decode error");
908             RESTORE_STACK;
909             return CELT_CORRUPTED_DATA;
910          }
911          val = 1-val;
912       }
913    }
914
915    RESTORE_STACK;
916    return 0;
917    /*printf ("\n");*/
918 }
919