2bdb7a703f70c79aa793a100122f294320667c00
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    int has_fold;
378    ec_byte_buffer buf;
379    ec_enc         enc;
380    celt_word32_t curr_power, pitch_power=0;
381    VARDECL(celt_sig_t, in);
382    VARDECL(celt_sig_t, freq);
383    VARDECL(celt_norm_t, X);
384    VARDECL(celt_norm_t, P);
385    VARDECL(celt_ener_t, bandE);
386    VARDECL(celt_pgain_t, gains);
387    VARDECL(int, stereo_mode);
388    VARDECL(int, fine_quant);
389    VARDECL(celt_word16_t, error);
390    VARDECL(int, pulses);
391    VARDECL(int, offsets);
392 #ifdef EXP_PSY
393    VARDECL(celt_word32_t, mask);
394    VARDECL(celt_word32_t, tonality);
395    VARDECL(celt_word32_t, bandM);
396    VARDECL(celt_ener_t, bandN);
397 #endif
398    int shortBlocks=0;
399    int transient_time;
400    int transient_shift;
401    const int C = CHANNELS(st->mode);
402    SAVE_STACK;
403
404    if (check_mode(st->mode) != CELT_OK)
405       return CELT_INVALID_MODE;
406
407    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
408    CELT_MEMSET(compressed, 0, nbCompressedBytes);
409    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
410    ec_enc_init(&enc,&buf);
411
412    N = st->block_size;
413    N4 = (N-st->overlap)>>1;
414    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
415
416    CELT_COPY(in, st->in_mem, C*st->overlap);
417    for (c=0;c<C;c++)
418    {
419       const celt_word16_t * restrict pcmp = pcm+c;
420       celt_sig_t * restrict inp = in+C*st->overlap+c;
421       for (i=0;i<N;i++)
422       {
423          /* Apply pre-emphasis */
424          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
425          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
426          st->preemph_memE[c] = SCALEIN(*pcmp);
427          inp += C;
428          pcmp += C;
429       }
430    }
431    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
432    
433    if (st->mode->nbShortMdcts > 1)
434    {
435       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
436       {
437 #ifndef FIXED_POINT
438          float gain_1;
439 #endif
440          ec_enc_bits(&enc, 0, 1); //Pitch off
441          ec_enc_bits(&enc, 1, 1); //Transient on
442          ec_enc_bits(&enc, transient_shift, 2);
443          if (transient_shift)
444             ec_enc_uint(&enc, transient_time, N+st->overlap);
445          if (transient_shift)
446          {
447 #ifdef FIXED_POINT
448             for (c=0;c<C;c++)
449                for (i=0;i<16;i++)
450                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
451             for (c=0;c<C;c++)
452                for (i=transient_time;i<N+st->overlap;i++)
453                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
454 #else
455             for (c=0;c<C;c++)
456                for (i=0;i<16;i++)
457                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
458             gain_1 = 1./(1<<transient_shift);
459             for (c=0;c<C;c++)
460                for (i=transient_time;i<N+st->overlap;i++)
461                   in[C*i+c] *= gain_1;
462 #endif
463          }
464          shortBlocks = 1;
465       } else {
466          transient_time = -1;
467          transient_shift = 0;
468          shortBlocks = 0;
469       }
470    } else {
471       transient_time = -1;
472       transient_shift = 0;
473       shortBlocks = 0;
474    }
475
476    /* Pitch analysis: we do it early to save on the peak stack space */
477 #ifdef EXP_PSY
478    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
479    {
480       VARDECL(celt_word16_t, X);
481       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
482       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
483       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
484    }
485 #else
486    if (st->pitch_enabled && !shortBlocks)
487    {
488       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
489    }
490 #endif
491    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
492    
493    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
494    /* Compute MDCTs */
495    compute_mdcts(st->mode, shortBlocks, in, freq);
496
497 #ifdef EXP_PSY
498    /*CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
499    for (i=0;i<N;i++)
500       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
501    for (c=1;c<C;c++)
502       for (i=0;i<N;i++)
503          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
504    */
505    ALLOC(mask, N, celt_sig_t);
506    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
507    /*for (i=0;i<256;i++)
508       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
509    printf ("\n");*/
510
511    /* Invert and stretch the mask to length of X 
512       For some reason, I get better results by using the sqrt instead,
513       although there's no valid reason to. Must investigate further */
514    /*for (i=0;i<C*N;i++)
515       mask[i] = 1/(.1+mask[i]);*/
516 #endif
517    
518    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
519    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
520    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
521    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
522    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
523
524    /*printf ("%f %f\n", curr_power, pitch_power);*/
525    /*int j;
526    for (j=0;j<B*N;j++)
527       printf ("%f ", X[j]);
528    for (j=0;j<B*N;j++)
529       printf ("%f ", P[j]);
530    printf ("\n");*/
531
532    /* Band normalisation */
533    compute_band_energies(st->mode, freq, bandE);
534 #ifdef EXP_PSY
535    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
536    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
537    compute_noise_energies(st->mode, freq, tonality, bandN);
538
539    /*for (i=0;i<st->mode->nbEBands;i++)
540       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
541    printf ("\n");*/
542    has_fold = 0;
543    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
544       if (bandN[i] < .4*bandE[i])
545          has_fold++;
546    /*printf ("%d\n", has_fold);*/
547    if (has_fold>=2)
548       has_fold = 1;
549    else
550       has_fold = 0;
551    for (i=0;i<N;i++)
552       mask[i] = sqrt(mask[i]);
553    compute_band_energies(st->mode, mask, bandM);
554    /*for (i=0;i<st->mode->nbEBands;i++)
555       printf ("%f %f ", bandE[i], bandM[i]);
556    printf ("\n");*/
557 #endif
558    normalise_bands(st->mode, freq, X, bandE);
559    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
560    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
561
562    /* Compute MDCTs of the pitch part */
563    if (st->pitch_enabled && !shortBlocks)
564    {
565       /* Normalise the pitch vector as well (discard the energies) */
566       VARDECL(celt_ener_t, bandEp);
567       
568       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
569       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
570       compute_band_energies(st->mode, freq, bandEp);
571       normalise_bands(st->mode, freq, P, bandEp);
572       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
573    }
574    curr_power = bandE[0]+bandE[1]+bandE[2];
575    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
576    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
577    {
578       /* Simulates intensity stereo */
579       /*for (i=30;i<N*B;i++)
580          X[i*C+1] = P[i*C+1] = 0;*/
581
582       /* Pitch prediction */
583       compute_pitch_gain(st->mode, X, P, gains);
584       has_pitch = quant_pitch(gains, st->mode->nbPBands, &enc);
585       if (has_pitch)
586          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
587       else if (st->mode->nbShortMdcts > 1)
588          ec_enc_bits(&enc, 0, 1); //Transient off
589    } else {
590       if (!shortBlocks)
591       {
592          ec_enc_bits(&enc, 0, 1); //Pitch off
593          if (st->mode->nbShortMdcts > 1)
594            ec_enc_bits(&enc, 0, 1); //Transient off
595       }
596       /* No pitch, so we just pretend we found a gain of zero */
597       for (i=0;i<st->mode->nbPBands;i++)
598          gains[i] = 0;
599       for (i=0;i<C*N;i++)
600          P[i] = 0;
601    }
602
603 #ifdef STDIN_TUNING2
604    static int fine_quant[30];
605    static int pulses[30];
606    static int init=0;
607    if (!init)
608    {
609       for (i=0;i<st->mode->nbEBands;i++)
610          scanf("%d ", &fine_quant[i]);
611       for (i=0;i<st->mode->nbEBands;i++)
612          scanf("%d ", &pulses[i]);
613       init = 1;
614    }
615 #else
616    ALLOC(fine_quant, st->mode->nbEBands, int);
617    ALLOC(pulses, st->mode->nbEBands, int);
618 #endif
619    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
620    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
621    
622    ALLOC(offsets, st->mode->nbEBands, int);
623    ALLOC(stereo_mode, st->mode->nbEBands, int);
624    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
625
626    for (i=0;i<st->mode->nbEBands;i++)
627       offsets[i] = 0;
628    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
629 #ifndef STDIN_TUNING
630    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
631 #endif
632    /*for (i=0;i<st->mode->nbEBands;i++)
633       printf("%d ", fine_quant[i]);
634    for (i=0;i<st->mode->nbEBands;i++)
635       printf("%d ", pulses[i]);
636    printf ("\n");*/
637    /*bits = ec_enc_tell(&st->enc, 0);
638    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
639    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
640
641    pitch_quant_bands(st->mode, P, gains);
642
643    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
644
645    /* Residual quantisation */
646    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &enc);
647    
648    if (st->pitch_enabled || optional_synthesis!=NULL)
649    {
650       if (C==2)
651          renormalise_bands(st->mode, X);
652       /* Synthesis */
653       denormalise_bands(st->mode, X, freq, bandE);
654       
655       
656       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
657       
658       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
659       /* De-emphasis and put everything back at the right place in the synthesis history */
660       if (optional_synthesis != NULL) {
661          for (c=0;c<C;c++)
662          {
663             int j;
664             for (j=0;j<N;j++)
665             {
666                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
667                                    preemph,st->preemph_memD[c]);
668                st->preemph_memD[c] = tmp;
669                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
670             }
671          }
672       }
673    }
674    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
675    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
676       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
677    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
678    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
679    {
680       int val = 0;
681       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
682       {
683          ec_enc_uint(&enc, val, 2);
684          val = 1-val;
685       }
686    }
687    ec_enc_done(&enc);
688    {
689       /*unsigned char *data;*/
690       int nbBytes = ec_byte_bytes(&buf);
691       if (nbBytes > nbCompressedBytes)
692       {
693          celt_warning_int ("got too many bytes:", nbBytes);
694          RESTORE_STACK;
695          return CELT_INTERNAL_ERROR;
696       }
697       /*printf ("%d\n", *nbBytes);*/
698       /*data = ec_byte_get_buffer(&buf);
699       for (i=0;i<nbBytes;i++)
700          compressed[i] = data[i];
701       for (i=nbBytes;i<nbCompressedBytes;i++)
702          compressed[i] = 0;*/
703    }
704
705    RESTORE_STACK;
706    return nbCompressedBytes;
707 }
708
709 #ifdef FIXED_POINT
710 #ifndef DISABLE_FLOAT_API
711 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
712 {
713    int j, ret;
714    const int C = CHANNELS(st->mode);
715    const int N = st->block_size;
716    VARDECL(celt_int16_t, in);
717    SAVE_STACK;
718    ALLOC(in, C*N, celt_int16_t);
719
720    for (j=0;j<C*N;j++)
721      in[j] = FLOAT2INT16(pcm[j]);
722
723    if (optional_synthesis != NULL) {
724      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
725    /*Converts backwards for inplace operation*/
726       for (j=0;j=C*N;j++)
727          optional_synthesis[j]=in[j]*(1/32768.);
728    } else {
729      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
730    }
731    RESTORE_STACK;
732    return ret;
733
734 }
735 #endif /*DISABLE_FLOAT_API*/
736 #else
737 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
738 {
739    int j, ret;
740    VARDECL(celt_sig_t, in);
741    const int C = CHANNELS(st->mode);
742    const int N = st->block_size;
743    SAVE_STACK;
744    ALLOC(in, C*N, celt_sig_t);
745    for (j=0;j<C*N;j++) {
746      in[j] = SCALEOUT(pcm[j]);
747    }
748
749    if (optional_synthesis != NULL) {
750       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
751       for (j=0;j<C*N;j++)
752          optional_synthesis[j] = FLOAT2INT16(in[j]);
753    } else {
754       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
755    }
756    RESTORE_STACK;
757    return ret;
758 }
759 #endif
760
761 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
762 {
763    va_list ap;
764    va_start(ap, request);
765    switch (request)
766    {
767       case CELT_SET_COMPLEXITY_REQUEST:
768       {
769          int value = va_arg(ap, int);
770          if (value<0 || value>10)
771             goto bad_arg;
772          if (value<=2)
773             st->pitch_enabled = 0;
774          else
775             st->pitch_enabled = 1;
776       }
777       break;
778       default:
779          goto bad_request;
780    }
781    va_end(ap);
782    return CELT_OK;
783 bad_arg:
784    va_end(ap);
785    return CELT_BAD_ARG;
786 bad_request:
787    va_end(ap);
788    return CELT_UNIMPLEMENTED;
789 }
790
791 /****************************************************************************/
792 /*                                                                          */
793 /*                                DECODER                                   */
794 /*                                                                          */
795 /****************************************************************************/
796
797
798 /** Decoder state 
799  @brief Decoder state
800  */
801 struct CELTDecoder {
802    const CELTMode *mode;
803    int frame_size;
804    int block_size;
805    int overlap;
806
807    ec_byte_buffer buf;
808    ec_enc         enc;
809
810    celt_sig_t * restrict preemph_memD;
811
812    celt_sig_t *out_mem;
813
814    celt_word16_t *oldBandE;
815    
816    int last_pitch_index;
817 };
818
819 CELTDecoder *celt_decoder_create(const CELTMode *mode)
820 {
821    int N, C;
822    CELTDecoder *st;
823
824    if (check_mode(mode) != CELT_OK)
825       return NULL;
826
827    N = mode->mdctSize;
828    C = CHANNELS(mode);
829    st = celt_alloc(sizeof(CELTDecoder));
830    
831    st->mode = mode;
832    st->frame_size = N;
833    st->block_size = N;
834    st->overlap = mode->overlap;
835
836    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
837    
838    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
839
840    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
841
842    st->last_pitch_index = 0;
843    return st;
844 }
845
846 void celt_decoder_destroy(CELTDecoder *st)
847 {
848    if (st == NULL)
849    {
850       celt_warning("NULL passed to celt_encoder_destroy");
851       return;
852    }
853    if (check_mode(st->mode) != CELT_OK)
854       return;
855
856
857    celt_free(st->out_mem);
858    
859    celt_free(st->oldBandE);
860    
861    celt_free(st->preemph_memD);
862
863    celt_free(st);
864 }
865
866 /** Handles lost packets by just copying past data with the same offset as the last
867     pitch period */
868 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
869 {
870    int c, N;
871    int pitch_index;
872    int i, len;
873    VARDECL(celt_sig_t, freq);
874    const int C = CHANNELS(st->mode);
875    int offset;
876    SAVE_STACK;
877    N = st->block_size;
878    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
879    
880    len = N+st->mode->overlap;
881 #if 0
882    pitch_index = st->last_pitch_index;
883    
884    /* Use the pitch MDCT as the "guessed" signal */
885    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
886
887 #else
888    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
889    pitch_index = MAX_PERIOD-len-pitch_index;
890    offset = MAX_PERIOD-pitch_index;
891    while (offset+len >= MAX_PERIOD)
892       offset -= pitch_index;
893    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
894    for (i=0;i<N;i++)
895       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
896 #endif
897    
898    
899    
900    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
901    /* Compute inverse MDCTs */
902    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
903
904    for (c=0;c<C;c++)
905    {
906       int j;
907       for (j=0;j<N;j++)
908       {
909          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
910                                 preemph,st->preemph_memD[c]);
911          st->preemph_memD[c] = tmp;
912          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
913       }
914    }
915    RESTORE_STACK;
916 }
917
918 #ifdef FIXED_POINT
919 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
920 {
921 #else
922 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
923 {
924 #endif
925    int i, c, N, N4;
926    int has_pitch, has_fold;
927    int pitch_index;
928    int bits;
929    ec_dec dec;
930    ec_byte_buffer buf;
931    VARDECL(celt_sig_t, freq);
932    VARDECL(celt_norm_t, X);
933    VARDECL(celt_norm_t, P);
934    VARDECL(celt_ener_t, bandE);
935    VARDECL(celt_pgain_t, gains);
936    VARDECL(int, stereo_mode);
937    VARDECL(int, fine_quant);
938    VARDECL(int, pulses);
939    VARDECL(int, offsets);
940
941    int shortBlocks;
942    int transient_time;
943    int transient_shift;
944    const int C = CHANNELS(st->mode);
945    SAVE_STACK;
946
947    if (check_mode(st->mode) != CELT_OK)
948       return CELT_INVALID_MODE;
949
950    N = st->block_size;
951    N4 = (N-st->overlap)>>1;
952
953    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
954    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
955    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
956    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
957    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
958    
959    if (check_mode(st->mode) != CELT_OK)
960    {
961       RESTORE_STACK;
962       return CELT_INVALID_MODE;
963    }
964    if (data == NULL)
965    {
966       celt_decode_lost(st, pcm);
967       RESTORE_STACK;
968       return 0;
969    }
970    
971    ec_byte_readinit(&buf,data,len);
972    ec_dec_init(&dec,&buf);
973    
974    has_pitch = ec_dec_bits(&dec, 1);
975    if (has_pitch)
976    {
977       has_fold = ec_dec_bits(&dec, 1);
978       shortBlocks = 0;
979    } else if (st->mode->nbShortMdcts > 1){
980       shortBlocks = ec_dec_bits(&dec, 1);
981       has_fold = 1;
982    } else {
983       shortBlocks = 0;
984       has_fold = 1;
985    }
986    if (shortBlocks)
987    {
988       transient_shift = ec_dec_bits(&dec, 2);
989       if (transient_shift)
990          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
991       else
992          transient_time = 0;
993    } else {
994       transient_time = -1;
995       transient_shift = 0;
996    }
997    /* Get the pitch gains */
998    
999    /* Get the pitch index */
1000    if (has_pitch)
1001    {
1002       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
1003       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1004       st->last_pitch_index = pitch_index;
1005    } else {
1006       /* FIXME: We could be more intelligent here and just not compute the MDCT */
1007       pitch_index = 0;
1008       for (i=0;i<st->mode->nbPBands;i++)
1009          gains[i] = 0;
1010    }
1011
1012    ALLOC(fine_quant, st->mode->nbEBands, int);
1013    /* Get band energies */
1014    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
1015    
1016    ALLOC(pulses, st->mode->nbEBands, int);
1017    ALLOC(offsets, st->mode->nbEBands, int);
1018    ALLOC(stereo_mode, st->mode->nbEBands, int);
1019    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1020
1021    for (i=0;i<st->mode->nbEBands;i++)
1022       offsets[i] = 0;
1023
1024    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1025    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1026    /*bits = ec_dec_tell(&dec, 0);
1027    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1028    
1029    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1030
1031
1032    if (has_pitch) 
1033    {
1034       VARDECL(celt_ener_t, bandEp);
1035       
1036       /* Pitch MDCT */
1037       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1038       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1039       compute_band_energies(st->mode, freq, bandEp);
1040       normalise_bands(st->mode, freq, P, bandEp);
1041    } else {
1042       for (i=0;i<C*N;i++)
1043          P[i] = 0;
1044    }
1045
1046    /* Apply pitch gains */
1047    pitch_quant_bands(st->mode, P, gains);
1048
1049    /* Decode fixed codebook and merge with pitch */
1050    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
1051
1052    if (C==2)
1053    {
1054       renormalise_bands(st->mode, X);
1055    }
1056    /* Synthesis */
1057    denormalise_bands(st->mode, X, freq, bandE);
1058
1059
1060    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1061    /* Compute inverse MDCTs */
1062    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1063
1064    for (c=0;c<C;c++)
1065    {
1066       int j;
1067       for (j=0;j<N;j++)
1068       {
1069          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1070                                 preemph,st->preemph_memD[c]);
1071          st->preemph_memD[c] = tmp;
1072          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1073       }
1074    }
1075
1076    {
1077       unsigned int val = 0;
1078       while (ec_dec_tell(&dec, 0) < len*8)
1079       {
1080          if (ec_dec_uint(&dec, 2) != val)
1081          {
1082             celt_warning("decode error");
1083             RESTORE_STACK;
1084             return CELT_CORRUPTED_DATA;
1085          }
1086          val = 1-val;
1087       }
1088    }
1089
1090    RESTORE_STACK;
1091    return 0;
1092    /*printf ("\n");*/
1093 }
1094
1095 #ifdef FIXED_POINT
1096 #ifndef DISABLE_FLOAT_API
1097 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1098 {
1099    int j, ret;
1100    const int C = CHANNELS(st->mode);
1101    const int N = st->block_size;
1102    VARDECL(celt_int16_t, out);
1103    SAVE_STACK;
1104    ALLOC(out, C*N, celt_int16_t);
1105
1106    ret=celt_decode(st, data, len, out);
1107
1108    for (j=0;j<C*N;j++)
1109      pcm[j]=out[j]*(1/32768.);
1110    RESTORE_STACK;
1111    return ret;
1112 }
1113 #endif /*DISABLE_FLOAT_API*/
1114 #else
1115 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1116 {
1117    int j, ret;
1118    VARDECL(celt_sig_t, out);
1119    const int C = CHANNELS(st->mode);
1120    const int N = st->block_size;
1121    SAVE_STACK;
1122    ALLOC(out, C*N, celt_sig_t);
1123
1124    ret=celt_decode_float(st, data, len, out);
1125
1126    for (j=0;j<C*N;j++)
1127      pcm[j] = FLOAT2INT16 (out[j]);
1128
1129    RESTORE_STACK;
1130    return ret;
1131 }
1132 #endif