fixed-point: conversion of pre-echo avoidance now complete.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66 /** Encoder state 
67  @brief Encoder state
68  */
69 struct CELTEncoder {
70    const CELTMode *mode;     /**< Mode used by the encoder */
71    int frame_size;
72    int block_size;
73    int overlap;
74    int channels;
75    
76    ec_byte_buffer buf;
77    ec_enc         enc;
78
79    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
80    celt_sig_t    * restrict preemph_memD;
81
82    celt_sig_t *in_mem;
83    celt_sig_t *out_mem;
84
85    celt_word16_t *oldBandE;
86 #ifdef EXP_PSY
87    celt_word16_t *psy_mem;
88    struct PsyDecay psy;
89 #endif
90 };
91
92 CELTEncoder *celt_encoder_create(const CELTMode *mode)
93 {
94    int N, C;
95    CELTEncoder *st;
96
97    if (check_mode(mode) != CELT_OK)
98       return NULL;
99
100    N = mode->mdctSize;
101    C = mode->nbChannels;
102    st = celt_alloc(sizeof(CELTEncoder));
103    
104    st->mode = mode;
105    st->frame_size = N;
106    st->block_size = N;
107    st->overlap = mode->overlap;
108
109    ec_byte_writeinit(&st->buf);
110    ec_enc_init(&st->enc,&st->buf);
111
112    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
113    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
114
115    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
116
117    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
118    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
119
120 #ifdef EXP_PSY
121    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
122    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
123 #endif
124
125    return st;
126 }
127
128 void celt_encoder_destroy(CELTEncoder *st)
129 {
130    if (st == NULL)
131    {
132       celt_warning("NULL passed to celt_encoder_destroy");
133       return;
134    }
135    if (check_mode(st->mode) != CELT_OK)
136       return;
137
138    ec_byte_writeclear(&st->buf);
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t SIG2INT16(celt_sig_t x)
157 {
158    x = PSHR32(x, SIG_SHIFT);
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161 #ifdef FIXED_POINT
162    return EXTRACT16(x);
163 #else
164    return (celt_int16_t)floor(.5+x);
165 #endif
166 }
167 #ifdef FIXED_POINT
168 static int ratio_compare(celt_word32_t num1, celt_word32_t den1, celt_word32_t num2, celt_word32_t den2)
169 {
170    int shift = celt_zlog2(MAX32(num1, num2));
171    if (shift > 14)
172    {
173       num1 = SHR32(num1, shift-14);
174       num2 = SHR32(num2, shift-14);
175    }
176    shift = celt_zlog2(MAX32(den1, den2));
177    if (shift > 14)
178    {
179       den1 = SHR32(den1, shift-14);
180       den2 = SHR32(den2, shift-14);
181    }
182    return MULT16_16(EXTRACT16(num1),EXTRACT16(den2)) > MULT16_16(EXTRACT16(den1),EXTRACT16(num2));
183 }
184 #else
185 static int ratio_compare(celt_word32_t num1, celt_word32_t den1, celt_word32_t num2, celt_word32_t den2)
186 {
187    return num1*den2 > den1*num2;
188 }
189 #endif
190
191 static int transient_analysis(celt_word32_t *in, int len, int C, celt_word32_t *r)
192 {
193    int c, i, n;
194    celt_word32_t ratio;
195    /* FIXME: Remove the floats here */
196    celt_word32_t maxN, maxD;
197    VARDECL(celt_word32_t, begin);
198    SAVE_STACK;
199    ALLOC(begin, len, celt_word32_t);
200    
201    for (i=0;i<len;i++)
202       begin[i] = EXTEND32(ABS16(SHR32(in[C*i],SIG_SHIFT)));
203    for (c=1;c<C;c++)
204    {
205       for (i=0;i<len;i++)
206          begin[i] = ADD32(begin[i], EXTEND32(ABS16(SHR32(in[C*i+c],SIG_SHIFT))));
207    }
208    for (i=1;i<len;i++)
209       begin[i] = begin[i-1]+begin[i];
210
211    maxD = VERY_LARGE32;
212    maxN = 0;
213    n = -1;
214    for (i=8;i<len-8;i++)
215    {
216       celt_word32_t endi;
217       celt_word32_t num, den;
218       endi = begin[len-1]-begin[i];
219       num = endi*i;
220       den = (30+begin[i])*(len-i)+MULT16_32_Q15(QCONST16(.1f,15),endi)*len;
221       if (ratio_compare(num, den, maxN, maxD) && (endi > MULT16_32_Q15(QCONST16(.05f,15),begin[i])))
222       {
223          maxN = num;
224          maxD = den;
225          n = i;
226       }
227    }
228    ratio = DIV32((begin[len-1]-begin[n])*n,(10+begin[n])*(len-n));
229    if (n<32)
230    {
231       n = -1;
232       ratio = 0;
233    }
234    if (ratio < 0)
235       ratio = 0;
236    if (ratio > 1000)
237       ratio = 1000;
238    *r = ratio*ratio;
239    RESTORE_STACK;
240    return n;
241 }
242
243 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
244 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
245 {
246    const int C = CHANNELS(mode);
247    if (C==1 && !shortBlocks)
248    {
249       const mdct_lookup *lookup = MDCT(mode);
250       const int overlap = OVERLAP(mode);
251       mdct_forward(lookup, in, out, mode->window, overlap);
252    } else if (!shortBlocks) {
253       const mdct_lookup *lookup = MDCT(mode);
254       const int overlap = OVERLAP(mode);
255       const int N = FRAMESIZE(mode);
256       int c;
257       VARDECL(celt_word32_t, x);
258       VARDECL(celt_word32_t, tmp);
259       SAVE_STACK;
260       ALLOC(x, N+overlap, celt_word32_t);
261       ALLOC(tmp, N, celt_word32_t);
262       for (c=0;c<C;c++)
263       {
264          int j;
265          for (j=0;j<N+overlap;j++)
266             x[j] = in[C*j+c];
267          mdct_forward(lookup, x, tmp, mode->window, overlap);
268          /* Interleaving the sub-frames */
269          for (j=0;j<N;j++)
270             out[C*j+c] = tmp[j];
271       }
272       RESTORE_STACK;
273    } else {
274       const mdct_lookup *lookup = &mode->shortMdct;
275       const int overlap = mode->shortMdctSize;
276       const int N = mode->shortMdctSize;
277       int b, c;
278       VARDECL(celt_word32_t, x);
279       VARDECL(celt_word32_t, tmp);
280       SAVE_STACK;
281       ALLOC(x, N+overlap, celt_word32_t);
282       ALLOC(tmp, N, celt_word32_t);
283       for (c=0;c<C;c++)
284       {
285          int B = mode->nbShortMdcts;
286          for (b=0;b<B;b++)
287          {
288             int j;
289             for (j=0;j<N+overlap;j++)
290                x[j] = in[C*(b*N+j)+c];
291             mdct_forward(lookup, x, tmp, mode->window, overlap);
292             /* Interleaving the sub-frames */
293             for (j=0;j<N;j++)
294                out[C*(j*B+b)+c] = tmp[j];
295          }
296       }
297       RESTORE_STACK;
298    }
299 }
300
301 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
302 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
303 {
304    int c, N4;
305    const int C = CHANNELS(mode);
306    const int N = FRAMESIZE(mode);
307    const int overlap = OVERLAP(mode);
308    N4 = (N-overlap)>>1;
309    for (c=0;c<C;c++)
310    {
311       int j;
312       if (transient_shift==0 && C==1 && !shortBlocks) {
313          const mdct_lookup *lookup = MDCT(mode);
314          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
315       } else if (!shortBlocks) {
316          const mdct_lookup *lookup = MDCT(mode);
317          VARDECL(celt_word32_t, x);
318          VARDECL(celt_word32_t, tmp);
319          SAVE_STACK;
320          ALLOC(x, 2*N, celt_word32_t);
321          ALLOC(tmp, N, celt_word32_t);
322          /* De-interleaving the sub-frames */
323          for (j=0;j<N;j++)
324             tmp[j] = X[C*j+c];
325          /* Prevents problems from the imdct doing the overlap-add */
326          CELT_MEMSET(x+N4, 0, overlap);
327          mdct_backward(lookup, tmp, x, mode->window, overlap);
328          celt_assert(transient_shift == 0);
329          /* The first and last part would need to be set to zero if we actually
330             wanted to use them. */
331          for (j=0;j<overlap;j++)
332             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
333          for (j=0;j<overlap;j++)
334             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
335          for (j=0;j<2*N4;j++)
336             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
337          RESTORE_STACK;
338       } else {
339          int b;
340          const int N2 = mode->shortMdctSize;
341          const int B = mode->nbShortMdcts;
342          const mdct_lookup *lookup = &mode->shortMdct;
343          VARDECL(celt_word32_t, x);
344          VARDECL(celt_word32_t, tmp);
345          SAVE_STACK;
346          ALLOC(x, 2*N, celt_word32_t);
347          ALLOC(tmp, N, celt_word32_t);
348          /* Prevents problems from the imdct doing the overlap-add */
349          CELT_MEMSET(x+N4, 0, overlap);
350          for (b=0;b<B;b++)
351          {
352             /* De-interleaving the sub-frames */
353             for (j=0;j<N2;j++)
354                tmp[j] = X[C*(j*B+b)+c];
355             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
356          }
357          if (transient_shift > 0)
358          {
359 #ifdef FIXED_POINT
360             for (j=0;j<16;j++)
361                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
362             for (j=transient_time;j<N+overlap;j++)
363                x[N4+j] = SHL32(x[N4+j], transient_shift);
364 #else
365             for (j=0;j<16;j++)
366                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
367             for (j=transient_time;j<N+overlap;j++)
368                x[N4+j] *= 1<<transient_shift;
369 #endif
370          }
371          /* The first and last part would need to be set to zero if we actually
372          wanted to use them. */
373          for (j=0;j<overlap;j++)
374             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
375          for (j=0;j<overlap;j++)
376             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
377          for (j=0;j<2*N4;j++)
378             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
379          RESTORE_STACK;
380       }
381    }
382 }
383
384 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
385 {
386    int i, c, N, N4;
387    int has_pitch;
388    int pitch_index;
389    celt_word32_t curr_power, pitch_power;
390    VARDECL(celt_sig_t, in);
391    VARDECL(celt_sig_t, freq);
392    VARDECL(celt_norm_t, X);
393    VARDECL(celt_norm_t, P);
394    VARDECL(celt_ener_t, bandE);
395    VARDECL(celt_pgain_t, gains);
396    VARDECL(int, stereo_mode);
397 #ifdef EXP_PSY
398    VARDECL(celt_word32_t, mask);
399 #endif
400    int shortBlocks=0;
401    int transient_time;
402    int transient_shift;
403    celt_word32_t maxR;
404    const int C = CHANNELS(st->mode);
405    SAVE_STACK;
406
407    if (check_mode(st->mode) != CELT_OK)
408       return CELT_INVALID_MODE;
409
410    N = st->block_size;
411    N4 = (N-st->overlap)>>1;
412    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
413
414    CELT_COPY(in, st->in_mem, C*st->overlap);
415    for (c=0;c<C;c++)
416    {
417       const celt_int16_t * restrict pcmp = pcm+c;
418       celt_sig_t * restrict inp = in+C*st->overlap+c;
419       for (i=0;i<N;i++)
420       {
421          /* Apply pre-emphasis */
422          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
423          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
424          st->preemph_memE[c] = *pcmp;
425          inp += C;
426          pcmp += C;
427       }
428    }
429    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
430    
431    transient_time = transient_analysis(in, N+st->overlap, C, &maxR);
432    if (maxR > 30)
433    {
434 #ifndef FIXED_POINT
435       float gain_1;
436 #endif
437       ec_enc_bits(&st->enc, 1, 1);
438       if (maxR < 30)
439       {
440          transient_shift = 0;
441       } else if (maxR < 100)
442       {
443          transient_shift = 1;
444       } else if (maxR < 500)
445       {
446          transient_shift = 2;
447       } else
448       {
449          transient_shift = 3;
450       }
451       ec_enc_bits(&st->enc, transient_shift, 2);
452       if (transient_shift)
453          ec_enc_uint(&st->enc, transient_time, N+st->overlap);
454       if (transient_shift)
455       {
456 #ifdef FIXED_POINT
457          for (c=0;c<C;c++)
458             for (i=0;i<16;i++)
459                in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
460          for (c=0;c<C;c++)
461             for (i=transient_time;i<N+st->overlap;i++)
462                in[C*i+c] = SHR32(in[C*i+c], transient_shift);
463 #else
464          for (c=0;c<C;c++)
465             for (i=0;i<16;i++)
466                in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
467          gain_1 = 1./(1<<transient_shift);
468          for (c=0;c<C;c++)
469             for (i=transient_time;i<N+st->overlap;i++)
470                in[C*i+c] *= gain_1;
471 #endif
472       }
473       shortBlocks = 1;
474    } else {
475       ec_enc_bits(&st->enc, 0, 1);
476       transient_time = -1;
477       transient_shift = 0;
478       shortBlocks = 0;
479    }
480    /* Pitch analysis: we do it early to save on the peak stack space */
481    if (!shortBlocks)
482       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
483
484    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
485    
486    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
487    /* Compute MDCTs */
488    compute_mdcts(st->mode, shortBlocks, in, freq);
489
490 #ifdef EXP_PSY
491    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
492    for (i=0;i<N;i++)
493       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
494    for (c=1;c<C;c++)
495       for (i=0;i<N;i++)
496          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
497
498    ALLOC(mask, N, celt_sig_t);
499    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
500
501    /* Invert and stretch the mask to length of X 
502       For some reason, I get better results by using the sqrt instead,
503       although there's no valid reason to. Must investigate further */
504    for (i=0;i<C*N;i++)
505       mask[i] = 1/(.1+mask[i]);
506 #endif
507    
508    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
509    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
510    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
511    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
512    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
513
514    /*printf ("%f %f\n", curr_power, pitch_power);*/
515    /*int j;
516    for (j=0;j<B*N;j++)
517       printf ("%f ", X[j]);
518    for (j=0;j<B*N;j++)
519       printf ("%f ", P[j]);
520    printf ("\n");*/
521
522    /* Band normalisation */
523    compute_band_energies(st->mode, freq, bandE);
524    normalise_bands(st->mode, freq, X, bandE);
525    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
526    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
527
528    /* Compute MDCTs of the pitch part */
529    if (!shortBlocks)
530       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
531
532    {
533       /* Normalise the pitch vector as well (discard the energies) */
534       VARDECL(celt_ener_t, bandEp);
535       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
536       compute_band_energies(st->mode, freq, bandEp);
537       normalise_bands(st->mode, freq, P, bandEp);
538       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
539    }
540    curr_power = bandE[0]+bandE[1]+bandE[2];
541    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
542    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
543    {
544       /* Simulates intensity stereo */
545       /*for (i=30;i<N*B;i++)
546          X[i*C+1] = P[i*C+1] = 0;*/
547
548       /* Pitch prediction */
549       compute_pitch_gain(st->mode, X, P, gains);
550       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
551       if (has_pitch)
552          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
553    } else {
554       /* No pitch, so we just pretend we found a gain of zero */
555       for (i=0;i<st->mode->nbPBands;i++)
556          gains[i] = 0;
557       ec_enc_bits(&st->enc, 0, 7);
558       for (i=0;i<C*N;i++)
559          P[i] = 0;
560    }
561    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
562
563    ALLOC(stereo_mode, st->mode->nbEBands, int);
564    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
565
566    pitch_quant_bands(st->mode, P, gains);
567
568    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
569
570    /* Residual quantisation */
571    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
572    
573    if (C==2)
574    {
575       renormalise_bands(st->mode, X);
576    }
577    /* Synthesis */
578    denormalise_bands(st->mode, X, freq, bandE);
579
580
581    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
582
583    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
584    /* De-emphasis and put everything back at the right place in the synthesis history */
585 #ifndef SHORTCUTS
586    for (c=0;c<C;c++)
587    {
588       int j;
589       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
590       celt_int16_t * restrict pcmp = pcm+c;
591       for (j=0;j<N;j++)
592       {
593          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
594          st->preemph_memD[c] = tmp;
595          *pcmp = SIG2INT16(tmp);
596          pcmp += C;
597          outp += C;
598       }
599    }
600 #endif
601    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
602       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
603    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
604    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
605    {
606       int val = 0;
607       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
608       {
609          ec_enc_uint(&st->enc, val, 2);
610          val = 1-val;
611       }
612    }
613    ec_enc_done(&st->enc);
614    {
615       unsigned char *data;
616       int nbBytes = ec_byte_bytes(&st->buf);
617       if (nbBytes > nbCompressedBytes)
618       {
619          celt_warning_int ("got too many bytes:", nbBytes);
620          RESTORE_STACK;
621          return CELT_INTERNAL_ERROR;
622       }
623       /*printf ("%d\n", *nbBytes);*/
624       data = ec_byte_get_buffer(&st->buf);
625       for (i=0;i<nbBytes;i++)
626          compressed[i] = data[i];
627       for (;i<nbCompressedBytes;i++)
628          compressed[i] = 0;
629    }
630    /* Reset the packing for the next encoding */
631    ec_byte_reset(&st->buf);
632    ec_enc_init(&st->enc,&st->buf);
633
634    RESTORE_STACK;
635    return nbCompressedBytes;
636 }
637
638
639 /****************************************************************************/
640 /*                                                                          */
641 /*                                DECODER                                   */
642 /*                                                                          */
643 /****************************************************************************/
644
645
646 /** Decoder state 
647  @brief Decoder state
648  */
649 struct CELTDecoder {
650    const CELTMode *mode;
651    int frame_size;
652    int block_size;
653    int overlap;
654
655    ec_byte_buffer buf;
656    ec_enc         enc;
657
658    celt_sig_t * restrict preemph_memD;
659
660    celt_sig_t *out_mem;
661
662    celt_word16_t *oldBandE;
663    
664    int last_pitch_index;
665 };
666
667 CELTDecoder *celt_decoder_create(const CELTMode *mode)
668 {
669    int N, C;
670    CELTDecoder *st;
671
672    if (check_mode(mode) != CELT_OK)
673       return NULL;
674
675    N = mode->mdctSize;
676    C = CHANNELS(mode);
677    st = celt_alloc(sizeof(CELTDecoder));
678    
679    st->mode = mode;
680    st->frame_size = N;
681    st->block_size = N;
682    st->overlap = mode->overlap;
683
684    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
685    
686    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
687
688    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
689
690    st->last_pitch_index = 0;
691    return st;
692 }
693
694 void celt_decoder_destroy(CELTDecoder *st)
695 {
696    if (st == NULL)
697    {
698       celt_warning("NULL passed to celt_encoder_destroy");
699       return;
700    }
701    if (check_mode(st->mode) != CELT_OK)
702       return;
703
704
705    celt_free(st->out_mem);
706    
707    celt_free(st->oldBandE);
708    
709    celt_free(st->preemph_memD);
710
711    celt_free(st);
712 }
713
714 /** Handles lost packets by just copying past data with the same offset as the last
715     pitch period */
716 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
717 {
718    int c, N;
719    int pitch_index;
720    int i, len;
721    VARDECL(celt_sig_t, freq);
722    const int C = CHANNELS(st->mode);
723    int offset;
724    SAVE_STACK;
725    N = st->block_size;
726    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
727    
728    len = N+st->mode->overlap;
729 #if 0
730    pitch_index = st->last_pitch_index;
731    
732    /* Use the pitch MDCT as the "guessed" signal */
733    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
734
735 #else
736    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
737    pitch_index = MAX_PERIOD-len-pitch_index;
738    offset = MAX_PERIOD-pitch_index;
739    while (offset+len >= MAX_PERIOD)
740       offset -= pitch_index;
741    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
742    for (i=0;i<N;i++)
743       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
744 #endif
745    
746    
747    
748    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
749    /* Compute inverse MDCTs */
750    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
751
752    for (c=0;c<C;c++)
753    {
754       int j;
755       for (j=0;j<N;j++)
756       {
757          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
758                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
759          st->preemph_memD[c] = tmp;
760          pcm[C*j+c] = SIG2INT16(tmp);
761       }
762    }
763    RESTORE_STACK;
764 }
765
766 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
767 {
768    int c, N, N4;
769    int has_pitch;
770    int pitch_index;
771    ec_dec dec;
772    ec_byte_buffer buf;
773    VARDECL(celt_sig_t, freq);
774    VARDECL(celt_norm_t, X);
775    VARDECL(celt_norm_t, P);
776    VARDECL(celt_ener_t, bandE);
777    VARDECL(celt_pgain_t, gains);
778    VARDECL(int, stereo_mode);
779    int shortBlocks;
780    int transient_time;
781    int transient_shift;
782    const int C = CHANNELS(st->mode);
783    SAVE_STACK;
784
785    if (check_mode(st->mode) != CELT_OK)
786       return CELT_INVALID_MODE;
787
788    N = st->block_size;
789    N4 = (N-st->overlap)>>1;
790
791    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
792    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
793    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
794    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
795    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
796    
797    if (check_mode(st->mode) != CELT_OK)
798    {
799       RESTORE_STACK;
800       return CELT_INVALID_MODE;
801    }
802    if (data == NULL)
803    {
804       celt_decode_lost(st, pcm);
805       RESTORE_STACK;
806       return 0;
807    }
808    
809    ec_byte_readinit(&buf,data,len);
810    ec_dec_init(&dec,&buf);
811    
812    shortBlocks = ec_dec_bits(&dec, 1);
813    if (shortBlocks)
814    {
815       transient_shift = ec_dec_bits(&dec, 2);
816       if (transient_shift)
817          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
818       else
819          transient_time = 0;
820    } else {
821       transient_time = -1;
822       transient_shift = 0;
823    }
824    /* Get the pitch gains */
825    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
826    
827    /* Get the pitch index */
828    if (has_pitch)
829    {
830       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
831       st->last_pitch_index = pitch_index;
832    } else {
833       /* FIXME: We could be more intelligent here and just not compute the MDCT */
834       pitch_index = 0;
835    }
836
837    /* Get band energies */
838    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
839
840    /* Pitch MDCT */
841    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
842
843    {
844       VARDECL(celt_ener_t, bandEp);
845       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
846       compute_band_energies(st->mode, freq, bandEp);
847       normalise_bands(st->mode, freq, P, bandEp);
848    }
849
850    ALLOC(stereo_mode, st->mode->nbEBands, int);
851    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
852    /* Apply pitch gains */
853    pitch_quant_bands(st->mode, P, gains);
854
855    /* Decode fixed codebook and merge with pitch */
856    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
857
858    if (C==2)
859    {
860       renormalise_bands(st->mode, X);
861    }
862    /* Synthesis */
863    denormalise_bands(st->mode, X, freq, bandE);
864
865
866    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
867    /* Compute inverse MDCTs */
868    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
869
870    for (c=0;c<C;c++)
871    {
872       int j;
873       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
874       celt_int16_t * restrict pcmp = pcm+c;
875       for (j=0;j<N;j++)
876       {
877          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
878          st->preemph_memD[c] = tmp;
879          *pcmp = SIG2INT16(tmp);
880          pcmp += C;
881          outp += C;
882       }
883    }
884
885    {
886       unsigned int val = 0;
887       while (ec_dec_tell(&dec, 0) < len*8)
888       {
889          if (ec_dec_uint(&dec, 2) != val)
890          {
891             celt_warning("decode error");
892             RESTORE_STACK;
893             return CELT_CORRUPTED_DATA;
894          }
895          val = 1-val;
896       }
897    }
898
899    RESTORE_STACK;
900    return 0;
901    /*printf ("\n");*/
902 }
903