Fix optional synthesis bug in fixed point mode.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    int has_fold=1;
378    ec_byte_buffer buf;
379    ec_enc         enc;
380    celt_word32_t curr_power, pitch_power=0;
381    VARDECL(celt_sig_t, in);
382    VARDECL(celt_sig_t, freq);
383    VARDECL(celt_norm_t, X);
384    VARDECL(celt_norm_t, P);
385    VARDECL(celt_ener_t, bandE);
386    VARDECL(celt_pgain_t, gains);
387    VARDECL(int, stereo_mode);
388    VARDECL(int, fine_quant);
389    VARDECL(celt_word16_t, error);
390    VARDECL(int, pulses);
391    VARDECL(int, offsets);
392 #ifdef EXP_PSY
393    VARDECL(celt_word32_t, mask);
394    VARDECL(celt_word32_t, tonality);
395    VARDECL(celt_word32_t, bandM);
396    VARDECL(celt_ener_t, bandN);
397 #endif
398    int shortBlocks=0;
399    int transient_time;
400    int transient_shift;
401    const int C = CHANNELS(st->mode);
402    SAVE_STACK;
403
404    if (check_mode(st->mode) != CELT_OK)
405       return CELT_INVALID_MODE;
406
407    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
408    CELT_MEMSET(compressed, 0, nbCompressedBytes);
409    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
410    ec_enc_init(&enc,&buf);
411
412    N = st->block_size;
413    N4 = (N-st->overlap)>>1;
414    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
415
416    CELT_COPY(in, st->in_mem, C*st->overlap);
417    for (c=0;c<C;c++)
418    {
419       const celt_word16_t * restrict pcmp = pcm+c;
420       celt_sig_t * restrict inp = in+C*st->overlap+c;
421       for (i=0;i<N;i++)
422       {
423          /* Apply pre-emphasis */
424          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
425          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
426          st->preemph_memE[c] = SCALEIN(*pcmp);
427          inp += C;
428          pcmp += C;
429       }
430    }
431    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
432    
433    /* Transient handling */
434    if (st->mode->nbShortMdcts > 1)
435    {
436       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
437       {
438 #ifndef FIXED_POINT
439          float gain_1;
440 #endif
441          ec_enc_bits(&enc, 0, 1); //Pitch off
442          ec_enc_bits(&enc, 1, 1); //Transient on
443          ec_enc_bits(&enc, transient_shift, 2);
444          if (transient_shift)
445             ec_enc_uint(&enc, transient_time, N+st->overlap);
446          /* Apply the inverse shaping window */
447          if (transient_shift)
448          {
449 #ifdef FIXED_POINT
450             for (c=0;c<C;c++)
451                for (i=0;i<16;i++)
452                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
453             for (c=0;c<C;c++)
454                for (i=transient_time;i<N+st->overlap;i++)
455                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
456 #else
457             for (c=0;c<C;c++)
458                for (i=0;i<16;i++)
459                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
460             gain_1 = 1./(1<<transient_shift);
461             for (c=0;c<C;c++)
462                for (i=transient_time;i<N+st->overlap;i++)
463                   in[C*i+c] *= gain_1;
464 #endif
465          }
466          shortBlocks = 1;
467       } else {
468          transient_time = -1;
469          transient_shift = 0;
470          shortBlocks = 0;
471       }
472    } else {
473       transient_time = -1;
474       transient_shift = 0;
475       shortBlocks = 0;
476    }
477
478    /* Pitch analysis: we do it early to save on the peak stack space */
479 #ifdef EXP_PSY
480    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
481    {
482       VARDECL(celt_word16_t, X);
483       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
484       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
485       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
486    }
487 #else
488    if (st->pitch_enabled && !shortBlocks)
489    {
490       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
491    }
492 #endif
493    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
494    
495    /* Compute MDCTs */
496    compute_mdcts(st->mode, shortBlocks, in, freq);
497
498 #ifdef EXP_PSY
499    ALLOC(mask, N, celt_sig_t);
500    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
501    /*for (i=0;i<256;i++)
502       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
503    printf ("\n");*/
504 #endif
505
506    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
507    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
508    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
509    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
510    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
511
512
513    /* Band normalisation */
514    compute_band_energies(st->mode, freq, bandE);
515    normalise_bands(st->mode, freq, X, bandE);
516
517 #ifdef EXP_PSY
518    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
519    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
520    compute_noise_energies(st->mode, freq, tonality, bandN);
521
522    /*for (i=0;i<st->mode->nbEBands;i++)
523       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
524    printf ("\n");*/
525    has_fold = 0;
526    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
527       if (bandN[i] < .4*bandE[i])
528          has_fold++;
529    /*printf ("%d\n", has_fold);*/
530    if (has_fold>=2)
531       has_fold = 0;
532    else
533       has_fold = 1;
534    for (i=0;i<N;i++)
535       mask[i] = sqrt(mask[i]);
536    compute_band_energies(st->mode, mask, bandM);
537    /*for (i=0;i<st->mode->nbEBands;i++)
538       printf ("%f %f ", bandE[i], bandM[i]);
539    printf ("\n");*/
540 #endif
541
542    /* Compute MDCTs of the pitch part */
543    if (st->pitch_enabled && !shortBlocks)
544    {
545       /* Normalise the pitch vector as well (discard the energies) */
546       VARDECL(celt_ener_t, bandEp);
547       
548       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
549       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
550       compute_band_energies(st->mode, freq, bandEp);
551       normalise_bands(st->mode, freq, P, bandEp);
552       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
553    }
554
555    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
556    curr_power = bandE[0]+bandE[1]+bandE[2];
557    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
558    {
559       int id;
560
561       /* Pitch prediction */
562       compute_pitch_gain(st->mode, X, P, gains);
563       id = quant_pitch(gains, st->mode->nbPBands, &enc);
564       if (id != -1)
565          has_pitch = 1;
566       else
567          has_pitch = 0;
568       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
569       if (has_pitch)
570       {
571          ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
572          ec_enc_bits(&enc, id, 7);
573          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
574       } else if (st->mode->nbShortMdcts > 1) {
575          ec_enc_bits(&enc, 0, 1); /* Transient off */
576          has_fold = 1;
577       }
578    } else {
579       if (!shortBlocks)
580       {
581          ec_enc_bits(&enc, 0, 1); /* Pitch off */
582          if (st->mode->nbShortMdcts > 1)
583            ec_enc_bits(&enc, 0, 1); /* Transient off */
584       }
585       has_fold = 1;
586       /* No pitch, so we just pretend we found a gain of zero */
587       for (i=0;i<st->mode->nbPBands;i++)
588          gains[i] = 0;
589       for (i=0;i<C*N;i++)
590          P[i] = 0;
591    }
592
593 #ifdef STDIN_TUNING2
594    static int fine_quant[30];
595    static int pulses[30];
596    static int init=0;
597    if (!init)
598    {
599       for (i=0;i<st->mode->nbEBands;i++)
600          scanf("%d ", &fine_quant[i]);
601       for (i=0;i<st->mode->nbEBands;i++)
602          scanf("%d ", &pulses[i]);
603       init = 1;
604    }
605 #else
606    ALLOC(fine_quant, st->mode->nbEBands, int);
607    ALLOC(pulses, st->mode->nbEBands, int);
608 #endif
609
610    /* Bit allocation */
611    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
612    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
613    
614    ALLOC(offsets, st->mode->nbEBands, int);
615    ALLOC(stereo_mode, st->mode->nbEBands, int);
616    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
617
618    for (i=0;i<st->mode->nbEBands;i++)
619       offsets[i] = 0;
620    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
621 #ifndef STDIN_TUNING
622    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
623 #endif
624
625    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
626
627    pitch_quant_bands(st->mode, P, gains);
628
629    /* Residual quantisation */
630    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
631
632    /* Re-synthesis of the coded audio if required */
633    if (st->pitch_enabled || optional_synthesis!=NULL)
634    {
635       if (C==2)
636          renormalise_bands(st->mode, X);
637       /* Synthesis */
638       denormalise_bands(st->mode, X, freq, bandE);
639       
640       
641       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
642       
643       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
644       /* De-emphasis and put everything back at the right place in the synthesis history */
645       if (optional_synthesis != NULL) {
646          for (c=0;c<C;c++)
647          {
648             int j;
649             for (j=0;j<N;j++)
650             {
651                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
652                                    preemph,st->preemph_memD[c]);
653                st->preemph_memD[c] = tmp;
654                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
655             }
656          }
657       }
658    }
659    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
660    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
661       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
662    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
663    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
664    {
665       int val = 0;
666       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
667       {
668          ec_enc_uint(&enc, val, 2);
669          val = 1-val;
670       }
671    }
672    ec_enc_done(&enc);
673    {
674       /*unsigned char *data;*/
675       int nbBytes = ec_byte_bytes(&buf);
676       if (nbBytes > nbCompressedBytes)
677       {
678          celt_warning_int ("got too many bytes:", nbBytes);
679          RESTORE_STACK;
680          return CELT_INTERNAL_ERROR;
681       }
682    }
683
684    RESTORE_STACK;
685    return nbCompressedBytes;
686 }
687
688 #ifdef FIXED_POINT
689 #ifndef DISABLE_FLOAT_API
690 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
691 {
692    int j, ret;
693    const int C = CHANNELS(st->mode);
694    const int N = st->block_size;
695    VARDECL(celt_int16_t, in);
696    SAVE_STACK;
697    ALLOC(in, C*N, celt_int16_t);
698
699    for (j=0;j<C*N;j++)
700      in[j] = FLOAT2INT16(pcm[j]);
701
702    if (optional_synthesis != NULL) {
703      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
704       for (j=0;j<C*N;j++)
705          optional_synthesis[j]=in[j]*(1/32768.);
706    } else {
707      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
708    }
709    RESTORE_STACK;
710    return ret;
711
712 }
713 #endif /*DISABLE_FLOAT_API*/
714 #else
715 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
716 {
717    int j, ret;
718    VARDECL(celt_sig_t, in);
719    const int C = CHANNELS(st->mode);
720    const int N = st->block_size;
721    SAVE_STACK;
722    ALLOC(in, C*N, celt_sig_t);
723    for (j=0;j<C*N;j++) {
724      in[j] = SCALEOUT(pcm[j]);
725    }
726
727    if (optional_synthesis != NULL) {
728       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
729       for (j=0;j<C*N;j++)
730          optional_synthesis[j] = FLOAT2INT16(in[j]);
731    } else {
732       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
733    }
734    RESTORE_STACK;
735    return ret;
736 }
737 #endif
738
739 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
740 {
741    va_list ap;
742    va_start(ap, request);
743    switch (request)
744    {
745       case CELT_SET_COMPLEXITY_REQUEST:
746       {
747          int value = va_arg(ap, int);
748          if (value<0 || value>10)
749             goto bad_arg;
750          if (value<=2)
751             st->pitch_enabled = 0;
752          else
753             st->pitch_enabled = 1;
754       }
755       break;
756       default:
757          goto bad_request;
758    }
759    va_end(ap);
760    return CELT_OK;
761 bad_arg:
762    va_end(ap);
763    return CELT_BAD_ARG;
764 bad_request:
765    va_end(ap);
766    return CELT_UNIMPLEMENTED;
767 }
768
769 /****************************************************************************/
770 /*                                                                          */
771 /*                                DECODER                                   */
772 /*                                                                          */
773 /****************************************************************************/
774
775
776 /** Decoder state 
777  @brief Decoder state
778  */
779 struct CELTDecoder {
780    const CELTMode *mode;
781    int frame_size;
782    int block_size;
783    int overlap;
784
785    ec_byte_buffer buf;
786    ec_enc         enc;
787
788    celt_sig_t * restrict preemph_memD;
789
790    celt_sig_t *out_mem;
791
792    celt_word16_t *oldBandE;
793    
794    int last_pitch_index;
795 };
796
797 CELTDecoder *celt_decoder_create(const CELTMode *mode)
798 {
799    int N, C;
800    CELTDecoder *st;
801
802    if (check_mode(mode) != CELT_OK)
803       return NULL;
804
805    N = mode->mdctSize;
806    C = CHANNELS(mode);
807    st = celt_alloc(sizeof(CELTDecoder));
808    
809    st->mode = mode;
810    st->frame_size = N;
811    st->block_size = N;
812    st->overlap = mode->overlap;
813
814    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
815    
816    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
817
818    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
819
820    st->last_pitch_index = 0;
821    return st;
822 }
823
824 void celt_decoder_destroy(CELTDecoder *st)
825 {
826    if (st == NULL)
827    {
828       celt_warning("NULL passed to celt_encoder_destroy");
829       return;
830    }
831    if (check_mode(st->mode) != CELT_OK)
832       return;
833
834
835    celt_free(st->out_mem);
836    
837    celt_free(st->oldBandE);
838    
839    celt_free(st->preemph_memD);
840
841    celt_free(st);
842 }
843
844 /** Handles lost packets by just copying past data with the same offset as the last
845     pitch period */
846 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
847 {
848    int c, N;
849    int pitch_index;
850    int i, len;
851    VARDECL(celt_sig_t, freq);
852    const int C = CHANNELS(st->mode);
853    int offset;
854    SAVE_STACK;
855    N = st->block_size;
856    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
857    
858    len = N+st->mode->overlap;
859 #if 0
860    pitch_index = st->last_pitch_index;
861    
862    /* Use the pitch MDCT as the "guessed" signal */
863    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
864
865 #else
866    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
867    pitch_index = MAX_PERIOD-len-pitch_index;
868    offset = MAX_PERIOD-pitch_index;
869    while (offset+len >= MAX_PERIOD)
870       offset -= pitch_index;
871    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
872    for (i=0;i<N;i++)
873       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
874 #endif
875    
876    
877    
878    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
879    /* Compute inverse MDCTs */
880    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
881
882    for (c=0;c<C;c++)
883    {
884       int j;
885       for (j=0;j<N;j++)
886       {
887          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
888                                 preemph,st->preemph_memD[c]);
889          st->preemph_memD[c] = tmp;
890          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
891       }
892    }
893    RESTORE_STACK;
894 }
895
896 #ifdef FIXED_POINT
897 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
898 {
899 #else
900 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
901 {
902 #endif
903    int i, c, N, N4;
904    int has_pitch, has_fold;
905    int pitch_index;
906    int bits;
907    ec_dec dec;
908    ec_byte_buffer buf;
909    VARDECL(celt_sig_t, freq);
910    VARDECL(celt_norm_t, X);
911    VARDECL(celt_norm_t, P);
912    VARDECL(celt_ener_t, bandE);
913    VARDECL(celt_pgain_t, gains);
914    VARDECL(int, stereo_mode);
915    VARDECL(int, fine_quant);
916    VARDECL(int, pulses);
917    VARDECL(int, offsets);
918
919    int shortBlocks;
920    int transient_time;
921    int transient_shift;
922    const int C = CHANNELS(st->mode);
923    SAVE_STACK;
924
925    if (check_mode(st->mode) != CELT_OK)
926       return CELT_INVALID_MODE;
927
928    N = st->block_size;
929    N4 = (N-st->overlap)>>1;
930
931    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
932    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
933    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
934    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
935    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
936    
937    if (check_mode(st->mode) != CELT_OK)
938    {
939       RESTORE_STACK;
940       return CELT_INVALID_MODE;
941    }
942    if (data == NULL)
943    {
944       celt_decode_lost(st, pcm);
945       RESTORE_STACK;
946       return 0;
947    }
948    
949    ec_byte_readinit(&buf,data,len);
950    ec_dec_init(&dec,&buf);
951    
952    has_pitch = ec_dec_bits(&dec, 1);
953    if (has_pitch)
954    {
955       has_fold = ec_dec_bits(&dec, 1);
956       shortBlocks = 0;
957    } else if (st->mode->nbShortMdcts > 1){
958       shortBlocks = ec_dec_bits(&dec, 1);
959       has_fold = 1;
960    } else {
961       shortBlocks = 0;
962       has_fold = 1;
963    }
964    if (shortBlocks)
965    {
966       transient_shift = ec_dec_bits(&dec, 2);
967       if (transient_shift)
968          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
969       else
970          transient_time = 0;
971    } else {
972       transient_time = -1;
973       transient_shift = 0;
974    }
975    /* Get the pitch gains */
976    
977    /* Get the pitch index */
978    if (has_pitch)
979    {
980       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
981       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
982       st->last_pitch_index = pitch_index;
983    } else {
984       /* FIXME: We could be more intelligent here and just not compute the MDCT */
985       pitch_index = 0;
986       for (i=0;i<st->mode->nbPBands;i++)
987          gains[i] = 0;
988    }
989
990    ALLOC(fine_quant, st->mode->nbEBands, int);
991    /* Get band energies */
992    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
993    
994    ALLOC(pulses, st->mode->nbEBands, int);
995    ALLOC(offsets, st->mode->nbEBands, int);
996    ALLOC(stereo_mode, st->mode->nbEBands, int);
997    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
998
999    for (i=0;i<st->mode->nbEBands;i++)
1000       offsets[i] = 0;
1001
1002    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1003    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1004    /*bits = ec_dec_tell(&dec, 0);
1005    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1006    
1007    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1008
1009
1010    if (has_pitch) 
1011    {
1012       VARDECL(celt_ener_t, bandEp);
1013       
1014       /* Pitch MDCT */
1015       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1016       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1017       compute_band_energies(st->mode, freq, bandEp);
1018       normalise_bands(st->mode, freq, P, bandEp);
1019    } else {
1020       for (i=0;i<C*N;i++)
1021          P[i] = 0;
1022    }
1023
1024    /* Apply pitch gains */
1025    pitch_quant_bands(st->mode, P, gains);
1026
1027    /* Decode fixed codebook and merge with pitch */
1028    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1029
1030    if (C==2)
1031    {
1032       renormalise_bands(st->mode, X);
1033    }
1034    /* Synthesis */
1035    denormalise_bands(st->mode, X, freq, bandE);
1036
1037
1038    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1039    /* Compute inverse MDCTs */
1040    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1041
1042    for (c=0;c<C;c++)
1043    {
1044       int j;
1045       for (j=0;j<N;j++)
1046       {
1047          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1048                                 preemph,st->preemph_memD[c]);
1049          st->preemph_memD[c] = tmp;
1050          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1051       }
1052    }
1053
1054    {
1055       unsigned int val = 0;
1056       while (ec_dec_tell(&dec, 0) < len*8)
1057       {
1058          if (ec_dec_uint(&dec, 2) != val)
1059          {
1060             celt_warning("decode error");
1061             RESTORE_STACK;
1062             return CELT_CORRUPTED_DATA;
1063          }
1064          val = 1-val;
1065       }
1066    }
1067
1068    RESTORE_STACK;
1069    return 0;
1070    /*printf ("\n");*/
1071 }
1072
1073 #ifdef FIXED_POINT
1074 #ifndef DISABLE_FLOAT_API
1075 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1076 {
1077    int j, ret;
1078    const int C = CHANNELS(st->mode);
1079    const int N = st->block_size;
1080    VARDECL(celt_int16_t, out);
1081    SAVE_STACK;
1082    ALLOC(out, C*N, celt_int16_t);
1083
1084    ret=celt_decode(st, data, len, out);
1085
1086    for (j=0;j<C*N;j++)
1087      pcm[j]=out[j]*(1/32768.);
1088    RESTORE_STACK;
1089    return ret;
1090 }
1091 #endif /*DISABLE_FLOAT_API*/
1092 #else
1093 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1094 {
1095    int j, ret;
1096    VARDECL(celt_sig_t, out);
1097    const int C = CHANNELS(st->mode);
1098    const int N = st->block_size;
1099    SAVE_STACK;
1100    ALLOC(out, C*N, celt_sig_t);
1101
1102    ret=celt_decode_float(st, data, len, out);
1103
1104    for (j=0;j<C*N;j++)
1105      pcm[j] = FLOAT2INT16 (out[j]);
1106
1107    RESTORE_STACK;
1108    return ret;
1109 }
1110 #endif