869d8a80f62c4968065d5f5e23a68c35db11ee78
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    int has_fold=1;
378    ec_byte_buffer buf;
379    ec_enc         enc;
380    celt_word32_t curr_power, pitch_power=0;
381    VARDECL(celt_sig_t, in);
382    VARDECL(celt_sig_t, freq);
383    VARDECL(celt_norm_t, X);
384    VARDECL(celt_norm_t, P);
385    VARDECL(celt_ener_t, bandE);
386    VARDECL(celt_pgain_t, gains);
387    VARDECL(int, stereo_mode);
388    VARDECL(int, fine_quant);
389    VARDECL(celt_word16_t, error);
390    VARDECL(int, pulses);
391    VARDECL(int, offsets);
392 #ifdef EXP_PSY
393    VARDECL(celt_word32_t, mask);
394    VARDECL(celt_word32_t, tonality);
395    VARDECL(celt_word32_t, bandM);
396    VARDECL(celt_ener_t, bandN);
397 #endif
398    int shortBlocks=0;
399    int transient_time;
400    int transient_shift;
401    const int C = CHANNELS(st->mode);
402    SAVE_STACK;
403
404    if (check_mode(st->mode) != CELT_OK)
405       return CELT_INVALID_MODE;
406
407    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
408    CELT_MEMSET(compressed, 0, nbCompressedBytes);
409    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
410    ec_enc_init(&enc,&buf);
411
412    N = st->block_size;
413    N4 = (N-st->overlap)>>1;
414    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
415
416    CELT_COPY(in, st->in_mem, C*st->overlap);
417    for (c=0;c<C;c++)
418    {
419       const celt_word16_t * restrict pcmp = pcm+c;
420       celt_sig_t * restrict inp = in+C*st->overlap+c;
421       for (i=0;i<N;i++)
422       {
423          /* Apply pre-emphasis */
424          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
425          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
426          st->preemph_memE[c] = SCALEIN(*pcmp);
427          inp += C;
428          pcmp += C;
429       }
430    }
431    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
432    
433    if (st->mode->nbShortMdcts > 1)
434    {
435       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
436       {
437 #ifndef FIXED_POINT
438          float gain_1;
439 #endif
440          ec_enc_bits(&enc, 0, 1); //Pitch off
441          ec_enc_bits(&enc, 1, 1); //Transient on
442          ec_enc_bits(&enc, transient_shift, 2);
443          if (transient_shift)
444             ec_enc_uint(&enc, transient_time, N+st->overlap);
445          if (transient_shift)
446          {
447 #ifdef FIXED_POINT
448             for (c=0;c<C;c++)
449                for (i=0;i<16;i++)
450                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
451             for (c=0;c<C;c++)
452                for (i=transient_time;i<N+st->overlap;i++)
453                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
454 #else
455             for (c=0;c<C;c++)
456                for (i=0;i<16;i++)
457                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
458             gain_1 = 1./(1<<transient_shift);
459             for (c=0;c<C;c++)
460                for (i=transient_time;i<N+st->overlap;i++)
461                   in[C*i+c] *= gain_1;
462 #endif
463          }
464          shortBlocks = 1;
465       } else {
466          transient_time = -1;
467          transient_shift = 0;
468          shortBlocks = 0;
469       }
470    } else {
471       transient_time = -1;
472       transient_shift = 0;
473       shortBlocks = 0;
474    }
475
476    /* Pitch analysis: we do it early to save on the peak stack space */
477 #ifdef EXP_PSY
478    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
479    {
480       VARDECL(celt_word16_t, X);
481       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
482       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
483       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
484    }
485 #else
486    if (st->pitch_enabled && !shortBlocks)
487    {
488       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
489    }
490 #endif
491    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
492    
493    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
494    /* Compute MDCTs */
495    compute_mdcts(st->mode, shortBlocks, in, freq);
496
497 #ifdef EXP_PSY
498    /*CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
499    for (i=0;i<N;i++)
500       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
501    for (c=1;c<C;c++)
502       for (i=0;i<N;i++)
503          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
504    */
505    ALLOC(mask, N, celt_sig_t);
506    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
507    /*for (i=0;i<256;i++)
508       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
509    printf ("\n");*/
510
511    /* Invert and stretch the mask to length of X 
512       For some reason, I get better results by using the sqrt instead,
513       although there's no valid reason to. Must investigate further */
514    /*for (i=0;i<C*N;i++)
515       mask[i] = 1/(.1+mask[i]);*/
516 #endif
517    
518    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
519    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
520    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
521    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
522    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
523
524    /*printf ("%f %f\n", curr_power, pitch_power);*/
525    /*int j;
526    for (j=0;j<B*N;j++)
527       printf ("%f ", X[j]);
528    for (j=0;j<B*N;j++)
529       printf ("%f ", P[j]);
530    printf ("\n");*/
531
532    /* Band normalisation */
533    compute_band_energies(st->mode, freq, bandE);
534 #ifdef EXP_PSY
535    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
536    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
537    compute_noise_energies(st->mode, freq, tonality, bandN);
538
539    /*for (i=0;i<st->mode->nbEBands;i++)
540       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
541    printf ("\n");*/
542    has_fold = 0;
543    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
544       if (bandN[i] < .4*bandE[i])
545          has_fold++;
546    /*printf ("%d\n", has_fold);*/
547    if (has_fold>=2)
548       has_fold = 0;
549    else
550       has_fold = 1;
551    for (i=0;i<N;i++)
552       mask[i] = sqrt(mask[i]);
553    compute_band_energies(st->mode, mask, bandM);
554    /*for (i=0;i<st->mode->nbEBands;i++)
555       printf ("%f %f ", bandE[i], bandM[i]);
556    printf ("\n");*/
557 #endif
558    normalise_bands(st->mode, freq, X, bandE);
559    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
560    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
561
562    /* Compute MDCTs of the pitch part */
563    if (st->pitch_enabled && !shortBlocks)
564    {
565       /* Normalise the pitch vector as well (discard the energies) */
566       VARDECL(celt_ener_t, bandEp);
567       
568       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
569       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
570       compute_band_energies(st->mode, freq, bandEp);
571       normalise_bands(st->mode, freq, P, bandEp);
572       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
573    }
574    curr_power = bandE[0]+bandE[1]+bandE[2];
575    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
576    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
577    {
578       int id;
579       /* Simulates intensity stereo */
580       /*for (i=30;i<N*B;i++)
581          X[i*C+1] = P[i*C+1] = 0;*/
582
583       /* Pitch prediction */
584       compute_pitch_gain(st->mode, X, P, gains);
585       id = quant_pitch(gains, st->mode->nbPBands, &enc);
586       if (id != -1)
587          has_pitch = 1;
588       else
589          has_pitch = 0;
590       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
591       if (has_pitch)
592       {
593          ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
594          ec_enc_bits(&enc, id, 7);
595          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
596       } else if (st->mode->nbShortMdcts > 1) {
597          ec_enc_bits(&enc, 0, 1); /* Transient off */
598          has_fold = 1;
599       }
600    } else {
601       if (!shortBlocks)
602       {
603          ec_enc_bits(&enc, 0, 1); /* Pitch off */
604          if (st->mode->nbShortMdcts > 1)
605            ec_enc_bits(&enc, 0, 1); /* Transient off */
606       }
607       has_fold = 1;
608       /* No pitch, so we just pretend we found a gain of zero */
609       for (i=0;i<st->mode->nbPBands;i++)
610          gains[i] = 0;
611       for (i=0;i<C*N;i++)
612          P[i] = 0;
613    }
614
615 #ifdef STDIN_TUNING2
616    static int fine_quant[30];
617    static int pulses[30];
618    static int init=0;
619    if (!init)
620    {
621       for (i=0;i<st->mode->nbEBands;i++)
622          scanf("%d ", &fine_quant[i]);
623       for (i=0;i<st->mode->nbEBands;i++)
624          scanf("%d ", &pulses[i]);
625       init = 1;
626    }
627 #else
628    ALLOC(fine_quant, st->mode->nbEBands, int);
629    ALLOC(pulses, st->mode->nbEBands, int);
630 #endif
631    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
632    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
633    
634    ALLOC(offsets, st->mode->nbEBands, int);
635    ALLOC(stereo_mode, st->mode->nbEBands, int);
636    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
637
638    for (i=0;i<st->mode->nbEBands;i++)
639       offsets[i] = 0;
640    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
641 #ifndef STDIN_TUNING
642    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
643 #endif
644    /*for (i=0;i<st->mode->nbEBands;i++)
645       printf("%d ", fine_quant[i]);
646    for (i=0;i<st->mode->nbEBands;i++)
647       printf("%d ", pulses[i]);
648    printf ("\n");*/
649    /*bits = ec_enc_tell(&st->enc, 0);
650    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
651    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
652
653    pitch_quant_bands(st->mode, P, gains);
654
655    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
656
657    /* Residual quantisation */
658    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
659    
660    if (st->pitch_enabled || optional_synthesis!=NULL)
661    {
662       if (C==2)
663          renormalise_bands(st->mode, X);
664       /* Synthesis */
665       denormalise_bands(st->mode, X, freq, bandE);
666       
667       
668       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
669       
670       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
671       /* De-emphasis and put everything back at the right place in the synthesis history */
672       if (optional_synthesis != NULL) {
673          for (c=0;c<C;c++)
674          {
675             int j;
676             for (j=0;j<N;j++)
677             {
678                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
679                                    preemph,st->preemph_memD[c]);
680                st->preemph_memD[c] = tmp;
681                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
682             }
683          }
684       }
685    }
686    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
687    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
688       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
689    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
690    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
691    {
692       int val = 0;
693       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
694       {
695          ec_enc_uint(&enc, val, 2);
696          val = 1-val;
697       }
698    }
699    ec_enc_done(&enc);
700    {
701       /*unsigned char *data;*/
702       int nbBytes = ec_byte_bytes(&buf);
703       if (nbBytes > nbCompressedBytes)
704       {
705          celt_warning_int ("got too many bytes:", nbBytes);
706          RESTORE_STACK;
707          return CELT_INTERNAL_ERROR;
708       }
709       /*printf ("%d\n", *nbBytes);*/
710       /*data = ec_byte_get_buffer(&buf);
711       for (i=0;i<nbBytes;i++)
712          compressed[i] = data[i];
713       for (i=nbBytes;i<nbCompressedBytes;i++)
714          compressed[i] = 0;*/
715    }
716
717    RESTORE_STACK;
718    return nbCompressedBytes;
719 }
720
721 #ifdef FIXED_POINT
722 #ifndef DISABLE_FLOAT_API
723 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
724 {
725    int j, ret;
726    const int C = CHANNELS(st->mode);
727    const int N = st->block_size;
728    VARDECL(celt_int16_t, in);
729    SAVE_STACK;
730    ALLOC(in, C*N, celt_int16_t);
731
732    for (j=0;j<C*N;j++)
733      in[j] = FLOAT2INT16(pcm[j]);
734
735    if (optional_synthesis != NULL) {
736      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
737    /*Converts backwards for inplace operation*/
738       for (j=0;j=C*N;j++)
739          optional_synthesis[j]=in[j]*(1/32768.);
740    } else {
741      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
742    }
743    RESTORE_STACK;
744    return ret;
745
746 }
747 #endif /*DISABLE_FLOAT_API*/
748 #else
749 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
750 {
751    int j, ret;
752    VARDECL(celt_sig_t, in);
753    const int C = CHANNELS(st->mode);
754    const int N = st->block_size;
755    SAVE_STACK;
756    ALLOC(in, C*N, celt_sig_t);
757    for (j=0;j<C*N;j++) {
758      in[j] = SCALEOUT(pcm[j]);
759    }
760
761    if (optional_synthesis != NULL) {
762       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
763       for (j=0;j<C*N;j++)
764          optional_synthesis[j] = FLOAT2INT16(in[j]);
765    } else {
766       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
767    }
768    RESTORE_STACK;
769    return ret;
770 }
771 #endif
772
773 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
774 {
775    va_list ap;
776    va_start(ap, request);
777    switch (request)
778    {
779       case CELT_SET_COMPLEXITY_REQUEST:
780       {
781          int value = va_arg(ap, int);
782          if (value<0 || value>10)
783             goto bad_arg;
784          if (value<=2)
785             st->pitch_enabled = 0;
786          else
787             st->pitch_enabled = 1;
788       }
789       break;
790       default:
791          goto bad_request;
792    }
793    va_end(ap);
794    return CELT_OK;
795 bad_arg:
796    va_end(ap);
797    return CELT_BAD_ARG;
798 bad_request:
799    va_end(ap);
800    return CELT_UNIMPLEMENTED;
801 }
802
803 /****************************************************************************/
804 /*                                                                          */
805 /*                                DECODER                                   */
806 /*                                                                          */
807 /****************************************************************************/
808
809
810 /** Decoder state 
811  @brief Decoder state
812  */
813 struct CELTDecoder {
814    const CELTMode *mode;
815    int frame_size;
816    int block_size;
817    int overlap;
818
819    ec_byte_buffer buf;
820    ec_enc         enc;
821
822    celt_sig_t * restrict preemph_memD;
823
824    celt_sig_t *out_mem;
825
826    celt_word16_t *oldBandE;
827    
828    int last_pitch_index;
829 };
830
831 CELTDecoder *celt_decoder_create(const CELTMode *mode)
832 {
833    int N, C;
834    CELTDecoder *st;
835
836    if (check_mode(mode) != CELT_OK)
837       return NULL;
838
839    N = mode->mdctSize;
840    C = CHANNELS(mode);
841    st = celt_alloc(sizeof(CELTDecoder));
842    
843    st->mode = mode;
844    st->frame_size = N;
845    st->block_size = N;
846    st->overlap = mode->overlap;
847
848    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
849    
850    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
851
852    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
853
854    st->last_pitch_index = 0;
855    return st;
856 }
857
858 void celt_decoder_destroy(CELTDecoder *st)
859 {
860    if (st == NULL)
861    {
862       celt_warning("NULL passed to celt_encoder_destroy");
863       return;
864    }
865    if (check_mode(st->mode) != CELT_OK)
866       return;
867
868
869    celt_free(st->out_mem);
870    
871    celt_free(st->oldBandE);
872    
873    celt_free(st->preemph_memD);
874
875    celt_free(st);
876 }
877
878 /** Handles lost packets by just copying past data with the same offset as the last
879     pitch period */
880 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
881 {
882    int c, N;
883    int pitch_index;
884    int i, len;
885    VARDECL(celt_sig_t, freq);
886    const int C = CHANNELS(st->mode);
887    int offset;
888    SAVE_STACK;
889    N = st->block_size;
890    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
891    
892    len = N+st->mode->overlap;
893 #if 0
894    pitch_index = st->last_pitch_index;
895    
896    /* Use the pitch MDCT as the "guessed" signal */
897    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
898
899 #else
900    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
901    pitch_index = MAX_PERIOD-len-pitch_index;
902    offset = MAX_PERIOD-pitch_index;
903    while (offset+len >= MAX_PERIOD)
904       offset -= pitch_index;
905    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
906    for (i=0;i<N;i++)
907       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
908 #endif
909    
910    
911    
912    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
913    /* Compute inverse MDCTs */
914    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
915
916    for (c=0;c<C;c++)
917    {
918       int j;
919       for (j=0;j<N;j++)
920       {
921          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
922                                 preemph,st->preemph_memD[c]);
923          st->preemph_memD[c] = tmp;
924          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
925       }
926    }
927    RESTORE_STACK;
928 }
929
930 #ifdef FIXED_POINT
931 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
932 {
933 #else
934 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
935 {
936 #endif
937    int i, c, N, N4;
938    int has_pitch, has_fold;
939    int pitch_index;
940    int bits;
941    ec_dec dec;
942    ec_byte_buffer buf;
943    VARDECL(celt_sig_t, freq);
944    VARDECL(celt_norm_t, X);
945    VARDECL(celt_norm_t, P);
946    VARDECL(celt_ener_t, bandE);
947    VARDECL(celt_pgain_t, gains);
948    VARDECL(int, stereo_mode);
949    VARDECL(int, fine_quant);
950    VARDECL(int, pulses);
951    VARDECL(int, offsets);
952
953    int shortBlocks;
954    int transient_time;
955    int transient_shift;
956    const int C = CHANNELS(st->mode);
957    SAVE_STACK;
958
959    if (check_mode(st->mode) != CELT_OK)
960       return CELT_INVALID_MODE;
961
962    N = st->block_size;
963    N4 = (N-st->overlap)>>1;
964
965    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
966    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
967    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
968    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
969    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
970    
971    if (check_mode(st->mode) != CELT_OK)
972    {
973       RESTORE_STACK;
974       return CELT_INVALID_MODE;
975    }
976    if (data == NULL)
977    {
978       celt_decode_lost(st, pcm);
979       RESTORE_STACK;
980       return 0;
981    }
982    
983    ec_byte_readinit(&buf,data,len);
984    ec_dec_init(&dec,&buf);
985    
986    has_pitch = ec_dec_bits(&dec, 1);
987    if (has_pitch)
988    {
989       has_fold = ec_dec_bits(&dec, 1);
990       shortBlocks = 0;
991    } else if (st->mode->nbShortMdcts > 1){
992       shortBlocks = ec_dec_bits(&dec, 1);
993       has_fold = 1;
994    } else {
995       shortBlocks = 0;
996       has_fold = 1;
997    }
998    if (shortBlocks)
999    {
1000       transient_shift = ec_dec_bits(&dec, 2);
1001       if (transient_shift)
1002          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1003       else
1004          transient_time = 0;
1005    } else {
1006       transient_time = -1;
1007       transient_shift = 0;
1008    }
1009    /* Get the pitch gains */
1010    
1011    /* Get the pitch index */
1012    if (has_pitch)
1013    {
1014       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
1015       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1016       st->last_pitch_index = pitch_index;
1017    } else {
1018       /* FIXME: We could be more intelligent here and just not compute the MDCT */
1019       pitch_index = 0;
1020       for (i=0;i<st->mode->nbPBands;i++)
1021          gains[i] = 0;
1022    }
1023
1024    ALLOC(fine_quant, st->mode->nbEBands, int);
1025    /* Get band energies */
1026    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
1027    
1028    ALLOC(pulses, st->mode->nbEBands, int);
1029    ALLOC(offsets, st->mode->nbEBands, int);
1030    ALLOC(stereo_mode, st->mode->nbEBands, int);
1031    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1032
1033    for (i=0;i<st->mode->nbEBands;i++)
1034       offsets[i] = 0;
1035
1036    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1037    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1038    /*bits = ec_dec_tell(&dec, 0);
1039    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1040    
1041    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1042
1043
1044    if (has_pitch) 
1045    {
1046       VARDECL(celt_ener_t, bandEp);
1047       
1048       /* Pitch MDCT */
1049       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1050       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1051       compute_band_energies(st->mode, freq, bandEp);
1052       normalise_bands(st->mode, freq, P, bandEp);
1053    } else {
1054       for (i=0;i<C*N;i++)
1055          P[i] = 0;
1056    }
1057
1058    /* Apply pitch gains */
1059    pitch_quant_bands(st->mode, P, gains);
1060
1061    /* Decode fixed codebook and merge with pitch */
1062    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1063
1064    if (C==2)
1065    {
1066       renormalise_bands(st->mode, X);
1067    }
1068    /* Synthesis */
1069    denormalise_bands(st->mode, X, freq, bandE);
1070
1071
1072    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1073    /* Compute inverse MDCTs */
1074    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1075
1076    for (c=0;c<C;c++)
1077    {
1078       int j;
1079       for (j=0;j<N;j++)
1080       {
1081          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1082                                 preemph,st->preemph_memD[c]);
1083          st->preemph_memD[c] = tmp;
1084          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1085       }
1086    }
1087
1088    {
1089       unsigned int val = 0;
1090       while (ec_dec_tell(&dec, 0) < len*8)
1091       {
1092          if (ec_dec_uint(&dec, 2) != val)
1093          {
1094             celt_warning("decode error");
1095             RESTORE_STACK;
1096             return CELT_CORRUPTED_DATA;
1097          }
1098          val = 1-val;
1099       }
1100    }
1101
1102    RESTORE_STACK;
1103    return 0;
1104    /*printf ("\n");*/
1105 }
1106
1107 #ifdef FIXED_POINT
1108 #ifndef DISABLE_FLOAT_API
1109 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1110 {
1111    int j, ret;
1112    const int C = CHANNELS(st->mode);
1113    const int N = st->block_size;
1114    VARDECL(celt_int16_t, out);
1115    SAVE_STACK;
1116    ALLOC(out, C*N, celt_int16_t);
1117
1118    ret=celt_decode(st, data, len, out);
1119
1120    for (j=0;j<C*N;j++)
1121      pcm[j]=out[j]*(1/32768.);
1122    RESTORE_STACK;
1123    return ret;
1124 }
1125 #endif /*DISABLE_FLOAT_API*/
1126 #else
1127 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1128 {
1129    int j, ret;
1130    VARDECL(celt_sig_t, out);
1131    const int C = CHANNELS(st->mode);
1132    const int N = st->block_size;
1133    SAVE_STACK;
1134    ALLOC(out, C*N, celt_sig_t);
1135
1136    ret=celt_decode_float(st, data, len, out);
1137
1138    for (j=0;j<C*N;j++)
1139      pcm[j] = FLOAT2INT16 (out[j]);
1140
1141    RESTORE_STACK;
1142    return ret;
1143 }
1144 #endif