Added signalling bits for enabling/disabling pitch, short blocks, and folding.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53
54 static const celt_word16_t preemph = QCONST16(0.8f,15);
55
56 #ifdef FIXED_POINT
57 static const celt_word16_t transientWindow[16] = {
58      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
59    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
60 #else
61 static const float transientWindow[16] = {
62    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
63    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
64 #endif
65
66    
67 /** Encoder state 
68  @brief Encoder state
69  */
70 struct CELTEncoder {
71    const CELTMode *mode;     /**< Mode used by the encoder */
72    int frame_size;
73    int block_size;
74    int overlap;
75    int channels;
76    
77    int pitch_enabled;
78    
79    ec_byte_buffer buf;
80    ec_enc         enc;
81
82    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
83    celt_sig_t    * restrict preemph_memD;
84
85    celt_sig_t *in_mem;
86    celt_sig_t *out_mem;
87
88    celt_word16_t *oldBandE;
89 #ifdef EXP_PSY
90    celt_word16_t *psy_mem;
91    struct PsyDecay psy;
92 #endif
93 };
94
95 CELTEncoder *celt_encoder_create(const CELTMode *mode)
96 {
97    int N, C;
98    CELTEncoder *st;
99
100    if (check_mode(mode) != CELT_OK)
101       return NULL;
102
103    N = mode->mdctSize;
104    C = mode->nbChannels;
105    st = celt_alloc(sizeof(CELTEncoder));
106    
107    st->mode = mode;
108    st->frame_size = N;
109    st->block_size = N;
110    st->overlap = mode->overlap;
111
112    st->pitch_enabled = 1;
113    
114    ec_byte_writeinit(&st->buf);
115    ec_enc_init(&st->enc,&st->buf);
116
117    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
118    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
119
120    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
121
122    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
123    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
124
125 #ifdef EXP_PSY
126    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
127    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
128 #endif
129
130    return st;
131 }
132
133 void celt_encoder_destroy(CELTEncoder *st)
134 {
135    if (st == NULL)
136    {
137       celt_warning("NULL passed to celt_encoder_destroy");
138       return;
139    }
140    if (check_mode(st->mode) != CELT_OK)
141       return;
142
143    ec_byte_writeclear(&st->buf);
144
145    celt_free(st->in_mem);
146    celt_free(st->out_mem);
147    
148    celt_free(st->oldBandE);
149    
150    celt_free(st->preemph_memE);
151    celt_free(st->preemph_memD);
152    
153 #ifdef EXP_PSY
154    celt_free (st->psy_mem);
155    psydecay_clear(&st->psy);
156 #endif
157    
158    celt_free(st);
159 }
160
161 static inline celt_int16_t FLOAT2INT16(float x)
162 {
163    x = x*32768.;
164    x = MAX32(x, -32768);
165    x = MIN32(x, 32767);
166    return (celt_int16_t)floor(.5+x);
167 }
168
169 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
170 {
171 #ifdef FIXED_POINT
172    x = PSHR32(x, SIG_SHIFT);
173    x = MAX32(x, -32768);
174    x = MIN32(x, 32767);
175    return EXTRACT16(x);
176 #else
177    return (celt_word16_t)x;
178 #endif
179 }
180
181 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
182 {
183    int c, i, n;
184    celt_word32_t ratio;
185    /* FIXME: Remove the floats here */
186    VARDECL(celt_word32_t, begin);
187    SAVE_STACK;
188    ALLOC(begin, len, celt_word32_t);
189    for (i=0;i<len;i++)
190       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
191    for (c=1;c<C;c++)
192    {
193       for (i=0;i<len;i++)
194          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
195    }
196    for (i=1;i<len;i++)
197       begin[i] = MAX32(begin[i-1],begin[i]);
198    n = -1;
199    for (i=8;i<len-8;i++)
200    {
201       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
202          n=i;
203    }
204    if (n<32)
205    {
206       n = -1;
207       ratio = 0;
208    } else {
209       ratio = DIV32(begin[len-1],1+begin[n-16]);
210    }
211    /*printf ("%d %f\n", n, ratio*ratio);*/
212    if (ratio < 0)
213       ratio = 0;
214    if (ratio > 1000)
215       ratio = 1000;
216    ratio *= ratio;
217    if (ratio < 50)
218       *transient_shift = 0;
219    else if (ratio < 256)
220       *transient_shift = 1;
221    else if (ratio < 4096)
222       *transient_shift = 2;
223    else
224       *transient_shift = 3;
225    *transient_time = n;
226    
227    RESTORE_STACK;
228    return ratio > 20;
229 }
230
231 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
232 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
233 {
234    const int C = CHANNELS(mode);
235    if (C==1 && !shortBlocks)
236    {
237       const mdct_lookup *lookup = MDCT(mode);
238       const int overlap = OVERLAP(mode);
239       mdct_forward(lookup, in, out, mode->window, overlap);
240    } else if (!shortBlocks) {
241       const mdct_lookup *lookup = MDCT(mode);
242       const int overlap = OVERLAP(mode);
243       const int N = FRAMESIZE(mode);
244       int c;
245       VARDECL(celt_word32_t, x);
246       VARDECL(celt_word32_t, tmp);
247       SAVE_STACK;
248       ALLOC(x, N+overlap, celt_word32_t);
249       ALLOC(tmp, N, celt_word32_t);
250       for (c=0;c<C;c++)
251       {
252          int j;
253          for (j=0;j<N+overlap;j++)
254             x[j] = in[C*j+c];
255          mdct_forward(lookup, x, tmp, mode->window, overlap);
256          /* Interleaving the sub-frames */
257          for (j=0;j<N;j++)
258             out[C*j+c] = tmp[j];
259       }
260       RESTORE_STACK;
261    } else {
262       const mdct_lookup *lookup = &mode->shortMdct;
263       const int overlap = mode->shortMdctSize;
264       const int N = mode->shortMdctSize;
265       int b, c;
266       VARDECL(celt_word32_t, x);
267       VARDECL(celt_word32_t, tmp);
268       SAVE_STACK;
269       ALLOC(x, N+overlap, celt_word32_t);
270       ALLOC(tmp, N, celt_word32_t);
271       for (c=0;c<C;c++)
272       {
273          int B = mode->nbShortMdcts;
274          for (b=0;b<B;b++)
275          {
276             int j;
277             for (j=0;j<N+overlap;j++)
278                x[j] = in[C*(b*N+j)+c];
279             mdct_forward(lookup, x, tmp, mode->window, overlap);
280             /* Interleaving the sub-frames */
281             for (j=0;j<N;j++)
282                out[C*(j*B+b)+c] = tmp[j];
283          }
284       }
285       RESTORE_STACK;
286    }
287 }
288
289 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
290 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
291 {
292    int c, N4;
293    const int C = CHANNELS(mode);
294    const int N = FRAMESIZE(mode);
295    const int overlap = OVERLAP(mode);
296    N4 = (N-overlap)>>1;
297    for (c=0;c<C;c++)
298    {
299       int j;
300       if (transient_shift==0 && C==1 && !shortBlocks) {
301          const mdct_lookup *lookup = MDCT(mode);
302          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
303       } else if (!shortBlocks) {
304          const mdct_lookup *lookup = MDCT(mode);
305          VARDECL(celt_word32_t, x);
306          VARDECL(celt_word32_t, tmp);
307          SAVE_STACK;
308          ALLOC(x, 2*N, celt_word32_t);
309          ALLOC(tmp, N, celt_word32_t);
310          /* De-interleaving the sub-frames */
311          for (j=0;j<N;j++)
312             tmp[j] = X[C*j+c];
313          /* Prevents problems from the imdct doing the overlap-add */
314          CELT_MEMSET(x+N4, 0, overlap);
315          mdct_backward(lookup, tmp, x, mode->window, overlap);
316          celt_assert(transient_shift == 0);
317          /* The first and last part would need to be set to zero if we actually
318             wanted to use them. */
319          for (j=0;j<overlap;j++)
320             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
321          for (j=0;j<overlap;j++)
322             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
323          for (j=0;j<2*N4;j++)
324             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
325          RESTORE_STACK;
326       } else {
327          int b;
328          const int N2 = mode->shortMdctSize;
329          const int B = mode->nbShortMdcts;
330          const mdct_lookup *lookup = &mode->shortMdct;
331          VARDECL(celt_word32_t, x);
332          VARDECL(celt_word32_t, tmp);
333          SAVE_STACK;
334          ALLOC(x, 2*N, celt_word32_t);
335          ALLOC(tmp, N, celt_word32_t);
336          /* Prevents problems from the imdct doing the overlap-add */
337          CELT_MEMSET(x+N4, 0, overlap);
338          for (b=0;b<B;b++)
339          {
340             /* De-interleaving the sub-frames */
341             for (j=0;j<N2;j++)
342                tmp[j] = X[C*(j*B+b)+c];
343             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
344          }
345          if (transient_shift > 0)
346          {
347 #ifdef FIXED_POINT
348             for (j=0;j<16;j++)
349                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
350             for (j=transient_time;j<N+overlap;j++)
351                x[N4+j] = SHL32(x[N4+j], transient_shift);
352 #else
353             for (j=0;j<16;j++)
354                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
355             for (j=transient_time;j<N+overlap;j++)
356                x[N4+j] *= 1<<transient_shift;
357 #endif
358          }
359          /* The first and last part would need to be set to zero if we actually
360          wanted to use them. */
361          for (j=0;j<overlap;j++)
362             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
363          for (j=0;j<overlap;j++)
364             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
365          for (j=0;j<2*N4;j++)
366             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
367          RESTORE_STACK;
368       }
369    }
370 }
371
372 #ifdef FIXED_POINT
373 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
374 {
375 #else
376 int celt_encode_float(CELTEncoder * restrict st, celt_sig_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
377 {
378 #endif
379    int i, c, N, N4;
380    int has_pitch;
381    int pitch_index;
382    int bits;
383    celt_word32_t curr_power, pitch_power=0;
384    VARDECL(celt_sig_t, in);
385    VARDECL(celt_sig_t, freq);
386    VARDECL(celt_norm_t, X);
387    VARDECL(celt_norm_t, P);
388    VARDECL(celt_ener_t, bandE);
389    VARDECL(celt_pgain_t, gains);
390    VARDECL(int, stereo_mode);
391    VARDECL(int, fine_quant);
392    VARDECL(celt_word16_t, error);
393    VARDECL(int, pulses);
394    VARDECL(int, offsets);
395 #ifdef EXP_PSY
396    VARDECL(celt_word32_t, mask);
397 #endif
398    int shortBlocks=0;
399    int transient_time;
400    int transient_shift;
401    const int C = CHANNELS(st->mode);
402    SAVE_STACK;
403
404    if (check_mode(st->mode) != CELT_OK)
405       return CELT_INVALID_MODE;
406
407    N = st->block_size;
408    N4 = (N-st->overlap)>>1;
409    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
410
411    CELT_COPY(in, st->in_mem, C*st->overlap);
412    for (c=0;c<C;c++)
413    {
414       const celt_word16_t * restrict pcmp = pcm+c;
415       celt_sig_t * restrict inp = in+C*st->overlap+c;
416       for (i=0;i<N;i++)
417       {
418          /* Apply pre-emphasis */
419          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
420          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
421          st->preemph_memE[c] = SCALEIN(*pcmp);
422          inp += C;
423          pcmp += C;
424       }
425    }
426    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
427    
428    if (st->mode->nbShortMdcts > 1)
429    {
430       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
431       {
432 #ifndef FIXED_POINT
433          float gain_1;
434 #endif
435          ec_enc_bits(&st->enc, 0, 1); //Pitch off
436          ec_enc_bits(&st->enc, 1, 1); //Transient on
437          ec_enc_bits(&st->enc, transient_shift, 2);
438          if (transient_shift)
439             ec_enc_uint(&st->enc, transient_time, N+st->overlap);
440          if (transient_shift)
441          {
442 #ifdef FIXED_POINT
443             for (c=0;c<C;c++)
444                for (i=0;i<16;i++)
445                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
446             for (c=0;c<C;c++)
447                for (i=transient_time;i<N+st->overlap;i++)
448                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
449 #else
450             for (c=0;c<C;c++)
451                for (i=0;i<16;i++)
452                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
453             gain_1 = 1./(1<<transient_shift);
454             for (c=0;c<C;c++)
455                for (i=transient_time;i<N+st->overlap;i++)
456                   in[C*i+c] *= gain_1;
457 #endif
458          }
459          shortBlocks = 1;
460       } else {
461          transient_time = -1;
462          transient_shift = 0;
463          shortBlocks = 0;
464       }
465    } else {
466       transient_time = -1;
467       transient_shift = 0;
468       shortBlocks = 0;
469    }
470    /* Pitch analysis: we do it early to save on the peak stack space */
471    if (st->pitch_enabled && !shortBlocks)
472       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
473
474    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
475    
476    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
477    /* Compute MDCTs */
478    compute_mdcts(st->mode, shortBlocks, in, freq);
479
480 #ifdef EXP_PSY
481    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
482    for (i=0;i<N;i++)
483       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
484    for (c=1;c<C;c++)
485       for (i=0;i<N;i++)
486          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
487
488    ALLOC(mask, N, celt_sig_t);
489    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
490
491    /* Invert and stretch the mask to length of X 
492       For some reason, I get better results by using the sqrt instead,
493       although there's no valid reason to. Must investigate further */
494    for (i=0;i<C*N;i++)
495       mask[i] = 1/(.1+mask[i]);
496 #endif
497    
498    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
499    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
500    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
501    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
502    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
503
504    /*printf ("%f %f\n", curr_power, pitch_power);*/
505    /*int j;
506    for (j=0;j<B*N;j++)
507       printf ("%f ", X[j]);
508    for (j=0;j<B*N;j++)
509       printf ("%f ", P[j]);
510    printf ("\n");*/
511
512    /* Band normalisation */
513    compute_band_energies(st->mode, freq, bandE);
514    normalise_bands(st->mode, freq, X, bandE);
515    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
516    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
517
518    /* Compute MDCTs of the pitch part */
519    if (st->pitch_enabled && !shortBlocks)
520    {
521       /* Normalise the pitch vector as well (discard the energies) */
522       VARDECL(celt_ener_t, bandEp);
523       
524       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
525       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
526       compute_band_energies(st->mode, freq, bandEp);
527       normalise_bands(st->mode, freq, P, bandEp);
528       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
529    }
530    curr_power = bandE[0]+bandE[1]+bandE[2];
531    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
532    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
533    {
534       /* Simulates intensity stereo */
535       /*for (i=30;i<N*B;i++)
536          X[i*C+1] = P[i*C+1] = 0;*/
537
538       /* Pitch prediction */
539       compute_pitch_gain(st->mode, X, P, gains);
540       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
541       if (has_pitch)
542          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
543    } else {
544       if (!shortBlocks)
545       {
546          ec_enc_bits(&st->enc, 0, 1); //Pitch off
547          ec_enc_bits(&st->enc, 0, 1); //Transient off
548       }
549       /* No pitch, so we just pretend we found a gain of zero */
550       for (i=0;i<st->mode->nbPBands;i++)
551          gains[i] = 0;
552       for (i=0;i<C*N;i++)
553          P[i] = 0;
554    }
555
556 #ifdef STDIN_TUNING2
557    static int fine_quant[30];
558    static int pulses[30];
559    static int init=0;
560    if (!init)
561    {
562       for (i=0;i<st->mode->nbEBands;i++)
563          scanf("%d ", &fine_quant[i]);
564       for (i=0;i<st->mode->nbEBands;i++)
565          scanf("%d ", &pulses[i]);
566       init = 1;
567    }
568 #else
569    ALLOC(fine_quant, st->mode->nbEBands, int);
570    ALLOC(pulses, st->mode->nbEBands, int);
571 #endif
572    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
573    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &st->enc);
574    
575    ALLOC(offsets, st->mode->nbEBands, int);
576    ALLOC(stereo_mode, st->mode->nbEBands, int);
577    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
578
579    for (i=0;i<st->mode->nbEBands;i++)
580       offsets[i] = 0;
581    bits = nbCompressedBytes*8 - ec_enc_tell(&st->enc, 0) - 1;
582 #ifndef STDIN_TUNING
583    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
584 #endif
585    /*for (i=0;i<st->mode->nbEBands;i++)
586       printf("%d ", fine_quant[i]);
587    for (i=0;i<st->mode->nbEBands;i++)
588       printf("%d ", pulses[i]);
589    printf ("\n");*/
590    /*bits = ec_enc_tell(&st->enc, 0);
591    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
592    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
593
594    pitch_quant_bands(st->mode, P, gains);
595
596    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
597
598    /* Residual quantisation */
599    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, &st->enc);
600    
601    if (C==2)
602    {
603       renormalise_bands(st->mode, X);
604    }
605    /* Synthesis */
606    denormalise_bands(st->mode, X, freq, bandE);
607
608
609    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
610
611    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
612    /* De-emphasis and put everything back at the right place in the synthesis history */
613 #ifndef SHORTCUTS
614    for (c=0;c<C;c++)
615    {
616       int j;
617       for (j=0;j<N;j++)
618       {
619          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
620                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
621          st->preemph_memD[c] = tmp;
622          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
623       }
624    }
625 #endif
626    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
627       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
628    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
629    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
630    {
631       int val = 0;
632       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
633       {
634          ec_enc_uint(&st->enc, val, 2);
635          val = 1-val;
636       }
637    }
638    ec_enc_done(&st->enc);
639    {
640       unsigned char *data;
641       int nbBytes = ec_byte_bytes(&st->buf);
642       if (nbBytes > nbCompressedBytes)
643       {
644          celt_warning_int ("got too many bytes:", nbBytes);
645          RESTORE_STACK;
646          return CELT_INTERNAL_ERROR;
647       }
648       /*printf ("%d\n", *nbBytes);*/
649       data = ec_byte_get_buffer(&st->buf);
650       for (i=0;i<nbBytes;i++)
651          compressed[i] = data[i];
652       for (;i<nbCompressedBytes;i++)
653          compressed[i] = 0;
654    }
655    /* Reset the packing for the next encoding */
656    ec_byte_reset(&st->buf);
657    ec_enc_init(&st->enc,&st->buf);
658
659    RESTORE_STACK;
660    return nbCompressedBytes;
661 }
662
663 #ifdef FIXED_POINT
664 #ifndef DISABLE_FLOAT_API
665 int celt_encode_float(CELTEncoder * restrict st, float * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
666 {
667    int j, ret;
668    const int C = CHANNELS(st->mode);
669    const int N = st->block_size;
670    ALLOC(in, C*N, celt_int16_t);
671
672    for (j=0;j<C*N;j++)
673      in[j] = FLOAT2INT16(pcm[j]);
674
675    ret=celt_encode(st,in,compressed,nbCompressedBytes);
676 #ifndef SHORTCUTS
677    /*Converts backwards for inplace operation*/
678    for (j=0;j=C*N;j++)
679      pcm[j]=in[j]*(1/32768.);
680 #endif
681    return ret;
682
683 }
684 #endif /*DISABLE_FLOAT_API*/
685 #else
686 int celt_encode(CELTEncoder * restrict st, celt_int16_t * restrict pcm, unsigned char *compressed, int nbCompressedBytes)
687 {
688    int j, ret;
689    VARDECL(celt_sig_t, in);
690    const int C = CHANNELS(st->mode);
691    const int N = st->block_size;
692
693    ALLOC(in, C*N, celt_sig_t);
694    for (j=0;j<C*N;j++) {
695      in[j] = SCALEOUT(pcm[j]);
696    }
697    ret = celt_encode_float(st,in,compressed,nbCompressedBytes);
698 #ifndef SHORTCUTS
699    for (j=0;j<C*N;j++)
700      pcm[j] = FLOAT2INT16(in[j]);
701
702 #endif
703    return ret;
704 }
705 #endif
706
707 /****************************************************************************/
708 /*                                                                          */
709 /*                                DECODER                                   */
710 /*                                                                          */
711 /****************************************************************************/
712
713
714 /** Decoder state 
715  @brief Decoder state
716  */
717 struct CELTDecoder {
718    const CELTMode *mode;
719    int frame_size;
720    int block_size;
721    int overlap;
722
723    ec_byte_buffer buf;
724    ec_enc         enc;
725
726    celt_sig_t * restrict preemph_memD;
727
728    celt_sig_t *out_mem;
729
730    celt_word16_t *oldBandE;
731    
732    int last_pitch_index;
733 };
734
735 CELTDecoder *celt_decoder_create(const CELTMode *mode)
736 {
737    int N, C;
738    CELTDecoder *st;
739
740    if (check_mode(mode) != CELT_OK)
741       return NULL;
742
743    N = mode->mdctSize;
744    C = CHANNELS(mode);
745    st = celt_alloc(sizeof(CELTDecoder));
746    
747    st->mode = mode;
748    st->frame_size = N;
749    st->block_size = N;
750    st->overlap = mode->overlap;
751
752    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
753    
754    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
755
756    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
757
758    st->last_pitch_index = 0;
759    return st;
760 }
761
762 void celt_decoder_destroy(CELTDecoder *st)
763 {
764    if (st == NULL)
765    {
766       celt_warning("NULL passed to celt_encoder_destroy");
767       return;
768    }
769    if (check_mode(st->mode) != CELT_OK)
770       return;
771
772
773    celt_free(st->out_mem);
774    
775    celt_free(st->oldBandE);
776    
777    celt_free(st->preemph_memD);
778
779    celt_free(st);
780 }
781
782 /** Handles lost packets by just copying past data with the same offset as the last
783     pitch period */
784 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
785 {
786    int c, N;
787    int pitch_index;
788    int i, len;
789    VARDECL(celt_sig_t, freq);
790    const int C = CHANNELS(st->mode);
791    int offset;
792    SAVE_STACK;
793    N = st->block_size;
794    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
795    
796    len = N+st->mode->overlap;
797 #if 0
798    pitch_index = st->last_pitch_index;
799    
800    /* Use the pitch MDCT as the "guessed" signal */
801    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
802
803 #else
804    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
805    pitch_index = MAX_PERIOD-len-pitch_index;
806    offset = MAX_PERIOD-pitch_index;
807    while (offset+len >= MAX_PERIOD)
808       offset -= pitch_index;
809    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
810    for (i=0;i<N;i++)
811       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
812 #endif
813    
814    
815    
816    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
817    /* Compute inverse MDCTs */
818    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
819
820    for (c=0;c<C;c++)
821    {
822       int j;
823       for (j=0;j<N;j++)
824       {
825          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
826                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
827          st->preemph_memD[c] = tmp;
828          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
829       }
830    }
831    RESTORE_STACK;
832 }
833
834 #ifdef FIXED_POINT
835 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
836 {
837 #else
838 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
839 {
840 #endif
841    int i, c, N, N4;
842    int has_pitch, has_fold;
843    int pitch_index;
844    int bits;
845    ec_dec dec;
846    ec_byte_buffer buf;
847    VARDECL(celt_sig_t, freq);
848    VARDECL(celt_norm_t, X);
849    VARDECL(celt_norm_t, P);
850    VARDECL(celt_ener_t, bandE);
851    VARDECL(celt_pgain_t, gains);
852    VARDECL(int, stereo_mode);
853    VARDECL(int, fine_quant);
854    VARDECL(int, pulses);
855    VARDECL(int, offsets);
856
857    int shortBlocks;
858    int transient_time;
859    int transient_shift;
860    const int C = CHANNELS(st->mode);
861    SAVE_STACK;
862
863    if (check_mode(st->mode) != CELT_OK)
864       return CELT_INVALID_MODE;
865
866    N = st->block_size;
867    N4 = (N-st->overlap)>>1;
868
869    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
870    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
871    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
872    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
873    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
874    
875    if (check_mode(st->mode) != CELT_OK)
876    {
877       RESTORE_STACK;
878       return CELT_INVALID_MODE;
879    }
880    if (data == NULL)
881    {
882       celt_decode_lost(st, pcm);
883       RESTORE_STACK;
884       return 0;
885    }
886    
887    ec_byte_readinit(&buf,data,len);
888    ec_dec_init(&dec,&buf);
889    
890    if (st->mode->nbShortMdcts > 1)
891    {
892       has_pitch = ec_dec_bits(&dec, 1);
893       if (has_pitch)
894       {
895          has_fold = ec_dec_bits(&dec, 1);
896          shortBlocks = 0;
897       } else {
898          shortBlocks = ec_dec_bits(&dec, 1);
899          has_fold = 1;
900       }
901       if (shortBlocks)
902       {
903          transient_shift = ec_dec_bits(&dec, 2);
904          if (transient_shift)
905             transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
906          else
907             transient_time = 0;
908       } else {
909          transient_time = -1;
910          transient_shift = 0;
911       }
912    } else {
913       has_pitch = ec_dec_bits(&dec, 1);
914       if (has_pitch)
915          has_fold = ec_dec_bits(&dec, 1);
916       else
917          has_fold = 1;
918       shortBlocks = 0;
919       transient_time = -1;
920       transient_shift = 0;
921    }
922    /* Get the pitch gains */
923    
924    /* Get the pitch index */
925    if (has_pitch)
926    {
927       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
928       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
929       st->last_pitch_index = pitch_index;
930    } else {
931       /* FIXME: We could be more intelligent here and just not compute the MDCT */
932       pitch_index = 0;
933       for (i=0;i<st->mode->nbPBands;i++)
934          gains[i] = 0;
935    }
936
937    ALLOC(fine_quant, st->mode->nbEBands, int);
938    /* Get band energies */
939    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
940    
941    ALLOC(pulses, st->mode->nbEBands, int);
942    ALLOC(offsets, st->mode->nbEBands, int);
943    ALLOC(stereo_mode, st->mode->nbEBands, int);
944    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
945
946    for (i=0;i<st->mode->nbEBands;i++)
947       offsets[i] = 0;
948
949    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
950    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
951    /*bits = ec_dec_tell(&dec, 0);
952    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
953    
954    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
955
956
957    if (has_pitch) 
958    {
959       VARDECL(celt_ener_t, bandEp);
960       
961       /* Pitch MDCT */
962       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
963       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
964       compute_band_energies(st->mode, freq, bandEp);
965       normalise_bands(st->mode, freq, P, bandEp);
966    } else {
967       for (i=0;i<C*N;i++)
968          P[i] = 0;
969    }
970
971    /* Apply pitch gains */
972    pitch_quant_bands(st->mode, P, gains);
973
974    /* Decode fixed codebook and merge with pitch */
975    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, &dec);
976
977    if (C==2)
978    {
979       renormalise_bands(st->mode, X);
980    }
981    /* Synthesis */
982    denormalise_bands(st->mode, X, freq, bandE);
983
984
985    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
986    /* Compute inverse MDCTs */
987    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
988
989    for (c=0;c<C;c++)
990    {
991       int j;
992       for (j=0;j<N;j++)
993       {
994          celt_sig_t tmp = ADD32(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
995                                 MULT16_32_Q15(preemph,st->preemph_memD[c]));
996          st->preemph_memD[c] = tmp;
997          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
998       }
999    }
1000
1001    {
1002       unsigned int val = 0;
1003       while (ec_dec_tell(&dec, 0) < len*8)
1004       {
1005          if (ec_dec_uint(&dec, 2) != val)
1006          {
1007             celt_warning("decode error");
1008             RESTORE_STACK;
1009             return CELT_CORRUPTED_DATA;
1010          }
1011          val = 1-val;
1012       }
1013    }
1014
1015    RESTORE_STACK;
1016    return 0;
1017    /*printf ("\n");*/
1018 }
1019
1020 #ifdef FIXED_POINT
1021 #ifndef DISABLE_FLOAT_API
1022 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1023 {
1024    int j, ret;
1025    const int C = CHANNELS(st->mode);
1026    const int N = st->block_size;
1027    ALLOC(out, C*N, celt_int16_t);
1028
1029    ret=celt_decode(st, data, len, out);
1030
1031    for (j=0;j<C*N;j++)
1032      pcm[j]=out[j]*(1/32768.);
1033
1034    return ret;
1035 }
1036 #endif /*DISABLE_FLOAT_API*/
1037 #else
1038 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1039 {
1040    int j, ret;
1041    VARDECL(celt_sig_t, out);
1042    const int C = CHANNELS(st->mode);
1043    const int N = st->block_size;
1044
1045    ALLOC(out, C*N, celt_sig_t);
1046
1047    ret=celt_decode_float(st, data, len, out);
1048
1049    for (j=0;j<C*N;j++)
1050      pcm[j] = FLOAT2INT16 (out[j]);
1051
1052    return ret;
1053 }
1054 #endif