b27fee93038edf7a2311ac2b9311db64529ad585
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    int has_fold=1;
378    ec_byte_buffer buf;
379    ec_enc         enc;
380    celt_word32_t curr_power, pitch_power=0;
381    VARDECL(celt_sig_t, in);
382    VARDECL(celt_sig_t, freq);
383    VARDECL(celt_norm_t, X);
384    VARDECL(celt_norm_t, P);
385    VARDECL(celt_ener_t, bandE);
386    VARDECL(celt_pgain_t, gains);
387    VARDECL(int, stereo_mode);
388    VARDECL(int, fine_quant);
389    VARDECL(celt_word16_t, error);
390    VARDECL(int, pulses);
391    VARDECL(int, offsets);
392 #ifdef EXP_PSY
393    VARDECL(celt_word32_t, mask);
394    VARDECL(celt_word32_t, tonality);
395    VARDECL(celt_word32_t, bandM);
396    VARDECL(celt_ener_t, bandN);
397 #endif
398    int shortBlocks=0;
399    int transient_time;
400    int transient_shift;
401    const int C = CHANNELS(st->mode);
402    SAVE_STACK;
403
404    if (check_mode(st->mode) != CELT_OK)
405       return CELT_INVALID_MODE;
406
407    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
408    CELT_MEMSET(compressed, 0, nbCompressedBytes);
409    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
410    ec_enc_init(&enc,&buf);
411
412    N = st->block_size;
413    N4 = (N-st->overlap)>>1;
414    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
415
416    CELT_COPY(in, st->in_mem, C*st->overlap);
417    for (c=0;c<C;c++)
418    {
419       const celt_word16_t * restrict pcmp = pcm+c;
420       celt_sig_t * restrict inp = in+C*st->overlap+c;
421       for (i=0;i<N;i++)
422       {
423          /* Apply pre-emphasis */
424          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
425          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
426          st->preemph_memE[c] = SCALEIN(*pcmp);
427          inp += C;
428          pcmp += C;
429       }
430    }
431    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
432    
433    /* Transient handling */
434    if (st->mode->nbShortMdcts > 1)
435    {
436       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
437       {
438 #ifndef FIXED_POINT
439          float gain_1;
440 #endif
441          ec_enc_bits(&enc, 0, 1); //Pitch off
442          ec_enc_bits(&enc, 1, 1); //Transient on
443          ec_enc_bits(&enc, transient_shift, 2);
444          if (transient_shift)
445             ec_enc_uint(&enc, transient_time, N+st->overlap);
446          /* Apply the inverse shaping window */
447          if (transient_shift)
448          {
449 #ifdef FIXED_POINT
450             for (c=0;c<C;c++)
451                for (i=0;i<16;i++)
452                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
453             for (c=0;c<C;c++)
454                for (i=transient_time;i<N+st->overlap;i++)
455                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
456 #else
457             for (c=0;c<C;c++)
458                for (i=0;i<16;i++)
459                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
460             gain_1 = 1./(1<<transient_shift);
461             for (c=0;c<C;c++)
462                for (i=transient_time;i<N+st->overlap;i++)
463                   in[C*i+c] *= gain_1;
464 #endif
465          }
466          shortBlocks = 1;
467       } else {
468          transient_time = -1;
469          transient_shift = 0;
470          shortBlocks = 0;
471       }
472    } else {
473       transient_time = -1;
474       transient_shift = 0;
475       shortBlocks = 0;
476    }
477
478    /* Pitch analysis: we do it early to save on the peak stack space */
479 #ifdef EXP_PSY
480    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
481    {
482       VARDECL(celt_word16_t, X);
483       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
484       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
485       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
486    }
487 #else
488    if (st->pitch_enabled && !shortBlocks)
489    {
490       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
491    }
492 #endif
493    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
494    
495    /* Compute MDCTs */
496    compute_mdcts(st->mode, shortBlocks, in, freq);
497
498 #ifdef EXP_PSY
499    ALLOC(mask, N, celt_sig_t);
500    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
501    /*for (i=0;i<256;i++)
502       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
503    printf ("\n");*/
504 #endif
505
506    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
507    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
508    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
509    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
510    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
511
512
513    /* Band normalisation */
514    compute_band_energies(st->mode, freq, bandE);
515    normalise_bands(st->mode, freq, X, bandE);
516
517 #ifdef EXP_PSY
518    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
519    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
520    compute_noise_energies(st->mode, freq, tonality, bandN);
521
522    /*for (i=0;i<st->mode->nbEBands;i++)
523       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
524    printf ("\n");*/
525    has_fold = 0;
526    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
527       if (bandN[i] < .4*bandE[i])
528          has_fold++;
529    /*printf ("%d\n", has_fold);*/
530    if (has_fold>=2)
531       has_fold = 0;
532    else
533       has_fold = 1;
534    for (i=0;i<N;i++)
535       mask[i] = sqrt(mask[i]);
536    compute_band_energies(st->mode, mask, bandM);
537    /*for (i=0;i<st->mode->nbEBands;i++)
538       printf ("%f %f ", bandE[i], bandM[i]);
539    printf ("\n");*/
540 #endif
541
542    /* Compute MDCTs of the pitch part */
543    if (st->pitch_enabled && !shortBlocks)
544    {
545       /* Normalise the pitch vector as well (discard the energies) */
546       VARDECL(celt_ener_t, bandEp);
547       
548       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
549       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
550       compute_band_energies(st->mode, freq, bandEp);
551       normalise_bands(st->mode, freq, P, bandEp);
552       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
553    }
554
555    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
556    curr_power = bandE[0]+bandE[1]+bandE[2];
557    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
558    {
559       int id;
560
561       /* Pitch prediction */
562       compute_pitch_gain(st->mode, X, P, gains);
563       id = quant_pitch(gains, st->mode->nbPBands, &enc);
564       if (id != -1)
565          has_pitch = 1;
566       else
567          has_pitch = 0;
568       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
569       if (has_pitch)
570       {
571          ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
572          ec_enc_bits(&enc, id, 7);
573          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
574       } else if (st->mode->nbShortMdcts > 1) {
575          ec_enc_bits(&enc, 0, 1); /* Transient off */
576          has_fold = 1;
577       }
578    } else {
579       if (!shortBlocks)
580       {
581          ec_enc_bits(&enc, 0, 1); /* Pitch off */
582          if (st->mode->nbShortMdcts > 1)
583            ec_enc_bits(&enc, 0, 1); /* Transient off */
584       }
585       has_fold = 1;
586       /* No pitch, so we just pretend we found a gain of zero */
587       for (i=0;i<st->mode->nbPBands;i++)
588          gains[i] = 0;
589       for (i=0;i<C*N;i++)
590          P[i] = 0;
591    }
592
593 #ifdef STDIN_TUNING2
594    static int fine_quant[30];
595    static int pulses[30];
596    static int init=0;
597    if (!init)
598    {
599       for (i=0;i<st->mode->nbEBands;i++)
600          scanf("%d ", &fine_quant[i]);
601       for (i=0;i<st->mode->nbEBands;i++)
602          scanf("%d ", &pulses[i]);
603       init = 1;
604    }
605 #else
606    ALLOC(fine_quant, st->mode->nbEBands, int);
607    ALLOC(pulses, st->mode->nbEBands, int);
608 #endif
609
610    /* Bit allocation */
611    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
612    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
613    
614    ALLOC(offsets, st->mode->nbEBands, int);
615    ALLOC(stereo_mode, st->mode->nbEBands, int);
616    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
617
618    for (i=0;i<st->mode->nbEBands;i++)
619       offsets[i] = 0;
620    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
621 #ifndef STDIN_TUNING
622    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
623 #endif
624
625    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
626
627    pitch_quant_bands(st->mode, P, gains);
628
629    /* Residual quantisation */
630    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
631
632    /* Re-synthesis of the coded audio if required */
633    if (st->pitch_enabled || optional_synthesis!=NULL)
634    {
635       if (C==2)
636          renormalise_bands(st->mode, X);
637       /* Synthesis */
638       denormalise_bands(st->mode, X, freq, bandE);
639       
640       
641       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
642       
643       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
644       /* De-emphasis and put everything back at the right place in the synthesis history */
645       if (optional_synthesis != NULL) {
646          for (c=0;c<C;c++)
647          {
648             int j;
649             for (j=0;j<N;j++)
650             {
651                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
652                                    preemph,st->preemph_memD[c]);
653                st->preemph_memD[c] = tmp;
654                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
655             }
656          }
657       }
658    }
659    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
660    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
661       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
662    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
663    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
664    {
665       int val = 0;
666       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
667       {
668          ec_enc_uint(&enc, val, 2);
669          val = 1-val;
670       }
671    }
672    ec_enc_done(&enc);
673    {
674       /*unsigned char *data;*/
675       int nbBytes = ec_byte_bytes(&buf);
676       if (nbBytes > nbCompressedBytes)
677       {
678          celt_warning_int ("got too many bytes:", nbBytes);
679          RESTORE_STACK;
680          return CELT_INTERNAL_ERROR;
681       }
682    }
683
684    RESTORE_STACK;
685    return nbCompressedBytes;
686 }
687
688 #ifdef FIXED_POINT
689 #ifndef DISABLE_FLOAT_API
690 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
691 {
692    int j, ret;
693    const int C = CHANNELS(st->mode);
694    const int N = st->block_size;
695    VARDECL(celt_int16_t, in);
696    SAVE_STACK;
697    ALLOC(in, C*N, celt_int16_t);
698
699    for (j=0;j<C*N;j++)
700      in[j] = FLOAT2INT16(pcm[j]);
701
702    if (optional_synthesis != NULL) {
703      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
704    /*Converts backwards for inplace operation*/
705       for (j=0;j=C*N;j++)
706          optional_synthesis[j]=in[j]*(1/32768.);
707    } else {
708      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
709    }
710    RESTORE_STACK;
711    return ret;
712
713 }
714 #endif /*DISABLE_FLOAT_API*/
715 #else
716 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
717 {
718    int j, ret;
719    VARDECL(celt_sig_t, in);
720    const int C = CHANNELS(st->mode);
721    const int N = st->block_size;
722    SAVE_STACK;
723    ALLOC(in, C*N, celt_sig_t);
724    for (j=0;j<C*N;j++) {
725      in[j] = SCALEOUT(pcm[j]);
726    }
727
728    if (optional_synthesis != NULL) {
729       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
730       for (j=0;j<C*N;j++)
731          optional_synthesis[j] = FLOAT2INT16(in[j]);
732    } else {
733       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
734    }
735    RESTORE_STACK;
736    return ret;
737 }
738 #endif
739
740 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
741 {
742    va_list ap;
743    va_start(ap, request);
744    switch (request)
745    {
746       case CELT_SET_COMPLEXITY_REQUEST:
747       {
748          int value = va_arg(ap, int);
749          if (value<0 || value>10)
750             goto bad_arg;
751          if (value<=2)
752             st->pitch_enabled = 0;
753          else
754             st->pitch_enabled = 1;
755       }
756       break;
757       default:
758          goto bad_request;
759    }
760    va_end(ap);
761    return CELT_OK;
762 bad_arg:
763    va_end(ap);
764    return CELT_BAD_ARG;
765 bad_request:
766    va_end(ap);
767    return CELT_UNIMPLEMENTED;
768 }
769
770 /****************************************************************************/
771 /*                                                                          */
772 /*                                DECODER                                   */
773 /*                                                                          */
774 /****************************************************************************/
775
776
777 /** Decoder state 
778  @brief Decoder state
779  */
780 struct CELTDecoder {
781    const CELTMode *mode;
782    int frame_size;
783    int block_size;
784    int overlap;
785
786    ec_byte_buffer buf;
787    ec_enc         enc;
788
789    celt_sig_t * restrict preemph_memD;
790
791    celt_sig_t *out_mem;
792
793    celt_word16_t *oldBandE;
794    
795    int last_pitch_index;
796 };
797
798 CELTDecoder *celt_decoder_create(const CELTMode *mode)
799 {
800    int N, C;
801    CELTDecoder *st;
802
803    if (check_mode(mode) != CELT_OK)
804       return NULL;
805
806    N = mode->mdctSize;
807    C = CHANNELS(mode);
808    st = celt_alloc(sizeof(CELTDecoder));
809    
810    st->mode = mode;
811    st->frame_size = N;
812    st->block_size = N;
813    st->overlap = mode->overlap;
814
815    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
816    
817    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
818
819    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
820
821    st->last_pitch_index = 0;
822    return st;
823 }
824
825 void celt_decoder_destroy(CELTDecoder *st)
826 {
827    if (st == NULL)
828    {
829       celt_warning("NULL passed to celt_encoder_destroy");
830       return;
831    }
832    if (check_mode(st->mode) != CELT_OK)
833       return;
834
835
836    celt_free(st->out_mem);
837    
838    celt_free(st->oldBandE);
839    
840    celt_free(st->preemph_memD);
841
842    celt_free(st);
843 }
844
845 /** Handles lost packets by just copying past data with the same offset as the last
846     pitch period */
847 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
848 {
849    int c, N;
850    int pitch_index;
851    int i, len;
852    VARDECL(celt_sig_t, freq);
853    const int C = CHANNELS(st->mode);
854    int offset;
855    SAVE_STACK;
856    N = st->block_size;
857    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
858    
859    len = N+st->mode->overlap;
860 #if 0
861    pitch_index = st->last_pitch_index;
862    
863    /* Use the pitch MDCT as the "guessed" signal */
864    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
865
866 #else
867    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
868    pitch_index = MAX_PERIOD-len-pitch_index;
869    offset = MAX_PERIOD-pitch_index;
870    while (offset+len >= MAX_PERIOD)
871       offset -= pitch_index;
872    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
873    for (i=0;i<N;i++)
874       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
875 #endif
876    
877    
878    
879    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
880    /* Compute inverse MDCTs */
881    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
882
883    for (c=0;c<C;c++)
884    {
885       int j;
886       for (j=0;j<N;j++)
887       {
888          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
889                                 preemph,st->preemph_memD[c]);
890          st->preemph_memD[c] = tmp;
891          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
892       }
893    }
894    RESTORE_STACK;
895 }
896
897 #ifdef FIXED_POINT
898 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
899 {
900 #else
901 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
902 {
903 #endif
904    int i, c, N, N4;
905    int has_pitch, has_fold;
906    int pitch_index;
907    int bits;
908    ec_dec dec;
909    ec_byte_buffer buf;
910    VARDECL(celt_sig_t, freq);
911    VARDECL(celt_norm_t, X);
912    VARDECL(celt_norm_t, P);
913    VARDECL(celt_ener_t, bandE);
914    VARDECL(celt_pgain_t, gains);
915    VARDECL(int, stereo_mode);
916    VARDECL(int, fine_quant);
917    VARDECL(int, pulses);
918    VARDECL(int, offsets);
919
920    int shortBlocks;
921    int transient_time;
922    int transient_shift;
923    const int C = CHANNELS(st->mode);
924    SAVE_STACK;
925
926    if (check_mode(st->mode) != CELT_OK)
927       return CELT_INVALID_MODE;
928
929    N = st->block_size;
930    N4 = (N-st->overlap)>>1;
931
932    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
933    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
934    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
935    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
936    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
937    
938    if (check_mode(st->mode) != CELT_OK)
939    {
940       RESTORE_STACK;
941       return CELT_INVALID_MODE;
942    }
943    if (data == NULL)
944    {
945       celt_decode_lost(st, pcm);
946       RESTORE_STACK;
947       return 0;
948    }
949    
950    ec_byte_readinit(&buf,data,len);
951    ec_dec_init(&dec,&buf);
952    
953    has_pitch = ec_dec_bits(&dec, 1);
954    if (has_pitch)
955    {
956       has_fold = ec_dec_bits(&dec, 1);
957       shortBlocks = 0;
958    } else if (st->mode->nbShortMdcts > 1){
959       shortBlocks = ec_dec_bits(&dec, 1);
960       has_fold = 1;
961    } else {
962       shortBlocks = 0;
963       has_fold = 1;
964    }
965    if (shortBlocks)
966    {
967       transient_shift = ec_dec_bits(&dec, 2);
968       if (transient_shift)
969          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
970       else
971          transient_time = 0;
972    } else {
973       transient_time = -1;
974       transient_shift = 0;
975    }
976    /* Get the pitch gains */
977    
978    /* Get the pitch index */
979    if (has_pitch)
980    {
981       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
982       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
983       st->last_pitch_index = pitch_index;
984    } else {
985       /* FIXME: We could be more intelligent here and just not compute the MDCT */
986       pitch_index = 0;
987       for (i=0;i<st->mode->nbPBands;i++)
988          gains[i] = 0;
989    }
990
991    ALLOC(fine_quant, st->mode->nbEBands, int);
992    /* Get band energies */
993    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
994    
995    ALLOC(pulses, st->mode->nbEBands, int);
996    ALLOC(offsets, st->mode->nbEBands, int);
997    ALLOC(stereo_mode, st->mode->nbEBands, int);
998    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
999
1000    for (i=0;i<st->mode->nbEBands;i++)
1001       offsets[i] = 0;
1002
1003    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1004    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1005    /*bits = ec_dec_tell(&dec, 0);
1006    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1007    
1008    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1009
1010
1011    if (has_pitch) 
1012    {
1013       VARDECL(celt_ener_t, bandEp);
1014       
1015       /* Pitch MDCT */
1016       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1017       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1018       compute_band_energies(st->mode, freq, bandEp);
1019       normalise_bands(st->mode, freq, P, bandEp);
1020    } else {
1021       for (i=0;i<C*N;i++)
1022          P[i] = 0;
1023    }
1024
1025    /* Apply pitch gains */
1026    pitch_quant_bands(st->mode, P, gains);
1027
1028    /* Decode fixed codebook and merge with pitch */
1029    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1030
1031    if (C==2)
1032    {
1033       renormalise_bands(st->mode, X);
1034    }
1035    /* Synthesis */
1036    denormalise_bands(st->mode, X, freq, bandE);
1037
1038
1039    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1040    /* Compute inverse MDCTs */
1041    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1042
1043    for (c=0;c<C;c++)
1044    {
1045       int j;
1046       for (j=0;j<N;j++)
1047       {
1048          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1049                                 preemph,st->preemph_memD[c]);
1050          st->preemph_memD[c] = tmp;
1051          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1052       }
1053    }
1054
1055    {
1056       unsigned int val = 0;
1057       while (ec_dec_tell(&dec, 0) < len*8)
1058       {
1059          if (ec_dec_uint(&dec, 2) != val)
1060          {
1061             celt_warning("decode error");
1062             RESTORE_STACK;
1063             return CELT_CORRUPTED_DATA;
1064          }
1065          val = 1-val;
1066       }
1067    }
1068
1069    RESTORE_STACK;
1070    return 0;
1071    /*printf ("\n");*/
1072 }
1073
1074 #ifdef FIXED_POINT
1075 #ifndef DISABLE_FLOAT_API
1076 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1077 {
1078    int j, ret;
1079    const int C = CHANNELS(st->mode);
1080    const int N = st->block_size;
1081    VARDECL(celt_int16_t, out);
1082    SAVE_STACK;
1083    ALLOC(out, C*N, celt_int16_t);
1084
1085    ret=celt_decode(st, data, len, out);
1086
1087    for (j=0;j<C*N;j++)
1088      pcm[j]=out[j]*(1/32768.);
1089    RESTORE_STACK;
1090    return ret;
1091 }
1092 #endif /*DISABLE_FLOAT_API*/
1093 #else
1094 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1095 {
1096    int j, ret;
1097    VARDECL(celt_sig_t, out);
1098    const int C = CHANNELS(st->mode);
1099    const int N = st->block_size;
1100    SAVE_STACK;
1101    ALLOC(out, C*N, celt_sig_t);
1102
1103    ret=celt_decode_float(st, data, len, out);
1104
1105    for (j=0;j<C*N;j++)
1106      pcm[j] = FLOAT2INT16 (out[j]);
1107
1108    RESTORE_STACK;
1109    return ret;
1110 }
1111 #endif