Add support for intra-coding of the coarse energy.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67    
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int frame_size;
74    int block_size;
75    int overlap;
76    int channels;
77    
78    int pitch_enabled;
79    int pitch_available;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112    st->pitch_available = 1;
113
114    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
115    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
116
117    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
118
119    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
120    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
121
122 #ifdef EXP_PSY
123    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
124    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
125 #endif
126
127    return st;
128 }
129
130 void celt_encoder_destroy(CELTEncoder *st)
131 {
132    if (st == NULL)
133    {
134       celt_warning("NULL passed to celt_encoder_destroy");
135       return;
136    }
137    if (check_mode(st->mode) != CELT_OK)
138       return;
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t FLOAT2INT16(float x)
157 {
158    x = x*32768.;
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161    return (celt_int16_t)float2int(x);
162 }
163
164 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
165 {
166 #ifdef FIXED_POINT
167    x = PSHR32(x, SIG_SHIFT);
168    x = MAX32(x, -32768);
169    x = MIN32(x, 32767);
170    return EXTRACT16(x);
171 #else
172    return (celt_word16_t)x;
173 #endif
174 }
175
176 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
177 {
178    int c, i, n;
179    celt_word32_t ratio;
180    /* FIXME: Remove the floats here */
181    VARDECL(celt_word32_t, begin);
182    SAVE_STACK;
183    ALLOC(begin, len, celt_word32_t);
184    for (i=0;i<len;i++)
185       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
186    for (c=1;c<C;c++)
187    {
188       for (i=0;i<len;i++)
189          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
190    }
191    for (i=1;i<len;i++)
192       begin[i] = MAX32(begin[i-1],begin[i]);
193    n = -1;
194    for (i=8;i<len-8;i++)
195    {
196       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
197          n=i;
198    }
199    if (n<32)
200    {
201       n = -1;
202       ratio = 0;
203    } else {
204       ratio = DIV32(begin[len-1],1+begin[n-16]);
205    }
206    /*printf ("%d %f\n", n, ratio*ratio);*/
207    if (ratio < 0)
208       ratio = 0;
209    if (ratio > 1000)
210       ratio = 1000;
211    ratio *= ratio;
212    if (ratio < 50)
213       *transient_shift = 0;
214    else if (ratio < 256)
215       *transient_shift = 1;
216    else if (ratio < 4096)
217       *transient_shift = 2;
218    else
219       *transient_shift = 3;
220    *transient_time = n;
221    
222    RESTORE_STACK;
223    return ratio > 20;
224 }
225
226 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
227 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
228 {
229    const int C = CHANNELS(mode);
230    if (C==1 && !shortBlocks)
231    {
232       const mdct_lookup *lookup = MDCT(mode);
233       const int overlap = OVERLAP(mode);
234       mdct_forward(lookup, in, out, mode->window, overlap);
235    } else if (!shortBlocks) {
236       const mdct_lookup *lookup = MDCT(mode);
237       const int overlap = OVERLAP(mode);
238       const int N = FRAMESIZE(mode);
239       int c;
240       VARDECL(celt_word32_t, x);
241       VARDECL(celt_word32_t, tmp);
242       SAVE_STACK;
243       ALLOC(x, N+overlap, celt_word32_t);
244       ALLOC(tmp, N, celt_word32_t);
245       for (c=0;c<C;c++)
246       {
247          int j;
248          for (j=0;j<N+overlap;j++)
249             x[j] = in[C*j+c];
250          mdct_forward(lookup, x, tmp, mode->window, overlap);
251          /* Interleaving the sub-frames */
252          for (j=0;j<N;j++)
253             out[C*j+c] = tmp[j];
254       }
255       RESTORE_STACK;
256    } else {
257       const mdct_lookup *lookup = &mode->shortMdct;
258       const int overlap = mode->overlap;
259       const int N = mode->shortMdctSize;
260       int b, c;
261       VARDECL(celt_word32_t, x);
262       VARDECL(celt_word32_t, tmp);
263       SAVE_STACK;
264       ALLOC(x, N+overlap, celt_word32_t);
265       ALLOC(tmp, N, celt_word32_t);
266       for (c=0;c<C;c++)
267       {
268          int B = mode->nbShortMdcts;
269          for (b=0;b<B;b++)
270          {
271             int j;
272             for (j=0;j<N+overlap;j++)
273                x[j] = in[C*(b*N+j)+c];
274             mdct_forward(lookup, x, tmp, mode->window, overlap);
275             /* Interleaving the sub-frames */
276             for (j=0;j<N;j++)
277                out[C*(j*B+b)+c] = tmp[j];
278          }
279       }
280       RESTORE_STACK;
281    }
282 }
283
284 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
285 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
286 {
287    int c, N4;
288    const int C = CHANNELS(mode);
289    const int N = FRAMESIZE(mode);
290    const int overlap = OVERLAP(mode);
291    N4 = (N-overlap)>>1;
292    for (c=0;c<C;c++)
293    {
294       int j;
295       if (transient_shift==0 && C==1 && !shortBlocks) {
296          const mdct_lookup *lookup = MDCT(mode);
297          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
298       } else if (!shortBlocks) {
299          const mdct_lookup *lookup = MDCT(mode);
300          VARDECL(celt_word32_t, x);
301          VARDECL(celt_word32_t, tmp);
302          SAVE_STACK;
303          ALLOC(x, 2*N, celt_word32_t);
304          ALLOC(tmp, N, celt_word32_t);
305          /* De-interleaving the sub-frames */
306          for (j=0;j<N;j++)
307             tmp[j] = X[C*j+c];
308          /* Prevents problems from the imdct doing the overlap-add */
309          CELT_MEMSET(x+N4, 0, N);
310          mdct_backward(lookup, tmp, x, mode->window, overlap);
311          celt_assert(transient_shift == 0);
312          /* The first and last part would need to be set to zero if we actually
313             wanted to use them. */
314          for (j=0;j<overlap;j++)
315             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
316          for (j=0;j<overlap;j++)
317             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
318          for (j=0;j<2*N4;j++)
319             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
320          RESTORE_STACK;
321       } else {
322          int b;
323          const int N2 = mode->shortMdctSize;
324          const int B = mode->nbShortMdcts;
325          const mdct_lookup *lookup = &mode->shortMdct;
326          VARDECL(celt_word32_t, x);
327          VARDECL(celt_word32_t, tmp);
328          SAVE_STACK;
329          ALLOC(x, 2*N, celt_word32_t);
330          ALLOC(tmp, N, celt_word32_t);
331          /* Prevents problems from the imdct doing the overlap-add */
332          CELT_MEMSET(x+N4, 0, N2);
333          for (b=0;b<B;b++)
334          {
335             /* De-interleaving the sub-frames */
336             for (j=0;j<N2;j++)
337                tmp[j] = X[C*(j*B+b)+c];
338             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
339          }
340          if (transient_shift > 0)
341          {
342 #ifdef FIXED_POINT
343             for (j=0;j<16;j++)
344                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
345             for (j=transient_time;j<N+overlap;j++)
346                x[N4+j] = SHL32(x[N4+j], transient_shift);
347 #else
348             for (j=0;j<16;j++)
349                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
350             for (j=transient_time;j<N+overlap;j++)
351                x[N4+j] *= 1<<transient_shift;
352 #endif
353          }
354          /* The first and last part would need to be set to zero if we actually
355          wanted to use them. */
356          for (j=0;j<overlap;j++)
357             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
358          for (j=0;j<overlap;j++)
359             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
360          for (j=0;j<2*N4;j++)
361             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
362          RESTORE_STACK;
363       }
364    }
365 }
366
367 #define FLAG_NONE        0
368 #define FLAG_INTRA       1U<<16
369 #define FLAG_PITCH       1U<<15
370 #define FLAG_SHORT       1U<<14
371 #define FLAG_FOLD        1U<<13
372 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
373
374 celt_int32_t flaglist[8] = {
375       0 /*00  */ | FLAG_FOLD,
376       1 /*01  */ | FLAG_PITCH|FLAG_FOLD,
377       8 /*1000*/ | FLAG_NONE,
378       9 /*1001*/ | FLAG_SHORT|FLAG_FOLD,
379      10 /*1010*/ | FLAG_PITCH,
380      11 /*1011*/ | FLAG_INTRA,
381       6 /*110 */ | FLAG_INTRA|FLAG_FOLD,
382       7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
383 };
384
385 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
386 {
387    int i;
388    int flags=FLAG_NONE;
389    int flag_bits;
390    flags |= intra_ener   ? FLAG_INTRA : 0;
391    flags |= has_pitch    ? FLAG_PITCH : 0;
392    flags |= shortBlocks  ? FLAG_SHORT : 0;
393    flags |= has_fold     ? FLAG_FOLD  : 0;
394    for (i=0;i<8;i++)
395       if (flags == (flaglist[i]&FLAG_MASK))
396          break;
397    celt_assert(i<8);
398    flag_bits = flaglist[i]&0xf;
399    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
400    if (i<2)
401       ec_enc_bits(enc, flag_bits, 2);
402    else if (i<6)
403       ec_enc_bits(enc, flag_bits, 4);
404    else
405       ec_enc_bits(enc, flag_bits, 3);
406 }
407
408 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
409 {
410    int i;
411    int flag_bits;
412    flag_bits = ec_dec_bits(dec, 2);
413    /*printf ("(%d) ", flag_bits);*/
414    if (flag_bits==2)
415       flag_bits = (flag_bits<<2) | ec_dec_bits(dec, 2);
416    else if (flag_bits==3)
417       flag_bits = (flag_bits<<1) | ec_dec_bits(dec, 1);
418    for (i=0;i<8;i++)
419       if (flag_bits == (flaglist[i]&0xf))
420          break;
421    celt_assert(i<8);
422    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
423    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
424    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
425    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
426    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
427 }
428
429 #ifdef FIXED_POINT
430 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
431 {
432 #else
433 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
434 {
435 #endif
436    int i, c, N, N4;
437    int has_pitch;
438    int pitch_index;
439    int bits;
440    int has_fold=1;
441    ec_byte_buffer buf;
442    ec_enc         enc;
443    VARDECL(celt_sig_t, in);
444    VARDECL(celt_sig_t, freq);
445    VARDECL(celt_norm_t, X);
446    VARDECL(celt_norm_t, P);
447    VARDECL(celt_ener_t, bandE);
448    VARDECL(celt_pgain_t, gains);
449    VARDECL(int, stereo_mode);
450    VARDECL(int, fine_quant);
451    VARDECL(celt_word16_t, error);
452    VARDECL(int, pulses);
453    VARDECL(int, offsets);
454 #ifdef EXP_PSY
455    VARDECL(celt_word32_t, mask);
456    VARDECL(celt_word32_t, tonality);
457    VARDECL(celt_word32_t, bandM);
458    VARDECL(celt_ener_t, bandN);
459 #endif
460    int intra_ener = 0;
461    int shortBlocks=0;
462    int transient_time;
463    int transient_shift;
464    const int C = CHANNELS(st->mode);
465    SAVE_STACK;
466
467    if (check_mode(st->mode) != CELT_OK)
468       return CELT_INVALID_MODE;
469
470    if (nbCompressedBytes<0)
471      return CELT_BAD_ARG; 
472
473    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
474    CELT_MEMSET(compressed, 0, nbCompressedBytes);
475    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
476    ec_enc_init(&enc,&buf);
477
478    N = st->block_size;
479    N4 = (N-st->overlap)>>1;
480    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
481
482    CELT_COPY(in, st->in_mem, C*st->overlap);
483    for (c=0;c<C;c++)
484    {
485       const celt_word16_t * restrict pcmp = pcm+c;
486       celt_sig_t * restrict inp = in+C*st->overlap+c;
487       for (i=0;i<N;i++)
488       {
489          /* Apply pre-emphasis */
490          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
491          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
492          st->preemph_memE[c] = SCALEIN(*pcmp);
493          inp += C;
494          pcmp += C;
495       }
496    }
497    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
498    
499    /* Transient handling */
500    if (st->mode->nbShortMdcts > 1)
501    {
502       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
503       {
504 #ifndef FIXED_POINT
505          float gain_1;
506 #endif
507          /* Apply the inverse shaping window */
508          if (transient_shift)
509          {
510 #ifdef FIXED_POINT
511             for (c=0;c<C;c++)
512                for (i=0;i<16;i++)
513                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
514             for (c=0;c<C;c++)
515                for (i=transient_time;i<N+st->overlap;i++)
516                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
517 #else
518             for (c=0;c<C;c++)
519                for (i=0;i<16;i++)
520                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
521             gain_1 = 1./(1<<transient_shift);
522             for (c=0;c<C;c++)
523                for (i=transient_time;i<N+st->overlap;i++)
524                   in[C*i+c] *= gain_1;
525 #endif
526          }
527          shortBlocks = 1;
528       } else {
529          transient_time = -1;
530          transient_shift = 0;
531          shortBlocks = 0;
532       }
533    } else {
534       transient_time = -1;
535       transient_shift = 0;
536       shortBlocks = 0;
537    }
538
539    /* Pitch analysis: we do it early to save on the peak stack space */
540    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
541    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
542 #ifdef EXP_PSY
543    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
544    {
545       VARDECL(celt_word16_t, X);
546       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
547       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
548       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
549    }
550 #else
551    if (has_pitch)
552    {
553       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
554    }
555 #endif
556    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
557    
558    /* Compute MDCTs */
559    compute_mdcts(st->mode, shortBlocks, in, freq);
560
561 #ifdef EXP_PSY
562    ALLOC(mask, N, celt_sig_t);
563    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
564    /*for (i=0;i<256;i++)
565       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
566    printf ("\n");*/
567 #endif
568
569    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
570    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
571    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
572    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
573    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
574
575
576    /* Band normalisation */
577    compute_band_energies(st->mode, freq, bandE);
578    normalise_bands(st->mode, freq, X, bandE);
579
580 #ifdef EXP_PSY
581    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
582    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
583    compute_noise_energies(st->mode, freq, tonality, bandN);
584
585    /*for (i=0;i<st->mode->nbEBands;i++)
586       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
587    printf ("\n");*/
588    has_fold = 0;
589    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
590       if (bandN[i] < .4*bandE[i])
591          has_fold++;
592    /*printf ("%d\n", has_fold);*/
593    if (has_fold>=2)
594       has_fold = 0;
595    else
596       has_fold = 1;
597    for (i=0;i<N;i++)
598       mask[i] = sqrt(mask[i]);
599    compute_band_energies(st->mode, mask, bandM);
600    /*for (i=0;i<st->mode->nbEBands;i++)
601       printf ("%f %f ", bandE[i], bandM[i]);
602    printf ("\n");*/
603 #endif
604
605    /* Compute MDCTs of the pitch part */
606    if (has_pitch)
607    {
608       celt_word32_t curr_power, pitch_power=0;
609       /* Normalise the pitch vector as well (discard the energies) */
610       VARDECL(celt_ener_t, bandEp);
611       
612       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
613       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
614       compute_band_energies(st->mode, freq, bandEp);
615       normalise_bands(st->mode, freq, P, bandEp);
616       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
617       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
618       curr_power = bandE[0]+bandE[1]+bandE[2];
619       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
620       {
621          /* Pitch prediction */
622          has_pitch = compute_pitch_gain(st->mode, X, P, gains);
623       } else {
624          has_pitch = 0;
625       }
626    }
627    
628    encode_flags(&enc, intra_ener, has_pitch, shortBlocks, has_fold);
629    if (has_pitch)
630    {
631       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
632    } else {
633       for (i=0;i<st->mode->nbPBands;i++)
634          gains[i] = 0;
635       for (i=0;i<C*N;i++)
636          P[i] = 0;
637    }
638    if (shortBlocks)
639    {
640       ec_enc_bits(&enc, transient_shift, 2);
641       if (transient_shift)
642          ec_enc_uint(&enc, transient_time, N+st->overlap);
643    }
644
645 #ifdef STDIN_TUNING2
646    static int fine_quant[30];
647    static int pulses[30];
648    static int init=0;
649    if (!init)
650    {
651       for (i=0;i<st->mode->nbEBands;i++)
652          scanf("%d ", &fine_quant[i]);
653       for (i=0;i<st->mode->nbEBands;i++)
654          scanf("%d ", &pulses[i]);
655       init = 1;
656    }
657 #else
658    ALLOC(fine_quant, st->mode->nbEBands, int);
659    ALLOC(pulses, st->mode->nbEBands, int);
660 #endif
661
662    /* Bit allocation */
663    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
664    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
665    
666    ALLOC(offsets, st->mode->nbEBands, int);
667    ALLOC(stereo_mode, st->mode->nbEBands, int);
668    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
669
670    for (i=0;i<st->mode->nbEBands;i++)
671       offsets[i] = 0;
672    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
673    if (has_pitch)
674       bits -= st->mode->nbPBands;
675 #ifndef STDIN_TUNING
676    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
677 #endif
678
679    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
680
681    /* Residual quantisation */
682    if (C==1)
683       quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
684    else
685       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
686
687    /* Re-synthesis of the coded audio if required */
688    if (st->pitch_available>0 || optional_synthesis!=NULL)
689    {
690       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
691         st->pitch_available+=st->frame_size;
692
693       /* Synthesis */
694       denormalise_bands(st->mode, X, freq, bandE);
695       
696       
697       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
698       
699       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
700       /* De-emphasis and put everything back at the right place in the synthesis history */
701       if (optional_synthesis != NULL) {
702          for (c=0;c<C;c++)
703          {
704             int j;
705             for (j=0;j<N;j++)
706             {
707                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
708                                    preemph,st->preemph_memD[c]);
709                st->preemph_memD[c] = tmp;
710                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
711             }
712          }
713       }
714    }
715
716    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
717    /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
718       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
719
720    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
721    {
722       int val = 0;
723       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
724       {
725          ec_enc_uint(&enc, val, 2);
726          val = 1-val;
727       }
728    }
729    ec_enc_done(&enc);
730    {
731       /*unsigned char *data;*/
732       int nbBytes = ec_byte_bytes(&buf);
733       if (nbBytes > nbCompressedBytes)
734       {
735          celt_warning_int ("got too many bytes:", nbBytes);
736          RESTORE_STACK;
737          return CELT_INTERNAL_ERROR;
738       }
739    }
740
741    RESTORE_STACK;
742    return nbCompressedBytes;
743 }
744
745 #ifdef FIXED_POINT
746 #ifndef DISABLE_FLOAT_API
747 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
748 {
749    int j, ret;
750    const int C = CHANNELS(st->mode);
751    const int N = st->block_size;
752    VARDECL(celt_int16_t, in);
753    SAVE_STACK;
754    ALLOC(in, C*N, celt_int16_t);
755
756    for (j=0;j<C*N;j++)
757      in[j] = FLOAT2INT16(pcm[j]);
758
759    if (optional_synthesis != NULL) {
760      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
761       for (j=0;j<C*N;j++)
762          optional_synthesis[j]=in[j]*(1/32768.);
763    } else {
764      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
765    }
766    RESTORE_STACK;
767    return ret;
768
769 }
770 #endif /*DISABLE_FLOAT_API*/
771 #else
772 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
773 {
774    int j, ret;
775    VARDECL(celt_sig_t, in);
776    const int C = CHANNELS(st->mode);
777    const int N = st->block_size;
778    SAVE_STACK;
779    ALLOC(in, C*N, celt_sig_t);
780    for (j=0;j<C*N;j++) {
781      in[j] = SCALEOUT(pcm[j]);
782    }
783
784    if (optional_synthesis != NULL) {
785       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
786       for (j=0;j<C*N;j++)
787          optional_synthesis[j] = FLOAT2INT16(in[j]);
788    } else {
789       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
790    }
791    RESTORE_STACK;
792    return ret;
793 }
794 #endif
795
796 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
797 {
798    va_list ap;
799    va_start(ap, request);
800    switch (request)
801    {
802       case CELT_SET_COMPLEXITY_REQUEST:
803       {
804          int value = va_arg(ap, int);
805          if (value<0 || value>10)
806             goto bad_arg;
807          if (value<=2) {
808             st->pitch_enabled = 0; 
809             st->pitch_available = 0;
810          } else {
811               st->pitch_enabled = 1;
812               if (st->pitch_available<1)
813                 st->pitch_available = 1;
814          }   
815       }
816       break;
817       case CELT_SET_LTP_REQUEST:
818       {
819          int value = va_arg(ap, int);
820          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
821             goto bad_arg;
822          if (value==0)
823             st->pitch_enabled = 0;
824          else
825             st->pitch_enabled = 1;
826       }
827       break;
828       default:
829          goto bad_request;
830    }
831    va_end(ap);
832    return CELT_OK;
833 bad_arg:
834    va_end(ap);
835    return CELT_BAD_ARG;
836 bad_request:
837    va_end(ap);
838    return CELT_UNIMPLEMENTED;
839 }
840
841 /****************************************************************************/
842 /*                                                                          */
843 /*                                DECODER                                   */
844 /*                                                                          */
845 /****************************************************************************/
846 #ifdef NEW_PLC
847 #define DECODE_BUFFER_SIZE 2048
848 #else
849 #define DECODE_BUFFER_SIZE MAX_PERIOD
850 #endif
851
852 /** Decoder state 
853  @brief Decoder state
854  */
855 struct CELTDecoder {
856    const CELTMode *mode;
857    int frame_size;
858    int block_size;
859    int overlap;
860
861    ec_byte_buffer buf;
862    ec_enc         enc;
863
864    celt_sig_t * restrict preemph_memD;
865
866    celt_sig_t *out_mem;
867    celt_sig_t *decode_mem;
868
869    celt_word16_t *oldBandE;
870    
871    int last_pitch_index;
872 };
873
874 CELTDecoder *celt_decoder_create(const CELTMode *mode)
875 {
876    int N, C;
877    CELTDecoder *st;
878
879    if (check_mode(mode) != CELT_OK)
880       return NULL;
881
882    N = mode->mdctSize;
883    C = CHANNELS(mode);
884    st = celt_alloc(sizeof(CELTDecoder));
885    
886    st->mode = mode;
887    st->frame_size = N;
888    st->block_size = N;
889    st->overlap = mode->overlap;
890
891    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
892    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
893    
894    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
895
896    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
897
898    st->last_pitch_index = 0;
899    return st;
900 }
901
902 void celt_decoder_destroy(CELTDecoder *st)
903 {
904    if (st == NULL)
905    {
906       celt_warning("NULL passed to celt_encoder_destroy");
907       return;
908    }
909    if (check_mode(st->mode) != CELT_OK)
910       return;
911
912
913    celt_free(st->decode_mem);
914    
915    celt_free(st->oldBandE);
916    
917    celt_free(st->preemph_memD);
918
919    celt_free(st);
920 }
921
922 /** Handles lost packets by just copying past data with the same offset as the last
923     pitch period */
924 #ifdef NEW_PLC
925 #include "plc.c"
926 #else
927 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
928 {
929    int c, N;
930    int pitch_index;
931    int i, len;
932    VARDECL(celt_sig_t, freq);
933    const int C = CHANNELS(st->mode);
934    int offset;
935    SAVE_STACK;
936    N = st->block_size;
937    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
938    
939    len = N+st->mode->overlap;
940 #if 0
941    pitch_index = st->last_pitch_index;
942    
943    /* Use the pitch MDCT as the "guessed" signal */
944    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
945
946 #else
947    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
948    pitch_index = MAX_PERIOD-len-pitch_index;
949    offset = MAX_PERIOD-pitch_index;
950    while (offset+len >= MAX_PERIOD)
951       offset -= pitch_index;
952    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
953    for (i=0;i<N;i++)
954       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
955 #endif
956    
957    
958    
959    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
960    /* Compute inverse MDCTs */
961    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
962
963    for (c=0;c<C;c++)
964    {
965       int j;
966       for (j=0;j<N;j++)
967       {
968          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
969                                 preemph,st->preemph_memD[c]);
970          st->preemph_memD[c] = tmp;
971          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
972       }
973    }
974    RESTORE_STACK;
975 }
976 #endif
977
978 #ifdef FIXED_POINT
979 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
980 {
981 #else
982 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
983 {
984 #endif
985    int i, c, N, N4;
986    int has_pitch, has_fold;
987    int pitch_index;
988    int bits;
989    ec_dec dec;
990    ec_byte_buffer buf;
991    VARDECL(celt_sig_t, freq);
992    VARDECL(celt_norm_t, X);
993    VARDECL(celt_norm_t, P);
994    VARDECL(celt_ener_t, bandE);
995    VARDECL(celt_pgain_t, gains);
996    VARDECL(int, stereo_mode);
997    VARDECL(int, fine_quant);
998    VARDECL(int, pulses);
999    VARDECL(int, offsets);
1000
1001    int shortBlocks;
1002    int intra_ener;
1003    int transient_time;
1004    int transient_shift;
1005    const int C = CHANNELS(st->mode);
1006    SAVE_STACK;
1007
1008    if (check_mode(st->mode) != CELT_OK)
1009       return CELT_INVALID_MODE;
1010
1011    N = st->block_size;
1012    N4 = (N-st->overlap)>>1;
1013
1014    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1015    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
1016    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
1017    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1018    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
1019    
1020    if (check_mode(st->mode) != CELT_OK)
1021    {
1022       RESTORE_STACK;
1023       return CELT_INVALID_MODE;
1024    }
1025    if (data == NULL)
1026    {
1027       celt_decode_lost(st, pcm);
1028       RESTORE_STACK;
1029       return 0;
1030    }
1031    if (len<0) {
1032      RESTORE_STACK;
1033      return CELT_BAD_ARG;
1034    }
1035    
1036    ec_byte_readinit(&buf,(unsigned char*)data,len);
1037    ec_dec_init(&dec,&buf);
1038    
1039    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1040    if (shortBlocks)
1041    {
1042       transient_shift = ec_dec_bits(&dec, 2);
1043       if (transient_shift)
1044          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1045       else
1046          transient_time = 0;
1047    } else {
1048       transient_time = -1;
1049       transient_shift = 0;
1050    }
1051    
1052    if (has_pitch)
1053    {
1054       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1055       st->last_pitch_index = pitch_index;
1056    } else {
1057       pitch_index = 0;
1058       for (i=0;i<st->mode->nbPBands;i++)
1059          gains[i] = 0;
1060    }
1061
1062    ALLOC(fine_quant, st->mode->nbEBands, int);
1063    /* Get band energies */
1064    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, intra_ener, st->mode->prob, &dec);
1065    
1066    ALLOC(pulses, st->mode->nbEBands, int);
1067    ALLOC(offsets, st->mode->nbEBands, int);
1068    ALLOC(stereo_mode, st->mode->nbEBands, int);
1069    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1070
1071    for (i=0;i<st->mode->nbEBands;i++)
1072       offsets[i] = 0;
1073
1074    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1075    if (has_pitch)
1076       bits -= st->mode->nbPBands;
1077    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1078    /*bits = ec_dec_tell(&dec, 0);
1079    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1080    
1081    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1082
1083
1084    if (has_pitch) 
1085    {
1086       VARDECL(celt_ener_t, bandEp);
1087       
1088       /* Pitch MDCT */
1089       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1090       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1091       compute_band_energies(st->mode, freq, bandEp);
1092       normalise_bands(st->mode, freq, P, bandEp);
1093       /* Apply pitch gains */
1094    } else {
1095       for (i=0;i<C*N;i++)
1096          P[i] = 0;
1097    }
1098
1099    /* Decode fixed codebook and merge with pitch */
1100    if (C==1)
1101       unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1102    else
1103       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1104
1105    /* Synthesis */
1106    denormalise_bands(st->mode, X, freq, bandE);
1107
1108
1109    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1110    /* Compute inverse MDCTs */
1111    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1112
1113    for (c=0;c<C;c++)
1114    {
1115       int j;
1116       for (j=0;j<N;j++)
1117       {
1118          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1119                                 preemph,st->preemph_memD[c]);
1120          st->preemph_memD[c] = tmp;
1121          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1122       }
1123    }
1124
1125    {
1126       unsigned int val = 0;
1127       while (ec_dec_tell(&dec, 0) < len*8)
1128       {
1129          if (ec_dec_uint(&dec, 2) != val)
1130          {
1131             celt_warning("decode error");
1132             RESTORE_STACK;
1133             return CELT_CORRUPTED_DATA;
1134          }
1135          val = 1-val;
1136       }
1137    }
1138
1139    RESTORE_STACK;
1140    return 0;
1141    /*printf ("\n");*/
1142 }
1143
1144 #ifdef FIXED_POINT
1145 #ifndef DISABLE_FLOAT_API
1146 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1147 {
1148    int j, ret;
1149    const int C = CHANNELS(st->mode);
1150    const int N = st->block_size;
1151    VARDECL(celt_int16_t, out);
1152    SAVE_STACK;
1153    ALLOC(out, C*N, celt_int16_t);
1154
1155    ret=celt_decode(st, data, len, out);
1156
1157    for (j=0;j<C*N;j++)
1158      pcm[j]=out[j]*(1/32768.);
1159    RESTORE_STACK;
1160    return ret;
1161 }
1162 #endif /*DISABLE_FLOAT_API*/
1163 #else
1164 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1165 {
1166    int j, ret;
1167    VARDECL(celt_sig_t, out);
1168    const int C = CHANNELS(st->mode);
1169    const int N = st->block_size;
1170    SAVE_STACK;
1171    ALLOC(out, C*N, celt_sig_t);
1172
1173    ret=celt_decode_float(st, data, len, out);
1174
1175    for (j=0;j<C*N;j++)
1176      pcm[j] = FLOAT2INT16 (out[j]);
1177
1178    RESTORE_STACK;
1179    return ret;
1180 }
1181 #endif