fixed-point: more work on the time window (almost done)
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167
168 static int transient_analysis(celt_word32_t *in, int len, int C, celt_word32_t *r)
169 {
170    int c, i, n;
171    celt_word32_t ratio;
172    /* FIXME: Remove the floats here */
173    float maxN, maxD;
174    VARDECL(celt_word32_t, begin);
175    SAVE_STACK;
176    ALLOC(begin, len, celt_word32_t);
177    
178    for (i=0;i<len;i++)
179       begin[i] = EXTEND32(ABS16(SHR32(in[C*i],SIG_SHIFT)));
180    for (c=1;c<C;c++)
181    {
182       for (i=0;i<len;i++)
183          begin[i] = ADD32(begin[i], EXTEND32(ABS16(SHR32(in[C*i+c],SIG_SHIFT))));
184    }
185    for (i=1;i<len;i++)
186       begin[i] = begin[i-1]+begin[i];
187
188    maxD = VERY_LARGE32;
189    maxN = 0;
190    n = -1;
191    for (i=8;i<len-8;i++)
192    {
193       celt_word32_t endi;
194       celt_word32_t num, den;
195       endi = begin[len-1]-begin[i];
196       num = endi*i;
197       den = (30+begin[i])*(len-i)+MULT16_32_Q15(QCONST16(.1f,15),endi)*len;
198       if ((num*maxD > den*maxN) && (endi > MULT16_32_Q15(QCONST16(.05f,15),begin[i])))
199       {
200          maxN = num;
201          maxD = den;
202          n = i;
203       }
204    }
205    ratio = DIV32((begin[len-1]-begin[n])*n,(10+begin[n])*(len-n));
206    if (n<32)
207    {
208       n = -1;
209       ratio = 0;
210    }
211    *r = ratio*ratio;
212    RESTORE_STACK;
213    return n;
214 }
215
216 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
217 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
218 {
219    const int C = CHANNELS(mode);
220    if (C==1 && !shortBlocks)
221    {
222       const mdct_lookup *lookup = MDCT(mode);
223       const int overlap = OVERLAP(mode);
224       mdct_forward(lookup, in, out, mode->window, overlap);
225    } else if (!shortBlocks) {
226       const mdct_lookup *lookup = MDCT(mode);
227       const int overlap = OVERLAP(mode);
228       const int N = FRAMESIZE(mode);
229       int c;
230       VARDECL(celt_word32_t, x);
231       VARDECL(celt_word32_t, tmp);
232       SAVE_STACK;
233       ALLOC(x, N+overlap, celt_word32_t);
234       ALLOC(tmp, N, celt_word32_t);
235       for (c=0;c<C;c++)
236       {
237          int j;
238          for (j=0;j<N+overlap;j++)
239             x[j] = in[C*j+c];
240          mdct_forward(lookup, x, tmp, mode->window, overlap);
241          /* Interleaving the sub-frames */
242          for (j=0;j<N;j++)
243             out[C*j+c] = tmp[j];
244       }
245       RESTORE_STACK;
246    } else {
247       const mdct_lookup *lookup = &mode->shortMdct;
248       const int overlap = mode->shortMdctSize;
249       const int N = mode->shortMdctSize;
250       int b, c;
251       VARDECL(celt_word32_t, x);
252       VARDECL(celt_word32_t, tmp);
253       SAVE_STACK;
254       ALLOC(x, N+overlap, celt_word32_t);
255       ALLOC(tmp, N, celt_word32_t);
256       for (c=0;c<C;c++)
257       {
258          int B = mode->nbShortMdcts;
259          for (b=0;b<B;b++)
260          {
261             int j;
262             for (j=0;j<N+overlap;j++)
263                x[j] = in[C*(b*N+j)+c];
264             mdct_forward(lookup, x, tmp, mode->window, overlap);
265             /* Interleaving the sub-frames */
266             for (j=0;j<N;j++)
267                out[C*(j*B+b)+c] = tmp[j];
268          }
269       }
270       RESTORE_STACK;
271    }
272 }
273
274 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
275 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
276 {
277    int c, N4;
278    const int C = CHANNELS(mode);
279    const int N = FRAMESIZE(mode);
280    const int overlap = OVERLAP(mode);
281    N4 = (N-overlap)>>1;
282    for (c=0;c<C;c++)
283    {
284       int j;
285       if (transient_shift==0 && C==1 && !shortBlocks) {
286          const mdct_lookup *lookup = MDCT(mode);
287          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
288       } else if (!shortBlocks) {
289          const mdct_lookup *lookup = MDCT(mode);
290          VARDECL(celt_word32_t, x);
291          VARDECL(celt_word32_t, tmp);
292          SAVE_STACK;
293          ALLOC(x, 2*N, celt_word32_t);
294          ALLOC(tmp, N, celt_word32_t);
295          /* De-interleaving the sub-frames */
296          for (j=0;j<N;j++)
297             tmp[j] = X[C*j+c];
298          /* Prevents problems from the imdct doing the overlap-add */
299          CELT_MEMSET(x+N4, 0, overlap);
300          mdct_backward(lookup, tmp, x, mode->window, overlap);
301          celt_assert(transient_shift == 0);
302          /* The first and last part would need to be set to zero if we actually
303             wanted to use them. */
304          for (j=0;j<overlap;j++)
305             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
306          for (j=0;j<overlap;j++)
307             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
308          for (j=0;j<2*N4;j++)
309             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
310          RESTORE_STACK;
311       } else {
312          int b;
313          const int N2 = mode->shortMdctSize;
314          const int B = mode->nbShortMdcts;
315          const mdct_lookup *lookup = &mode->shortMdct;
316          VARDECL(celt_word32_t, x);
317          VARDECL(celt_word32_t, tmp);
318          SAVE_STACK;
319          ALLOC(x, 2*N, celt_word32_t);
320          ALLOC(tmp, N, celt_word32_t);
321          /* Prevents problems from the imdct doing the overlap-add */
322          CELT_MEMSET(x+N4, 0, overlap);
323          for (b=0;b<B;b++)
324          {
325             /* De-interleaving the sub-frames */
326             for (j=0;j<N2;j++)
327                tmp[j] = X[C*(j*B+b)+c];
328             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
329          }
330          if (transient_shift > 0)
331          {
332 #ifdef FIXED_POINT
333             for (j=0;j<16;j++)
334                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
335             for (j=transient_time;j<N+overlap;j++)
336                x[N4+j] = SHL32(x[N4+j], transient_shift);
337 #else
338             for (j=0;j<16;j++)
339                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
340             for (j=transient_time;j<N+overlap;j++)
341                x[N4+j] *= 1<<transient_shift;
342 #endif
343          }
344          /* The first and last part would need to be set to zero if we actually
345          wanted to use them. */
346          for (j=0;j<overlap;j++)
347             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
348          for (j=0;j<overlap;j++)
349             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
350          for (j=0;j<2*N4;j++)
351             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
352          RESTORE_STACK;
353       }
354    }
355 }
356
357 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
358 {
359    int i, c, N, N4;
360    int has_pitch;
361    int pitch_index;
362    celt_word32_t curr_power, pitch_power;
363    VARDECL(celt_sig_t, in);
364    VARDECL(celt_sig_t, freq);
365    VARDECL(celt_norm_t, X);
366    VARDECL(celt_norm_t, P);
367    VARDECL(celt_ener_t, bandE);
368    VARDECL(celt_pgain_t, gains);
369    VARDECL(int, stereo_mode);
370 #ifdef EXP_PSY
371    VARDECL(celt_word32_t, mask);
372 #endif
373    int shortBlocks=0;
374    int transient_time;
375    int transient_shift;
376    celt_word32_t maxR;
377    const int C = CHANNELS(st->mode);
378    SAVE_STACK;
379
380    if (check_mode(st->mode) != CELT_OK)
381       return CELT_INVALID_MODE;
382
383    N = st->block_size;
384    N4 = (N-st->overlap)>>1;
385    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
386
387    CELT_COPY(in, st->in_mem, C*st->overlap);
388    for (c=0;c<C;c++)
389    {
390       const celt_int16_t * restrict pcmp = pcm+c;
391       celt_sig_t * restrict inp = in+C*st->overlap+c;
392       for (i=0;i<N;i++)
393       {
394          /* Apply pre-emphasis */
395          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
396          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
397          st->preemph_memE[c] = *pcmp;
398          inp += C;
399          pcmp += C;
400       }
401    }
402    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
403    
404    transient_time = transient_analysis(in, N+st->overlap, C, &maxR);
405    if (maxR > 30)
406    {
407 #ifndef FIXED_POINT
408       float gain_1;
409 #endif
410       ec_enc_bits(&st->enc, 1, 1);
411       if (maxR < 30)
412       {
413          transient_shift = 0;
414       } else if (maxR < 100)
415       {
416          transient_shift = 1;
417       } else if (maxR < 500)
418       {
419          transient_shift = 2;
420       } else
421       {
422          transient_shift = 3;
423       }
424       ec_enc_bits(&st->enc, transient_shift, 2);
425       if (transient_shift)
426          ec_enc_uint(&st->enc, transient_time, N+st->overlap);
427       if (transient_shift)
428       {
429 #ifdef FIXED_POINT
430          for (c=0;c<C;c++)
431             for (i=0;i<16;i++)
432                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
433          for (c=0;c<C;c++)
434             for (i=transient_time;i<N+st->overlap;i++)
435                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
436 #else
437          for (c=0;c<C;c++)
438             for (i=0;i<16;i++)
439                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
440          gain_1 = 1./(1<<transient_shift);
441          for (c=0;c<C;c++)
442             for (i=transient_time;i<N+st->overlap;i++)
443                in[C*i+c] *= gain_1;
444 #endif
445       }
446       shortBlocks = 1;
447    } else {
448       ec_enc_bits(&st->enc, 0, 1);
449       transient_time = -1;
450       transient_shift = 0;
451       shortBlocks = 0;
452    }
453    /* Pitch analysis: we do it early to save on the peak stack space */
454    if (!shortBlocks)
455       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
456
457    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
458    
459    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
460    /* Compute MDCTs */
461    compute_mdcts(st->mode, shortBlocks, in, freq);
462
463 #ifdef EXP_PSY
464    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
465    for (i=0;i<N;i++)
466       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
467    for (c=1;c<C;c++)
468       for (i=0;i<N;i++)
469          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
470
471    ALLOC(mask, N, celt_sig_t);
472    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
473
474    /* Invert and stretch the mask to length of X 
475       For some reason, I get better results by using the sqrt instead,
476       although there's no valid reason to. Must investigate further */
477    for (i=0;i<C*N;i++)
478       mask[i] = 1/(.1+mask[i]);
479 #endif
480    
481    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
482    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
483    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
484    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
485    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
486
487    /*printf ("%f %f\n", curr_power, pitch_power);*/
488    /*int j;
489    for (j=0;j<B*N;j++)
490       printf ("%f ", X[j]);
491    for (j=0;j<B*N;j++)
492       printf ("%f ", P[j]);
493    printf ("\n");*/
494
495    /* Band normalisation */
496    compute_band_energies(st->mode, freq, bandE);
497    normalise_bands(st->mode, freq, X, bandE);
498    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
499    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
500
501    /* Compute MDCTs of the pitch part */
502    if (!shortBlocks)
503       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
504
505    {
506       /* Normalise the pitch vector as well (discard the energies) */
507       VARDECL(celt_ener_t, bandEp);
508       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
509       compute_band_energies(st->mode, freq, bandEp);
510       normalise_bands(st->mode, freq, P, bandEp);
511       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
512    }
513    curr_power = bandE[0]+bandE[1]+bandE[2];
514    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
515    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
516    {
517       /* Simulates intensity stereo */
518       /*for (i=30;i<N*B;i++)
519          X[i*C+1] = P[i*C+1] = 0;*/
520
521       /* Pitch prediction */
522       compute_pitch_gain(st->mode, X, P, gains);
523       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
524       if (has_pitch)
525          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
526    } else {
527       /* No pitch, so we just pretend we found a gain of zero */
528       for (i=0;i<st->mode->nbPBands;i++)
529          gains[i] = 0;
530       ec_enc_bits(&st->enc, 0, 7);
531       for (i=0;i<C*N;i++)
532          P[i] = 0;
533    }
534    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
535
536    ALLOC(stereo_mode, st->mode->nbEBands, int);
537    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
538
539    pitch_quant_bands(st->mode, P, gains);
540
541    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
542
543    /* Residual quantisation */
544    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
545    
546    if (C==2)
547    {
548       renormalise_bands(st->mode, X);
549    }
550    /* Synthesis */
551    denormalise_bands(st->mode, X, freq, bandE);
552
553
554    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
555
556    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
557    /* De-emphasis and put everything back at the right place in the synthesis history */
558 #ifndef SHORTCUTS
559    for (c=0;c<C;c++)
560    {
561       int j;
562       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
563       celt_int16_t * restrict pcmp = pcm+c;
564       for (j=0;j<N;j++)
565       {
566          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
567          st->preemph_memD[c] = tmp;
568          *pcmp = SIG2INT16(tmp);
569          pcmp += C;
570          outp += C;
571       }
572    }
573 #endif
574    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
575       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
576    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
577    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
578    {
579       int val = 0;
580       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
581       {
582          ec_enc_uint(&st->enc, val, 2);
583          val = 1-val;
584       }
585    }
586    ec_enc_done(&st->enc);
587    {
588       unsigned char *data;
589       int nbBytes = ec_byte_bytes(&st->buf);
590       if (nbBytes > nbCompressedBytes)
591       {
592          celt_warning_int ("got too many bytes:", nbBytes);
593          RESTORE_STACK;
594          return CELT_INTERNAL_ERROR;
595       }
596       /*printf ("%d\n", *nbBytes);*/
597       data = ec_byte_get_buffer(&st->buf);
598       for (i=0;i<nbBytes;i++)
599          compressed[i] = data[i];
600       for (;i<nbCompressedBytes;i++)
601          compressed[i] = 0;
602    }
603    /* Reset the packing for the next encoding */
604    ec_byte_reset(&st->buf);
605    ec_enc_init(&st->enc,&st->buf);
606
607    RESTORE_STACK;
608    return nbCompressedBytes;
609 }
610
611
612 /****************************************************************************/
613 /*                                                                          */
614 /*                                DECODER                                   */
615 /*                                                                          */
616 /****************************************************************************/
617
618
619 /** Decoder state 
620  @brief Decoder state
621  */
622 struct CELTDecoder {
623    const CELTMode *mode;
624    int frame_size;
625    int block_size;
626    int overlap;
627
628    ec_byte_buffer buf;
629    ec_enc         enc;
630
631    celt_sig_t * restrict preemph_memD;
632
633    celt_sig_t *out_mem;
634
635    celt_word16_t *oldBandE;
636    
637    int last_pitch_index;
638 };
639
640 CELTDecoder *celt_decoder_create(const CELTMode *mode)
641 {
642    int N, C;
643    CELTDecoder *st;
644
645    if (check_mode(mode) != CELT_OK)
646       return NULL;
647
648    N = mode->mdctSize;
649    C = CHANNELS(mode);
650    st = celt_alloc(sizeof(CELTDecoder));
651    
652    st->mode = mode;
653    st->frame_size = N;
654    st->block_size = N;
655    st->overlap = mode->overlap;
656
657    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
658    
659    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
660
661    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
662
663    st->last_pitch_index = 0;
664    return st;
665 }
666
667 void celt_decoder_destroy(CELTDecoder *st)
668 {
669    if (st == NULL)
670    {
671       celt_warning("NULL passed to celt_encoder_destroy");
672       return;
673    }
674    if (check_mode(st->mode) != CELT_OK)
675       return;
676
677
678    celt_free(st->out_mem);
679    
680    celt_free(st->oldBandE);
681    
682    celt_free(st->preemph_memD);
683
684    celt_free(st);
685 }
686
687 /** Handles lost packets by just copying past data with the same offset as the last
688     pitch period */
689 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
690 {
691    int c, N;
692    int pitch_index;
693    int i, len;
694    VARDECL(celt_sig_t, freq);
695    const int C = CHANNELS(st->mode);
696    int offset;
697    SAVE_STACK;
698    N = st->block_size;
699    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
700    
701    len = N+st->mode->overlap;
702 #if 0
703    pitch_index = st->last_pitch_index;
704    
705    /* Use the pitch MDCT as the "guessed" signal */
706    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
707
708 #else
709    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
710    pitch_index = MAX_PERIOD-len-pitch_index;
711    offset = MAX_PERIOD-pitch_index;
712    while (offset+len >= MAX_PERIOD)
713       offset -= pitch_index;
714    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
715    for (i=0;i<N;i++)
716       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
717 #endif
718    
719    
720    
721    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
722    /* Compute inverse MDCTs */
723    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
724
725    for (c=0;c<C;c++)
726    {
727       int j;
728       for (j=0;j<N;j++)
729       {
730          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
731                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
732          st->preemph_memD[c] = tmp;
733          pcm[C*j+c] = SIG2INT16(tmp);
734       }
735    }
736    RESTORE_STACK;
737 }
738
739 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
740 {
741    int c, N, N4;
742    int has_pitch;
743    int pitch_index;
744    ec_dec dec;
745    ec_byte_buffer buf;
746    VARDECL(celt_sig_t, freq);
747    VARDECL(celt_norm_t, X);
748    VARDECL(celt_norm_t, P);
749    VARDECL(celt_ener_t, bandE);
750    VARDECL(celt_pgain_t, gains);
751    VARDECL(int, stereo_mode);
752    int shortBlocks;
753    int transient_time;
754    int transient_shift;
755    const int C = CHANNELS(st->mode);
756    SAVE_STACK;
757
758    if (check_mode(st->mode) != CELT_OK)
759       return CELT_INVALID_MODE;
760
761    N = st->block_size;
762    N4 = (N-st->overlap)>>1;
763
764    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
765    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
766    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
767    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
768    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
769    
770    if (check_mode(st->mode) != CELT_OK)
771    {
772       RESTORE_STACK;
773       return CELT_INVALID_MODE;
774    }
775    if (data == NULL)
776    {
777       celt_decode_lost(st, pcm);
778       RESTORE_STACK;
779       return 0;
780    }
781    
782    ec_byte_readinit(&buf,data,len);
783    ec_dec_init(&dec,&buf);
784    
785    shortBlocks = ec_dec_bits(&dec, 1);
786    if (shortBlocks)
787    {
788       transient_shift = ec_dec_bits(&dec, 2);
789       if (transient_shift)
790          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
791       else
792          transient_time = 0;
793    } else {
794       transient_time = -1;
795       transient_shift = 0;
796    }
797    /* Get the pitch gains */
798    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
799    
800    /* Get the pitch index */
801    if (has_pitch)
802    {
803       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
804       st->last_pitch_index = pitch_index;
805    } else {
806       /* FIXME: We could be more intelligent here and just not compute the MDCT */
807       pitch_index = 0;
808    }
809
810    /* Get band energies */
811    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
812
813    /* Pitch MDCT */
814    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
815
816    {
817       VARDECL(celt_ener_t, bandEp);
818       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
819       compute_band_energies(st->mode, freq, bandEp);
820       normalise_bands(st->mode, freq, P, bandEp);
821    }
822
823    ALLOC(stereo_mode, st->mode->nbEBands, int);
824    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
825    /* Apply pitch gains */
826    pitch_quant_bands(st->mode, P, gains);
827
828    /* Decode fixed codebook and merge with pitch */
829    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
830
831    if (C==2)
832    {
833       renormalise_bands(st->mode, X);
834    }
835    /* Synthesis */
836    denormalise_bands(st->mode, X, freq, bandE);
837
838
839    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
840    /* Compute inverse MDCTs */
841    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
842
843    for (c=0;c<C;c++)
844    {
845       int j;
846       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
847       celt_int16_t * restrict pcmp = pcm+c;
848       for (j=0;j<N;j++)
849       {
850          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
851          st->preemph_memD[c] = tmp;
852          *pcmp = SIG2INT16(tmp);
853          pcmp += C;
854          outp += C;
855       }
856    }
857
858    {
859       unsigned int val = 0;
860       while (ec_dec_tell(&dec, 0) < len*8)
861       {
862          if (ec_dec_uint(&dec, 2) != val)
863          {
864             celt_warning("decode error");
865             RESTORE_STACK;
866             return CELT_CORRUPTED_DATA;
867          }
868          val = 1-val;
869       }
870    }
871
872    RESTORE_STACK;
873    return 0;
874    /*printf ("\n");*/
875 }
876