celt_encoder_ctl() is a bit more type-safe.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80    
81    ec_byte_buffer buf;
82    ec_enc         enc;
83
84    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
85    celt_sig_t    * restrict preemph_memD;
86
87    celt_sig_t *in_mem;
88    celt_sig_t *out_mem;
89
90    celt_word16_t *oldBandE;
91 #ifdef EXP_PSY
92    celt_word16_t *psy_mem;
93    struct PsyDecay psy;
94 #endif
95 };
96
97 CELTEncoder *celt_encoder_create(const CELTMode *mode)
98 {
99    int N, C;
100    CELTEncoder *st;
101
102    if (check_mode(mode) != CELT_OK)
103       return NULL;
104
105    N = mode->mdctSize;
106    C = mode->nbChannels;
107    st = celt_alloc(sizeof(CELTEncoder));
108    
109    st->mode = mode;
110    st->frame_size = N;
111    st->block_size = N;
112    st->overlap = mode->overlap;
113
114    st->pitch_enabled = 1;
115    
116    ec_byte_writeinit(&st->buf);
117    ec_enc_init(&st->enc,&st->buf);
118
119    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
120    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
121
122    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
123
124    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
125    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
126
127 #ifdef EXP_PSY
128    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
129    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
130 #endif
131
132    return st;
133 }
134
135 void celt_encoder_destroy(CELTEncoder *st)
136 {
137    if (st == NULL)
138    {
139       celt_warning("NULL passed to celt_encoder_destroy");
140       return;
141    }
142    if (check_mode(st->mode) != CELT_OK)
143       return;
144
145    ec_byte_writeclear(&st->buf);
146
147    celt_free(st->in_mem);
148    celt_free(st->out_mem);
149    
150    celt_free(st->oldBandE);
151    
152    celt_free(st->preemph_memE);
153    celt_free(st->preemph_memD);
154    
155 #ifdef EXP_PSY
156    celt_free (st->psy_mem);
157    psydecay_clear(&st->psy);
158 #endif
159    
160    celt_free(st);
161 }
162
163 static inline celt_int16_t FLOAT2INT16(float x)
164 {
165    x = x*32768.;
166    x = MAX32(x, -32768);
167    x = MIN32(x, 32767);
168    return (celt_int16_t)float2int(x);
169 }
170
171 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
172 {
173 #ifdef FIXED_POINT
174    x = PSHR32(x, SIG_SHIFT);
175    x = MAX32(x, -32768);
176    x = MIN32(x, 32767);
177    return EXTRACT16(x);
178 #else
179    return (celt_word16_t)x;
180 #endif
181 }
182
183 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
184 {
185    int c, i, n;
186    celt_word32_t ratio;
187    /* FIXME: Remove the floats here */
188    VARDECL(celt_word32_t, begin);
189    SAVE_STACK;
190    ALLOC(begin, len, celt_word32_t);
191    for (i=0;i<len;i++)
192       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
193    for (c=1;c<C;c++)
194    {
195       for (i=0;i<len;i++)
196          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
197    }
198    for (i=1;i<len;i++)
199       begin[i] = MAX32(begin[i-1],begin[i]);
200    n = -1;
201    for (i=8;i<len-8;i++)
202    {
203       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
204          n=i;
205    }
206    if (n<32)
207    {
208       n = -1;
209       ratio = 0;
210    } else {
211       ratio = DIV32(begin[len-1],1+begin[n-16]);
212    }
213    /*printf ("%d %f\n", n, ratio*ratio);*/
214    if (ratio < 0)
215       ratio = 0;
216    if (ratio > 1000)
217       ratio = 1000;
218    ratio *= ratio;
219    if (ratio < 50)
220       *transient_shift = 0;
221    else if (ratio < 256)
222       *transient_shift = 1;
223    else if (ratio < 4096)
224       *transient_shift = 2;
225    else
226       *transient_shift = 3;
227    *transient_time = n;
228    
229    RESTORE_STACK;
230    return ratio > 20;
231 }
232
233 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
234 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
235 {
236    const int C = CHANNELS(mode);
237    if (C==1 && !shortBlocks)
238    {
239       const mdct_lookup *lookup = MDCT(mode);
240       const int overlap = OVERLAP(mode);
241       mdct_forward(lookup, in, out, mode->window, overlap);
242    } else if (!shortBlocks) {
243       const mdct_lookup *lookup = MDCT(mode);
244       const int overlap = OVERLAP(mode);
245       const int N = FRAMESIZE(mode);
246       int c;
247       VARDECL(celt_word32_t, x);
248       VARDECL(celt_word32_t, tmp);
249       SAVE_STACK;
250       ALLOC(x, N+overlap, celt_word32_t);
251       ALLOC(tmp, N, celt_word32_t);
252       for (c=0;c<C;c++)
253       {
254          int j;
255          for (j=0;j<N+overlap;j++)
256             x[j] = in[C*j+c];
257          mdct_forward(lookup, x, tmp, mode->window, overlap);
258          /* Interleaving the sub-frames */
259          for (j=0;j<N;j++)
260             out[C*j+c] = tmp[j];
261       }
262       RESTORE_STACK;
263    } else {
264       const mdct_lookup *lookup = &mode->shortMdct;
265       const int overlap = mode->overlap;
266       const int N = mode->shortMdctSize;
267       int b, c;
268       VARDECL(celt_word32_t, x);
269       VARDECL(celt_word32_t, tmp);
270       SAVE_STACK;
271       ALLOC(x, N+overlap, celt_word32_t);
272       ALLOC(tmp, N, celt_word32_t);
273       for (c=0;c<C;c++)
274       {
275          int B = mode->nbShortMdcts;
276          for (b=0;b<B;b++)
277          {
278             int j;
279             for (j=0;j<N+overlap;j++)
280                x[j] = in[C*(b*N+j)+c];
281             mdct_forward(lookup, x, tmp, mode->window, overlap);
282             /* Interleaving the sub-frames */
283             for (j=0;j<N;j++)
284                out[C*(j*B+b)+c] = tmp[j];
285          }
286       }
287       RESTORE_STACK;
288    }
289 }
290
291 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
292 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
293 {
294    int c, N4;
295    const int C = CHANNELS(mode);
296    const int N = FRAMESIZE(mode);
297    const int overlap = OVERLAP(mode);
298    N4 = (N-overlap)>>1;
299    for (c=0;c<C;c++)
300    {
301       int j;
302       if (transient_shift==0 && C==1 && !shortBlocks) {
303          const mdct_lookup *lookup = MDCT(mode);
304          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
305       } else if (!shortBlocks) {
306          const mdct_lookup *lookup = MDCT(mode);
307          VARDECL(celt_word32_t, x);
308          VARDECL(celt_word32_t, tmp);
309          SAVE_STACK;
310          ALLOC(x, 2*N, celt_word32_t);
311          ALLOC(tmp, N, celt_word32_t);
312          /* De-interleaving the sub-frames */
313          for (j=0;j<N;j++)
314             tmp[j] = X[C*j+c];
315          /* Prevents problems from the imdct doing the overlap-add */
316          CELT_MEMSET(x+N4, 0, N);
317          mdct_backward(lookup, tmp, x, mode->window, overlap);
318          celt_assert(transient_shift == 0);
319          /* The first and last part would need to be set to zero if we actually
320             wanted to use them. */
321          for (j=0;j<overlap;j++)
322             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
323          for (j=0;j<overlap;j++)
324             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
325          for (j=0;j<2*N4;j++)
326             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
327          RESTORE_STACK;
328       } else {
329          int b;
330          const int N2 = mode->shortMdctSize;
331          const int B = mode->nbShortMdcts;
332          const mdct_lookup *lookup = &mode->shortMdct;
333          VARDECL(celt_word32_t, x);
334          VARDECL(celt_word32_t, tmp);
335          SAVE_STACK;
336          ALLOC(x, 2*N, celt_word32_t);
337          ALLOC(tmp, N, celt_word32_t);
338          /* Prevents problems from the imdct doing the overlap-add */
339          CELT_MEMSET(x+N4, 0, N2);
340          for (b=0;b<B;b++)
341          {
342             /* De-interleaving the sub-frames */
343             for (j=0;j<N2;j++)
344                tmp[j] = X[C*(j*B+b)+c];
345             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
346          }
347          if (transient_shift > 0)
348          {
349 #ifdef FIXED_POINT
350             for (j=0;j<16;j++)
351                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
352             for (j=transient_time;j<N+overlap;j++)
353                x[N4+j] = SHL32(x[N4+j], transient_shift);
354 #else
355             for (j=0;j<16;j++)
356                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
357             for (j=transient_time;j<N+overlap;j++)
358                x[N4+j] *= 1<<transient_shift;
359 #endif
360          }
361          /* The first and last part would need to be set to zero if we actually
362          wanted to use them. */
363          for (j=0;j<overlap;j++)
364             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
365          for (j=0;j<overlap;j++)
366             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
367          for (j=0;j<2*N4;j++)
368             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
369          RESTORE_STACK;
370       }
371    }
372 }
373
374 #ifdef FIXED_POINT
375 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
376 {
377 #else
378 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
379 {
380 #endif
381    int i, c, N, N4;
382    int has_pitch;
383    int pitch_index;
384    int bits;
385    celt_word32_t curr_power, pitch_power=0;
386    VARDECL(celt_sig_t, in);
387    VARDECL(celt_sig_t, freq);
388    VARDECL(celt_norm_t, X);
389    VARDECL(celt_norm_t, P);
390    VARDECL(celt_ener_t, bandE);
391    VARDECL(celt_pgain_t, gains);
392    VARDECL(int, stereo_mode);
393    VARDECL(int, fine_quant);
394    VARDECL(celt_word16_t, error);
395    VARDECL(int, pulses);
396    VARDECL(int, offsets);
397 #ifdef EXP_PSY
398    VARDECL(celt_word32_t, mask);
399 #endif
400    int shortBlocks=0;
401    int transient_time;
402    int transient_shift;
403    const int C = CHANNELS(st->mode);
404    SAVE_STACK;
405
406    if (check_mode(st->mode) != CELT_OK)
407       return CELT_INVALID_MODE;
408
409    N = st->block_size;
410    N4 = (N-st->overlap)>>1;
411    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
412
413    CELT_COPY(in, st->in_mem, C*st->overlap);
414    for (c=0;c<C;c++)
415    {
416       const celt_word16_t * restrict pcmp = pcm+c;
417       celt_sig_t * restrict inp = in+C*st->overlap+c;
418       for (i=0;i<N;i++)
419       {
420          /* Apply pre-emphasis */
421          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
422          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
423          st->preemph_memE[c] = SCALEIN(*pcmp);
424          inp += C;
425          pcmp += C;
426       }
427    }
428    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
429    
430    if (st->mode->nbShortMdcts > 1)
431    {
432       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
433       {
434 #ifndef FIXED_POINT
435          float gain_1;
436 #endif
437          ec_enc_bits(&st->enc, 0, 1); //Pitch off
438          ec_enc_bits(&st->enc, 1, 1); //Transient on
439          ec_enc_bits(&st->enc, transient_shift, 2);
440          if (transient_shift)
441             ec_enc_uint(&st->enc, transient_time, N+st->overlap);
442          if (transient_shift)
443          {
444 #ifdef FIXED_POINT
445             for (c=0;c<C;c++)
446                for (i=0;i<16;i++)
447                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
448             for (c=0;c<C;c++)
449                for (i=transient_time;i<N+st->overlap;i++)
450                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
451 #else
452             for (c=0;c<C;c++)
453                for (i=0;i<16;i++)
454                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
455             gain_1 = 1./(1<<transient_shift);
456             for (c=0;c<C;c++)
457                for (i=transient_time;i<N+st->overlap;i++)
458                   in[C*i+c] *= gain_1;
459 #endif
460          }
461          shortBlocks = 1;
462       } else {
463          transient_time = -1;
464          transient_shift = 0;
465          shortBlocks = 0;
466       }
467    } else {
468       transient_time = -1;
469       transient_shift = 0;
470       shortBlocks = 0;
471    }
472    /* Pitch analysis: we do it early to save on the peak stack space */
473    if (st->pitch_enabled && !shortBlocks)
474       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
475
476    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
477    
478    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
479    /* Compute MDCTs */
480    compute_mdcts(st->mode, shortBlocks, in, freq);
481
482 #ifdef EXP_PSY
483    CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
484    for (i=0;i<N;i++)
485       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
486    for (c=1;c<C;c++)
487       for (i=0;i<N;i++)
488          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
489
490    ALLOC(mask, N, celt_sig_t);
491    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
492
493    /* Invert and stretch the mask to length of X 
494       For some reason, I get better results by using the sqrt instead,
495       although there's no valid reason to. Must investigate further */
496    for (i=0;i<C*N;i++)
497       mask[i] = 1/(.1+mask[i]);
498 #endif
499    
500    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
501    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
502    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
503    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
504    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
505
506    /*printf ("%f %f\n", curr_power, pitch_power);*/
507    /*int j;
508    for (j=0;j<B*N;j++)
509       printf ("%f ", X[j]);
510    for (j=0;j<B*N;j++)
511       printf ("%f ", P[j]);
512    printf ("\n");*/
513
514    /* Band normalisation */
515    compute_band_energies(st->mode, freq, bandE);
516    normalise_bands(st->mode, freq, X, bandE);
517    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
518    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
519
520    /* Compute MDCTs of the pitch part */
521    if (st->pitch_enabled && !shortBlocks)
522    {
523       /* Normalise the pitch vector as well (discard the energies) */
524       VARDECL(celt_ener_t, bandEp);
525       
526       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
527       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
528       compute_band_energies(st->mode, freq, bandEp);
529       normalise_bands(st->mode, freq, P, bandEp);
530       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
531    }
532    curr_power = bandE[0]+bandE[1]+bandE[2];
533    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
534    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
535    {
536       /* Simulates intensity stereo */
537       /*for (i=30;i<N*B;i++)
538          X[i*C+1] = P[i*C+1] = 0;*/
539
540       /* Pitch prediction */
541       compute_pitch_gain(st->mode, X, P, gains);
542       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
543       if (has_pitch)
544          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
545       else if (st->mode->nbShortMdcts > 1)
546          ec_enc_bits(&st->enc, 0, 1); //Transient off
547    } else {
548       if (!shortBlocks)
549       {
550          ec_enc_bits(&st->enc, 0, 1); //Pitch off
551          if (st->mode->nbShortMdcts > 1)
552            ec_enc_bits(&st->enc, 0, 1); //Transient off
553       }
554       /* No pitch, so we just pretend we found a gain of zero */
555       for (i=0;i<st->mode->nbPBands;i++)
556          gains[i] = 0;
557       for (i=0;i<C*N;i++)
558          P[i] = 0;
559    }
560
561 #ifdef STDIN_TUNING2
562    static int fine_quant[30];
563    static int pulses[30];
564    static int init=0;
565    if (!init)
566    {
567       for (i=0;i<st->mode->nbEBands;i++)
568          scanf("%d ", &fine_quant[i]);
569       for (i=0;i<st->mode->nbEBands;i++)
570          scanf("%d ", &pulses[i]);
571       init = 1;
572    }
573 #else
574    ALLOC(fine_quant, st->mode->nbEBands, int);
575    ALLOC(pulses, st->mode->nbEBands, int);
576 #endif
577    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
578    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &st->enc);
579    
580    ALLOC(offsets, st->mode->nbEBands, int);
581    ALLOC(stereo_mode, st->mode->nbEBands, int);
582    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
583
584    for (i=0;i<st->mode->nbEBands;i++)
585       offsets[i] = 0;
586    bits = nbCompressedBytes*8 - ec_enc_tell(&st->enc, 0) - 1;
587 #ifndef STDIN_TUNING
588    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
589 #endif
590    /*for (i=0;i<st->mode->nbEBands;i++)
591       printf("%d ", fine_quant[i]);
592    for (i=0;i<st->mode->nbEBands;i++)
593       printf("%d ", pulses[i]);
594    printf ("\n");*/
595    /*bits = ec_enc_tell(&st->enc, 0);
596    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
597    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
598
599    pitch_quant_bands(st->mode, P, gains);
600
601    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
602
603    /* Residual quantisation */
604    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &st->enc);
605    
606    if (st->pitch_enabled || optional_synthesis!=NULL)
607    {
608       if (C==2)
609          renormalise_bands(st->mode, X);
610       /* Synthesis */
611       denormalise_bands(st->mode, X, freq, bandE);
612       
613       
614       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
615       
616       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
617       /* De-emphasis and put everything back at the right place in the synthesis history */
618       if (optional_synthesis != NULL) {
619          for (c=0;c<C;c++)
620          {
621             int j;
622             for (j=0;j<N;j++)
623             {
624                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
625                                    preemph,st->preemph_memD[c]);
626                st->preemph_memD[c] = tmp;
627                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
628             }
629          }
630       }
631    }
632    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
633    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
634       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
635    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
636    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
637    {
638       int val = 0;
639       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
640       {
641          ec_enc_uint(&st->enc, val, 2);
642          val = 1-val;
643       }
644    }
645    ec_enc_done(&st->enc);
646    {
647       unsigned char *data;
648       int nbBytes = ec_byte_bytes(&st->buf);
649       if (nbBytes > nbCompressedBytes)
650       {
651          celt_warning_int ("got too many bytes:", nbBytes);
652          RESTORE_STACK;
653          return CELT_INTERNAL_ERROR;
654       }
655       /*printf ("%d\n", *nbBytes);*/
656       data = ec_byte_get_buffer(&st->buf);
657       for (i=0;i<nbBytes;i++)
658          compressed[i] = data[i];
659       for (;i<nbCompressedBytes;i++)
660          compressed[i] = 0;
661    }
662    /* Reset the packing for the next encoding */
663    ec_byte_reset(&st->buf);
664    ec_enc_init(&st->enc,&st->buf);
665
666    RESTORE_STACK;
667    return nbCompressedBytes;
668 }
669
670 #ifdef FIXED_POINT
671 #ifndef DISABLE_FLOAT_API
672 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
673 {
674    int j, ret;
675    const int C = CHANNELS(st->mode);
676    const int N = st->block_size;
677    VARDECL(celt_int16_t, in);
678    SAVE_STACK;
679    ALLOC(in, C*N, celt_int16_t);
680
681    for (j=0;j<C*N;j++)
682      in[j] = FLOAT2INT16(pcm[j]);
683
684    if (optional_synthesis != NULL) {
685      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
686    /*Converts backwards for inplace operation*/
687       for (j=0;j=C*N;j++)
688          optional_synthesis[j]=in[j]*(1/32768.);
689    } else {
690      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
691    }
692    RESTORE_STACK;
693    return ret;
694
695 }
696 #endif /*DISABLE_FLOAT_API*/
697 #else
698 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
699 {
700    int j, ret;
701    VARDECL(celt_sig_t, in);
702    const int C = CHANNELS(st->mode);
703    const int N = st->block_size;
704    SAVE_STACK;
705    ALLOC(in, C*N, celt_sig_t);
706    for (j=0;j<C*N;j++) {
707      in[j] = SCALEOUT(pcm[j]);
708    }
709
710    if (optional_synthesis != NULL) {
711       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
712       for (j=0;j<C*N;j++)
713          optional_synthesis[j] = FLOAT2INT16(in[j]);
714    } else {
715       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
716    }
717    RESTORE_STACK;
718    return ret;
719 }
720 #endif
721
722 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
723 {
724    va_list ap;
725    va_start(ap, request);
726    switch (request)
727    {
728       case CELT_SET_COMPLEXITY_REQUEST:
729       {
730          int value = va_arg(ap, int);
731          if (value<0 || value>10)
732             goto bad_arg;
733          if (value<=2)
734             st->pitch_enabled = 0;
735          else
736             st->pitch_enabled = 1;
737       }
738       break;
739       default:
740          goto bad_request;
741    }
742    va_end(ap);
743    return CELT_OK;
744 bad_arg:
745    va_end(ap);
746    return CELT_BAD_ARG;
747 bad_request:
748    va_end(ap);
749    return CELT_UNIMPLEMENTED;
750 }
751
752 /****************************************************************************/
753 /*                                                                          */
754 /*                                DECODER                                   */
755 /*                                                                          */
756 /****************************************************************************/
757
758
759 /** Decoder state 
760  @brief Decoder state
761  */
762 struct CELTDecoder {
763    const CELTMode *mode;
764    int frame_size;
765    int block_size;
766    int overlap;
767
768    ec_byte_buffer buf;
769    ec_enc         enc;
770
771    celt_sig_t * restrict preemph_memD;
772
773    celt_sig_t *out_mem;
774
775    celt_word16_t *oldBandE;
776    
777    int last_pitch_index;
778 };
779
780 CELTDecoder *celt_decoder_create(const CELTMode *mode)
781 {
782    int N, C;
783    CELTDecoder *st;
784
785    if (check_mode(mode) != CELT_OK)
786       return NULL;
787
788    N = mode->mdctSize;
789    C = CHANNELS(mode);
790    st = celt_alloc(sizeof(CELTDecoder));
791    
792    st->mode = mode;
793    st->frame_size = N;
794    st->block_size = N;
795    st->overlap = mode->overlap;
796
797    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
798    
799    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
800
801    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
802
803    st->last_pitch_index = 0;
804    return st;
805 }
806
807 void celt_decoder_destroy(CELTDecoder *st)
808 {
809    if (st == NULL)
810    {
811       celt_warning("NULL passed to celt_encoder_destroy");
812       return;
813    }
814    if (check_mode(st->mode) != CELT_OK)
815       return;
816
817
818    celt_free(st->out_mem);
819    
820    celt_free(st->oldBandE);
821    
822    celt_free(st->preemph_memD);
823
824    celt_free(st);
825 }
826
827 /** Handles lost packets by just copying past data with the same offset as the last
828     pitch period */
829 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
830 {
831    int c, N;
832    int pitch_index;
833    int i, len;
834    VARDECL(celt_sig_t, freq);
835    const int C = CHANNELS(st->mode);
836    int offset;
837    SAVE_STACK;
838    N = st->block_size;
839    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
840    
841    len = N+st->mode->overlap;
842 #if 0
843    pitch_index = st->last_pitch_index;
844    
845    /* Use the pitch MDCT as the "guessed" signal */
846    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
847
848 #else
849    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
850    pitch_index = MAX_PERIOD-len-pitch_index;
851    offset = MAX_PERIOD-pitch_index;
852    while (offset+len >= MAX_PERIOD)
853       offset -= pitch_index;
854    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
855    for (i=0;i<N;i++)
856       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
857 #endif
858    
859    
860    
861    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
862    /* Compute inverse MDCTs */
863    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
864
865    for (c=0;c<C;c++)
866    {
867       int j;
868       for (j=0;j<N;j++)
869       {
870          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
871                                 preemph,st->preemph_memD[c]);
872          st->preemph_memD[c] = tmp;
873          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
874       }
875    }
876    RESTORE_STACK;
877 }
878
879 #ifdef FIXED_POINT
880 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
881 {
882 #else
883 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
884 {
885 #endif
886    int i, c, N, N4;
887    int has_pitch, has_fold;
888    int pitch_index;
889    int bits;
890    ec_dec dec;
891    ec_byte_buffer buf;
892    VARDECL(celt_sig_t, freq);
893    VARDECL(celt_norm_t, X);
894    VARDECL(celt_norm_t, P);
895    VARDECL(celt_ener_t, bandE);
896    VARDECL(celt_pgain_t, gains);
897    VARDECL(int, stereo_mode);
898    VARDECL(int, fine_quant);
899    VARDECL(int, pulses);
900    VARDECL(int, offsets);
901
902    int shortBlocks;
903    int transient_time;
904    int transient_shift;
905    const int C = CHANNELS(st->mode);
906    SAVE_STACK;
907
908    if (check_mode(st->mode) != CELT_OK)
909       return CELT_INVALID_MODE;
910
911    N = st->block_size;
912    N4 = (N-st->overlap)>>1;
913
914    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
915    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
916    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
917    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
918    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
919    
920    if (check_mode(st->mode) != CELT_OK)
921    {
922       RESTORE_STACK;
923       return CELT_INVALID_MODE;
924    }
925    if (data == NULL)
926    {
927       celt_decode_lost(st, pcm);
928       RESTORE_STACK;
929       return 0;
930    }
931    
932    ec_byte_readinit(&buf,data,len);
933    ec_dec_init(&dec,&buf);
934    
935    has_pitch = ec_dec_bits(&dec, 1);
936    if (has_pitch)
937    {
938       has_fold = ec_dec_bits(&dec, 1);
939       shortBlocks = 0;
940    } else if (st->mode->nbShortMdcts > 1){
941       shortBlocks = ec_dec_bits(&dec, 1);
942       has_fold = 1;
943    } else {
944       shortBlocks = 0;
945       has_fold = 1;
946    }
947    if (shortBlocks)
948    {
949       transient_shift = ec_dec_bits(&dec, 2);
950       if (transient_shift)
951          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
952       else
953          transient_time = 0;
954    } else {
955       transient_time = -1;
956       transient_shift = 0;
957    }
958    /* Get the pitch gains */
959    
960    /* Get the pitch index */
961    if (has_pitch)
962    {
963       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
964       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
965       st->last_pitch_index = pitch_index;
966    } else {
967       /* FIXME: We could be more intelligent here and just not compute the MDCT */
968       pitch_index = 0;
969       for (i=0;i<st->mode->nbPBands;i++)
970          gains[i] = 0;
971    }
972
973    ALLOC(fine_quant, st->mode->nbEBands, int);
974    /* Get band energies */
975    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
976    
977    ALLOC(pulses, st->mode->nbEBands, int);
978    ALLOC(offsets, st->mode->nbEBands, int);
979    ALLOC(stereo_mode, st->mode->nbEBands, int);
980    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
981
982    for (i=0;i<st->mode->nbEBands;i++)
983       offsets[i] = 0;
984
985    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
986    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
987    /*bits = ec_dec_tell(&dec, 0);
988    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
989    
990    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
991
992
993    if (has_pitch) 
994    {
995       VARDECL(celt_ener_t, bandEp);
996       
997       /* Pitch MDCT */
998       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
999       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1000       compute_band_energies(st->mode, freq, bandEp);
1001       normalise_bands(st->mode, freq, P, bandEp);
1002    } else {
1003       for (i=0;i<C*N;i++)
1004          P[i] = 0;
1005    }
1006
1007    /* Apply pitch gains */
1008    pitch_quant_bands(st->mode, P, gains);
1009
1010    /* Decode fixed codebook and merge with pitch */
1011    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
1012
1013    if (C==2)
1014    {
1015       renormalise_bands(st->mode, X);
1016    }
1017    /* Synthesis */
1018    denormalise_bands(st->mode, X, freq, bandE);
1019
1020
1021    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1022    /* Compute inverse MDCTs */
1023    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1024
1025    for (c=0;c<C;c++)
1026    {
1027       int j;
1028       for (j=0;j<N;j++)
1029       {
1030          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1031                                 preemph,st->preemph_memD[c]);
1032          st->preemph_memD[c] = tmp;
1033          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1034       }
1035    }
1036
1037    {
1038       unsigned int val = 0;
1039       while (ec_dec_tell(&dec, 0) < len*8)
1040       {
1041          if (ec_dec_uint(&dec, 2) != val)
1042          {
1043             celt_warning("decode error");
1044             RESTORE_STACK;
1045             return CELT_CORRUPTED_DATA;
1046          }
1047          val = 1-val;
1048       }
1049    }
1050
1051    RESTORE_STACK;
1052    return 0;
1053    /*printf ("\n");*/
1054 }
1055
1056 #ifdef FIXED_POINT
1057 #ifndef DISABLE_FLOAT_API
1058 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1059 {
1060    int j, ret;
1061    const int C = CHANNELS(st->mode);
1062    const int N = st->block_size;
1063    VARDECL(celt_int16_t, out);
1064    SAVE_STACK;
1065    ALLOC(out, C*N, celt_int16_t);
1066
1067    ret=celt_decode(st, data, len, out);
1068
1069    for (j=0;j<C*N;j++)
1070      pcm[j]=out[j]*(1/32768.);
1071    RESTORE_STACK;
1072    return ret;
1073 }
1074 #endif /*DISABLE_FLOAT_API*/
1075 #else
1076 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1077 {
1078    int j, ret;
1079    VARDECL(celt_sig_t, out);
1080    const int C = CHANNELS(st->mode);
1081    const int N = st->block_size;
1082    SAVE_STACK;
1083    ALLOC(out, C*N, celt_sig_t);
1084
1085    ret=celt_decode_float(st, data, len, out);
1086
1087    for (j=0;j<C*N;j++)
1088      pcm[j] = FLOAT2INT16 (out[j]);
1089
1090    RESTORE_STACK;
1091    return ret;
1092 }
1093 #endif