Automatically choosing the overlap based on the frame size.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52
53 static const celt_word16_t preemph = QCONST16(0.8f,15);
54
55 static const float gainWindow[16] = {
56    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
57    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
58    
59 /** Encoder state 
60  @brief Encoder state
61  */
62 struct CELTEncoder {
63    const CELTMode *mode;     /**< Mode used by the encoder */
64    int frame_size;
65    int block_size;
66    int overlap;
67    int channels;
68    
69    ec_byte_buffer buf;
70    ec_enc         enc;
71
72    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
73    celt_sig_t    * restrict preemph_memD;
74
75    celt_sig_t *in_mem;
76    celt_sig_t *out_mem;
77
78    celt_word16_t *oldBandE;
79 #ifdef EXP_PSY
80    celt_word16_t *psy_mem;
81    struct PsyDecay psy;
82 #endif
83 };
84
85 CELTEncoder *celt_encoder_create(const CELTMode *mode)
86 {
87    int N, C;
88    CELTEncoder *st;
89
90    if (check_mode(mode) != CELT_OK)
91       return NULL;
92
93    N = mode->mdctSize;
94    C = mode->nbChannels;
95    st = celt_alloc(sizeof(CELTEncoder));
96    
97    st->mode = mode;
98    st->frame_size = N;
99    st->block_size = N;
100    st->overlap = mode->overlap;
101
102    ec_byte_writeinit(&st->buf);
103    ec_enc_init(&st->enc,&st->buf);
104
105    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
106    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
107
108    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
109
110    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));;
111    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
112
113 #ifdef EXP_PSY
114    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
115    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
116 #endif
117
118    return st;
119 }
120
121 void celt_encoder_destroy(CELTEncoder *st)
122 {
123    if (st == NULL)
124    {
125       celt_warning("NULL passed to celt_encoder_destroy");
126       return;
127    }
128    if (check_mode(st->mode) != CELT_OK)
129       return;
130
131    ec_byte_writeclear(&st->buf);
132
133    celt_free(st->in_mem);
134    celt_free(st->out_mem);
135    
136    celt_free(st->oldBandE);
137    
138    celt_free(st->preemph_memE);
139    celt_free(st->preemph_memD);
140    
141 #ifdef EXP_PSY
142    celt_free (st->psy_mem);
143    psydecay_clear(&st->psy);
144 #endif
145    
146    celt_free(st);
147 }
148
149 static inline celt_int16_t SIG2INT16(celt_sig_t x)
150 {
151    x = PSHR32(x, SIG_SHIFT);
152    x = MAX32(x, -32768);
153    x = MIN32(x, 32767);
154 #ifdef FIXED_POINT
155    return EXTRACT16(x);
156 #else
157    return (celt_int16_t)floor(.5+x);
158 #endif
159 }
160
161 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
162 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
163 {
164    const int C = CHANNELS(mode);
165    if (C==1 && !shortBlocks)
166    {
167       const mdct_lookup *lookup = MDCT(mode);
168       const int overlap = OVERLAP(mode);
169       mdct_forward(lookup, in, out, mode->window, overlap);
170    } else if (!shortBlocks) {
171       const mdct_lookup *lookup = MDCT(mode);
172       const int overlap = OVERLAP(mode);
173       const int N = FRAMESIZE(mode);
174       int c;
175       VARDECL(celt_word32_t, x);
176       VARDECL(celt_word32_t, tmp);
177       SAVE_STACK;
178       ALLOC(x, N+overlap, celt_word32_t);
179       ALLOC(tmp, N, celt_word32_t);
180       for (c=0;c<C;c++)
181       {
182          int j;
183          for (j=0;j<N+overlap;j++)
184             x[j] = in[C*j+c];
185          mdct_forward(lookup, x, tmp, mode->window, overlap);
186          /* Interleaving the sub-frames */
187          for (j=0;j<N;j++)
188             out[C*j+c] = tmp[j];
189       }
190       RESTORE_STACK;
191    } else {
192       const mdct_lookup *lookup = &mode->shortMdct;
193       const int overlap = mode->shortMdctSize;
194       const int N = mode->shortMdctSize;
195       int b, c;
196       VARDECL(celt_word32_t, x);
197       VARDECL(celt_word32_t, tmp);
198       SAVE_STACK;
199       ALLOC(x, N+overlap, celt_word32_t);
200       ALLOC(tmp, N, celt_word32_t);
201       for (c=0;c<C;c++)
202       {
203          int B = mode->nbShortMdcts;
204          for (b=0;b<B;b++)
205          {
206             int j;
207             for (j=0;j<N+overlap;j++)
208                x[j] = in[C*(b*N+j)+c];
209             mdct_forward(lookup, x, tmp, mode->window, overlap);
210             /* Interleaving the sub-frames */
211             for (j=0;j<N;j++)
212                out[C*(j*B+b)+c] = tmp[j];
213          }
214       }
215       RESTORE_STACK;
216    }
217 }
218
219 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
220 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, float transient_gain, celt_sig_t * restrict out_mem)
221 {
222    int c, N4;
223    const int C = CHANNELS(mode);
224    const int N = FRAMESIZE(mode);
225    const int overlap = OVERLAP(mode);
226    N4 = (N-overlap)>>1;
227    for (c=0;c<C;c++)
228    {
229       int j;
230       if (transient_time<0 && C==1 && !shortBlocks) {
231          const mdct_lookup *lookup = MDCT(mode);
232          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
233       } else if (!shortBlocks) {
234          const mdct_lookup *lookup = MDCT(mode);
235          VARDECL(celt_word32_t, x);
236          VARDECL(celt_word32_t, tmp);
237          SAVE_STACK;
238          ALLOC(x, 2*N, celt_word32_t);
239          ALLOC(tmp, N, celt_word32_t);
240          /* De-interleaving the sub-frames */
241          for (j=0;j<N;j++)
242             tmp[j] = X[C*j+c];
243          /* Prevents problems from the imdct doing the overlap-add */
244          CELT_MEMSET(x+N4, 0, overlap);
245          mdct_backward(lookup, tmp, x, mode->window, overlap);
246          if (transient_time >= 0)
247          {
248             for (j=0;j<16;j++)
249                x[N4+transient_time+j-16] *= 1+gainWindow[j]*(transient_gain-1);
250             for (j=transient_time;j<N+overlap;j++)
251                x[N4+j] *= transient_gain;
252          }
253          /* The first and last part would need to be set to zero if we actually
254          wanted to use them. */
255          for (j=0;j<overlap;j++)
256             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
257          for (j=0;j<overlap;j++)
258             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
259          for (j=0;j<2*N4;j++)
260             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
261          RESTORE_STACK;
262       } else {
263          int b;
264          const int N2 = mode->shortMdctSize;
265          const int B = mode->nbShortMdcts;
266          const mdct_lookup *lookup = &mode->shortMdct;
267          VARDECL(celt_word32_t, x);
268          VARDECL(celt_word32_t, tmp);
269          SAVE_STACK;
270          ALLOC(x, 2*N, celt_word32_t);
271          ALLOC(tmp, N, celt_word32_t);
272          /* Prevents problems from the imdct doing the overlap-add */
273          CELT_MEMSET(x+N4, 0, overlap);
274          for (b=0;b<B;b++)
275          {
276             /* De-interleaving the sub-frames */
277             for (j=0;j<N2;j++)
278                tmp[j] = X[C*(j*B+b)+c];
279             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
280          }
281          if (transient_time >= 0)
282          {
283             for (j=0;j<16;j++)
284                x[N4+transient_time+j-16] *= 1+gainWindow[j]*(transient_gain-1);
285             for (j=transient_time;j<N+overlap;j++)
286                x[N4+j] *= transient_gain;
287          }
288          /* The first and last part would need to be set to zero if we actually
289          wanted to use them. */
290          for (j=0;j<overlap;j++)
291             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
292          for (j=0;j<overlap;j++)
293             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
294          for (j=0;j<2*N4;j++)
295             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
296          RESTORE_STACK;
297       }
298    }
299 }
300
301 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
302 {
303    int i, c, N, N4;
304    int has_pitch;
305    int pitch_index;
306    celt_word32_t curr_power, pitch_power;
307    VARDECL(celt_sig_t, in);
308    VARDECL(celt_sig_t, freq);
309    VARDECL(celt_norm_t, X);
310    VARDECL(celt_norm_t, P);
311    VARDECL(celt_ener_t, bandE);
312    VARDECL(celt_pgain_t, gains);
313    VARDECL(int, stereo_mode);
314 #ifdef EXP_PSY
315    VARDECL(celt_word32_t, mask);
316 #endif
317    int shortBlocks=0;
318    int transient_time;
319    float transient_gain;
320    const int C = CHANNELS(st->mode);
321    SAVE_STACK;
322
323    if (check_mode(st->mode) != CELT_OK)
324       return CELT_INVALID_MODE;
325
326    N = st->block_size;
327    N4 = (N-st->overlap)>>1;
328    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
329
330    CELT_COPY(in, st->in_mem, C*st->overlap);
331    for (c=0;c<C;c++)
332    {
333       const celt_int16_t * restrict pcmp = pcm+c;
334       celt_sig_t * restrict inp = in+C*st->overlap+c;
335       for (i=0;i<N;i++)
336       {
337          /* Apply pre-emphasis */
338          celt_sig_t tmp = SHL32(EXTEND32(*pcmp), SIG_SHIFT);
339          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),1));
340          st->preemph_memE[c] = *pcmp;
341          inp += C;
342          pcmp += C;
343       }
344    }
345    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
346    if (1) {
347       int len = N+st->overlap;
348       float maxR, maxD;
349       float begin[C*len], end[C*len];
350       begin[0] = in[0]*in[0];
351       for (i=1;i<len*C;i++)
352          begin[i] = begin[i-1]+in[i]*in[i];
353       end[len-1] = in[len-1]*in[len-1];
354       for (i=C*len-2;i>=0;i--)
355          end[i] = end[i+1] + in[i]*in[i];
356       maxD = 0;
357       maxR = 1;
358       transient_time = -1;
359       for (i=8*C;i<C*(len-8);i++)
360       {
361          float diff = sqrt(sqrt(end[i]/(C*len-i)))-sqrt(sqrt(begin[i]/(i)));
362          float ratio = ((1000+end[i])*i)/((1000+begin[i])*(C*len-i));
363          if (diff > maxD && end[i] > .5*begin[i])
364          {
365             maxD = diff;
366             maxR = ratio;
367             transient_time = i;
368          }
369       }
370       transient_time /= C;
371       if (transient_time<32)
372       {
373          transient_time = -1;
374          maxR = 0;
375       }
376       if (maxR > 30)
377       {
378          float gain_1;
379          ec_enc_bits(&st->enc, 1, 1);
380          if (maxR < 30)
381          {
382             transient_time = 16;
383             transient_gain = 1;
384             ec_enc_bits(&st->enc, 0, 2);
385          } else if (maxR < 100)
386          {
387             transient_gain = 2;
388             ec_enc_bits(&st->enc, 1, 2);
389          } else if (maxR < 500)
390          {
391             transient_gain = 4;
392             ec_enc_bits(&st->enc, 2, 2);
393          } else
394          {
395             transient_gain = 8;
396             ec_enc_bits(&st->enc, 3, 2);
397          }
398          ec_enc_uint(&st->enc, transient_time, len);
399          for (c=0;c<C;c++)
400             for (i=0;i<16;i++)
401                in[C*(transient_time+i-16)+c] /= 1+gainWindow[i]*(transient_gain-1);
402          gain_1 = 1./transient_gain;
403          for (c=0;c<C;c++)
404             for (i=transient_time;i<len;i++)
405                in[C*i+c] *= gain_1;
406          shortBlocks = 1;
407       } else {
408          ec_enc_bits(&st->enc, 0, 1);
409          transient_time = -1;
410          transient_gain = 1;
411          shortBlocks = 0;
412       }
413    }
414    /* Pitch analysis: we do it early to save on the peak stack space */
415    if (!shortBlocks)
416       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
417
418    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
419    
420    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
421    /* Compute MDCTs */
422    compute_mdcts(st->mode, shortBlocks, in, freq);
423
424 #ifdef EXP_PSY
425    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
426    for (i=0;i<N;i++)
427       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
428    for (c=1;c<C;c++)
429       for (i=0;i<N;i++)
430          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
431
432    ALLOC(mask, N, celt_sig_t);
433    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
434
435    /* Invert and stretch the mask to length of X 
436       For some reason, I get better results by using the sqrt instead,
437       although there's no valid reason to. Must investigate further */
438    for (i=0;i<C*N;i++)
439       mask[i] = 1/(.1+mask[i]);
440 #endif
441    
442    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
443    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
444    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
445    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
446    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
447
448    /*printf ("%f %f\n", curr_power, pitch_power);*/
449    /*int j;
450    for (j=0;j<B*N;j++)
451       printf ("%f ", X[j]);
452    for (j=0;j<B*N;j++)
453       printf ("%f ", P[j]);
454    printf ("\n");*/
455
456    /* Band normalisation */
457    compute_band_energies(st->mode, freq, bandE);
458    normalise_bands(st->mode, freq, X, bandE);
459    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
460    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
461
462    /* Compute MDCTs of the pitch part */
463    if (!shortBlocks)
464       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
465
466    {
467       /* Normalise the pitch vector as well (discard the energies) */
468       VARDECL(celt_ener_t, bandEp);
469       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
470       compute_band_energies(st->mode, freq, bandEp);
471       normalise_bands(st->mode, freq, P, bandEp);
472       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
473    }
474    curr_power = bandE[0]+bandE[1]+bandE[2];
475    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
476    if (!shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
477    {
478       /* Simulates intensity stereo */
479       /*for (i=30;i<N*B;i++)
480          X[i*C+1] = P[i*C+1] = 0;*/
481
482       /* Pitch prediction */
483       compute_pitch_gain(st->mode, X, P, gains);
484       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
485       if (has_pitch)
486          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
487    } else {
488       /* No pitch, so we just pretend we found a gain of zero */
489       for (i=0;i<st->mode->nbPBands;i++)
490          gains[i] = 0;
491       ec_enc_bits(&st->enc, 0, 7);
492       for (i=0;i<C*N;i++)
493          P[i] = 0;
494    }
495    quant_energy(st->mode, bandE, st->oldBandE, 20*C+nbCompressedBytes*8/5, st->mode->prob, &st->enc);
496
497    ALLOC(stereo_mode, st->mode->nbEBands, int);
498    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
499
500    pitch_quant_bands(st->mode, P, gains);
501
502    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
503
504    /* Residual quantisation */
505    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, nbCompressedBytes*8, shortBlocks, &st->enc);
506    
507    if (C==2)
508    {
509       renormalise_bands(st->mode, X);
510    }
511    /* Synthesis */
512    denormalise_bands(st->mode, X, freq, bandE);
513
514
515    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
516
517    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_gain, st->out_mem);
518    /* De-emphasis and put everything back at the right place in the synthesis history */
519 #ifndef SHORTCUTS
520    for (c=0;c<C;c++)
521    {
522       int j;
523       celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
524       celt_int16_t * restrict pcmp = pcm+c;
525       for (j=0;j<N;j++)
526       {
527          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
528          st->preemph_memD[c] = tmp;
529          *pcmp = SIG2INT16(tmp);
530          pcmp += C;
531          outp += C;
532       }
533    }
534 #endif
535    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
536       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
537    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
538    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
539    {
540       int val = 0;
541       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
542       {
543          ec_enc_uint(&st->enc, val, 2);
544          val = 1-val;
545       }
546    }
547    ec_enc_done(&st->enc);
548    {
549       unsigned char *data;
550       int nbBytes = ec_byte_bytes(&st->buf);
551       if (nbBytes > nbCompressedBytes)
552       {
553          celt_warning_int ("got too many bytes:", nbBytes);
554          RESTORE_STACK;
555          return CELT_INTERNAL_ERROR;
556       }
557       /*printf ("%d\n", *nbBytes);*/
558       data = ec_byte_get_buffer(&st->buf);
559       for (i=0;i<nbBytes;i++)
560          compressed[i] = data[i];
561       for (;i<nbCompressedBytes;i++)
562          compressed[i] = 0;
563    }
564    /* Reset the packing for the next encoding */
565    ec_byte_reset(&st->buf);
566    ec_enc_init(&st->enc,&st->buf);
567
568    RESTORE_STACK;
569    return nbCompressedBytes;
570 }
571
572
573 /****************************************************************************/
574 /*                                                                          */
575 /*                                DECODER                                   */
576 /*                                                                          */
577 /****************************************************************************/
578
579
580 /** Decoder state 
581  @brief Decoder state
582  */
583 struct CELTDecoder {
584    const CELTMode *mode;
585    int frame_size;
586    int block_size;
587    int overlap;
588
589    ec_byte_buffer buf;
590    ec_enc         enc;
591
592    celt_sig_t * restrict preemph_memD;
593
594    celt_sig_t *out_mem;
595
596    celt_word16_t *oldBandE;
597    
598    int last_pitch_index;
599 };
600
601 CELTDecoder *celt_decoder_create(const CELTMode *mode)
602 {
603    int N, C;
604    CELTDecoder *st;
605
606    if (check_mode(mode) != CELT_OK)
607       return NULL;
608
609    N = mode->mdctSize;
610    C = CHANNELS(mode);
611    st = celt_alloc(sizeof(CELTDecoder));
612    
613    st->mode = mode;
614    st->frame_size = N;
615    st->block_size = N;
616    st->overlap = mode->overlap;
617
618    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
619    
620    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
621
622    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));;
623
624    st->last_pitch_index = 0;
625    return st;
626 }
627
628 void celt_decoder_destroy(CELTDecoder *st)
629 {
630    if (st == NULL)
631    {
632       celt_warning("NULL passed to celt_encoder_destroy");
633       return;
634    }
635    if (check_mode(st->mode) != CELT_OK)
636       return;
637
638
639    celt_free(st->out_mem);
640    
641    celt_free(st->oldBandE);
642    
643    celt_free(st->preemph_memD);
644
645    celt_free(st);
646 }
647
648 /** Handles lost packets by just copying past data with the same offset as the last
649     pitch period */
650 static void celt_decode_lost(CELTDecoder * restrict st, short * restrict pcm)
651 {
652    int c, N;
653    int pitch_index;
654    int i, len;
655    VARDECL(celt_sig_t, freq);
656    const int C = CHANNELS(st->mode);
657    int offset;
658    SAVE_STACK;
659    N = st->block_size;
660    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
661    
662    len = N+st->mode->overlap;
663 #if 0
664    pitch_index = st->last_pitch_index;
665    
666    /* Use the pitch MDCT as the "guessed" signal */
667    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
668
669 #else
670    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
671    pitch_index = MAX_PERIOD-len-pitch_index;
672    offset = MAX_PERIOD-pitch_index;
673    while (offset+len >= MAX_PERIOD)
674       offset -= pitch_index;
675    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
676    for (i=0;i<N;i++)
677       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
678 #endif
679    
680    
681    
682    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
683    /* Compute inverse MDCTs */
684    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
685
686    for (c=0;c<C;c++)
687    {
688       int j;
689       for (j=0;j<N;j++)
690       {
691          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
692                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
693          st->preemph_memD[c] = tmp;
694          pcm[C*j+c] = SIG2INT16(tmp);
695       }
696    }
697    RESTORE_STACK;
698 }
699
700 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
701 {
702    int c, N, N4;
703    int has_pitch;
704    int pitch_index;
705    ec_dec dec;
706    ec_byte_buffer buf;
707    VARDECL(celt_sig_t, freq);
708    VARDECL(celt_norm_t, X);
709    VARDECL(celt_norm_t, P);
710    VARDECL(celt_ener_t, bandE);
711    VARDECL(celt_pgain_t, gains);
712    VARDECL(int, stereo_mode);
713    int shortBlocks;
714    int transient_time;
715    float transient_gain;
716    const int C = CHANNELS(st->mode);
717    SAVE_STACK;
718
719    if (check_mode(st->mode) != CELT_OK)
720       return CELT_INVALID_MODE;
721
722    N = st->block_size;
723    N4 = (N-st->overlap)>>1;
724
725    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
726    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
727    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
728    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
729    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
730    
731    if (check_mode(st->mode) != CELT_OK)
732    {
733       RESTORE_STACK;
734       return CELT_INVALID_MODE;
735    }
736    if (data == NULL)
737    {
738       celt_decode_lost(st, pcm);
739       RESTORE_STACK;
740       return 0;
741    }
742    
743    ec_byte_readinit(&buf,data,len);
744    ec_dec_init(&dec,&buf);
745    
746    shortBlocks = ec_dec_bits(&dec, 1);
747    if (shortBlocks)
748    {
749       int gainid = ec_dec_bits(&dec, 2);
750       switch(gainid) {
751          case 0:
752             transient_gain = 1;
753             break;
754          case 1:
755             transient_gain = 2;
756             break;
757          case 2:
758             transient_gain = 4;
759             break;
760          case 3:
761          default:
762             transient_gain = 8;
763             break;
764       }
765       transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
766    } else {
767       transient_time = -1;
768       transient_gain = 1;
769    }
770    /* Get the pitch gains */
771    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
772    
773    /* Get the pitch index */
774    if (has_pitch)
775    {
776       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
777       st->last_pitch_index = pitch_index;
778    } else {
779       /* FIXME: We could be more intelligent here and just not compute the MDCT */
780       pitch_index = 0;
781    }
782
783    /* Get band energies */
784    unquant_energy(st->mode, bandE, st->oldBandE, 20*C+len*8/5, st->mode->prob, &dec);
785
786    /* Pitch MDCT */
787    compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
788
789    {
790       VARDECL(celt_ener_t, bandEp);
791       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
792       compute_band_energies(st->mode, freq, bandEp);
793       normalise_bands(st->mode, freq, P, bandEp);
794    }
795
796    ALLOC(stereo_mode, st->mode->nbEBands, int);
797    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
798    /* Apply pitch gains */
799    pitch_quant_bands(st->mode, P, gains);
800
801    /* Decode fixed codebook and merge with pitch */
802    unquant_bands(st->mode, X, P, bandE, stereo_mode, len*8, shortBlocks, &dec);
803
804    if (C==2)
805    {
806       renormalise_bands(st->mode, X);
807    }
808    /* Synthesis */
809    denormalise_bands(st->mode, X, freq, bandE);
810
811
812    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
813    /* Compute inverse MDCTs */
814    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_gain, st->out_mem);
815
816    for (c=0;c<C;c++)
817    {
818       int j;
819       const celt_sig_t * restrict outp=st->out_mem+C*(MAX_PERIOD-N)+c;
820       celt_int16_t * restrict pcmp = pcm+c;
821       for (j=0;j<N;j++)
822       {
823          celt_sig_t tmp = ADD32(*outp, MULT16_32_Q15(preemph,st->preemph_memD[c]));
824          st->preemph_memD[c] = tmp;
825          *pcmp = SIG2INT16(tmp);
826          pcmp += C;
827          outp += C;
828       }
829    }
830
831    {
832       unsigned int val = 0;
833       while (ec_dec_tell(&dec, 0) < len*8)
834       {
835          if (ec_dec_uint(&dec, 2) != val)
836          {
837             celt_warning("decode error");
838             RESTORE_STACK;
839             return CELT_CORRUPTED_DATA;
840          }
841          val = 1-val;
842       }
843    }
844
845    RESTORE_STACK;
846    return 0;
847    /*printf ("\n");*/
848 }
849