Fixing stereo: Do not attempt to use more bits than are available.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67    
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int frame_size;
74    int block_size;
75    int overlap;
76    int channels;
77    
78    int pitch_enabled;
79    int pitch_available;
80    int delayedIntra;
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112    st->pitch_available = 1;
113    st->delayedIntra = 1;
114
115    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
116    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
117
118    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
119
120    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
121    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
122
123 #ifdef EXP_PSY
124    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
125    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
126 #endif
127
128    return st;
129 }
130
131 void celt_encoder_destroy(CELTEncoder *st)
132 {
133    if (st == NULL)
134    {
135       celt_warning("NULL passed to celt_encoder_destroy");
136       return;
137    }
138    if (check_mode(st->mode) != CELT_OK)
139       return;
140
141    celt_free(st->in_mem);
142    celt_free(st->out_mem);
143    
144    celt_free(st->oldBandE);
145    
146    celt_free(st->preemph_memE);
147    celt_free(st->preemph_memD);
148    
149 #ifdef EXP_PSY
150    celt_free (st->psy_mem);
151    psydecay_clear(&st->psy);
152 #endif
153    
154    celt_free(st);
155 }
156
157 static inline celt_int16_t FLOAT2INT16(float x)
158 {
159    x = x*32768.;
160    x = MAX32(x, -32768);
161    x = MIN32(x, 32767);
162    return (celt_int16_t)float2int(x);
163 }
164
165 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
166 {
167 #ifdef FIXED_POINT
168    x = PSHR32(x, SIG_SHIFT);
169    x = MAX32(x, -32768);
170    x = MIN32(x, 32767);
171    return EXTRACT16(x);
172 #else
173    return (celt_word16_t)x;
174 #endif
175 }
176
177 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
178 {
179    int c, i, n;
180    celt_word32_t ratio;
181    /* FIXME: Remove the floats here */
182    VARDECL(celt_word32_t, begin);
183    SAVE_STACK;
184    ALLOC(begin, len, celt_word32_t);
185    for (i=0;i<len;i++)
186       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
187    for (c=1;c<C;c++)
188    {
189       for (i=0;i<len;i++)
190          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
191    }
192    for (i=1;i<len;i++)
193       begin[i] = MAX32(begin[i-1],begin[i]);
194    n = -1;
195    for (i=8;i<len-8;i++)
196    {
197       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
198          n=i;
199    }
200    if (n<32)
201    {
202       n = -1;
203       ratio = 0;
204    } else {
205       ratio = DIV32(begin[len-1],1+begin[n-16]);
206    }
207    /*printf ("%d %f\n", n, ratio*ratio);*/
208    if (ratio < 0)
209       ratio = 0;
210    if (ratio > 1000)
211       ratio = 1000;
212    ratio *= ratio;
213    if (ratio < 50)
214       *transient_shift = 0;
215    else if (ratio < 256)
216       *transient_shift = 1;
217    else if (ratio < 4096)
218       *transient_shift = 2;
219    else
220       *transient_shift = 3;
221    *transient_time = n;
222    
223    RESTORE_STACK;
224    return ratio > 20;
225 }
226
227 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
228 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
229 {
230    const int C = CHANNELS(mode);
231    if (C==1 && !shortBlocks)
232    {
233       const mdct_lookup *lookup = MDCT(mode);
234       const int overlap = OVERLAP(mode);
235       mdct_forward(lookup, in, out, mode->window, overlap);
236    } else if (!shortBlocks) {
237       const mdct_lookup *lookup = MDCT(mode);
238       const int overlap = OVERLAP(mode);
239       const int N = FRAMESIZE(mode);
240       int c;
241       VARDECL(celt_word32_t, x);
242       VARDECL(celt_word32_t, tmp);
243       SAVE_STACK;
244       ALLOC(x, N+overlap, celt_word32_t);
245       ALLOC(tmp, N, celt_word32_t);
246       for (c=0;c<C;c++)
247       {
248          int j;
249          for (j=0;j<N+overlap;j++)
250             x[j] = in[C*j+c];
251          mdct_forward(lookup, x, tmp, mode->window, overlap);
252          /* Interleaving the sub-frames */
253          for (j=0;j<N;j++)
254             out[C*j+c] = tmp[j];
255       }
256       RESTORE_STACK;
257    } else {
258       const mdct_lookup *lookup = &mode->shortMdct;
259       const int overlap = mode->overlap;
260       const int N = mode->shortMdctSize;
261       int b, c;
262       VARDECL(celt_word32_t, x);
263       VARDECL(celt_word32_t, tmp);
264       SAVE_STACK;
265       ALLOC(x, N+overlap, celt_word32_t);
266       ALLOC(tmp, N, celt_word32_t);
267       for (c=0;c<C;c++)
268       {
269          int B = mode->nbShortMdcts;
270          for (b=0;b<B;b++)
271          {
272             int j;
273             for (j=0;j<N+overlap;j++)
274                x[j] = in[C*(b*N+j)+c];
275             mdct_forward(lookup, x, tmp, mode->window, overlap);
276             /* Interleaving the sub-frames */
277             for (j=0;j<N;j++)
278                out[C*(j*B+b)+c] = tmp[j];
279          }
280       }
281       RESTORE_STACK;
282    }
283 }
284
285 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
286 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
287 {
288    int c, N4;
289    const int C = CHANNELS(mode);
290    const int N = FRAMESIZE(mode);
291    const int overlap = OVERLAP(mode);
292    N4 = (N-overlap)>>1;
293    for (c=0;c<C;c++)
294    {
295       int j;
296       if (transient_shift==0 && C==1 && !shortBlocks) {
297          const mdct_lookup *lookup = MDCT(mode);
298          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
299       } else if (!shortBlocks) {
300          const mdct_lookup *lookup = MDCT(mode);
301          VARDECL(celt_word32_t, x);
302          VARDECL(celt_word32_t, tmp);
303          SAVE_STACK;
304          ALLOC(x, 2*N, celt_word32_t);
305          ALLOC(tmp, N, celt_word32_t);
306          /* De-interleaving the sub-frames */
307          for (j=0;j<N;j++)
308             tmp[j] = X[C*j+c];
309          /* Prevents problems from the imdct doing the overlap-add */
310          CELT_MEMSET(x+N4, 0, N);
311          mdct_backward(lookup, tmp, x, mode->window, overlap);
312          celt_assert(transient_shift == 0);
313          /* The first and last part would need to be set to zero if we actually
314             wanted to use them. */
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
317          for (j=0;j<overlap;j++)
318             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
319          for (j=0;j<2*N4;j++)
320             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
321          RESTORE_STACK;
322       } else {
323          int b;
324          const int N2 = mode->shortMdctSize;
325          const int B = mode->nbShortMdcts;
326          const mdct_lookup *lookup = &mode->shortMdct;
327          VARDECL(celt_word32_t, x);
328          VARDECL(celt_word32_t, tmp);
329          SAVE_STACK;
330          ALLOC(x, 2*N, celt_word32_t);
331          ALLOC(tmp, N, celt_word32_t);
332          /* Prevents problems from the imdct doing the overlap-add */
333          CELT_MEMSET(x+N4, 0, N2);
334          for (b=0;b<B;b++)
335          {
336             /* De-interleaving the sub-frames */
337             for (j=0;j<N2;j++)
338                tmp[j] = X[C*(j*B+b)+c];
339             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
340          }
341          if (transient_shift > 0)
342          {
343 #ifdef FIXED_POINT
344             for (j=0;j<16;j++)
345                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
346             for (j=transient_time;j<N+overlap;j++)
347                x[N4+j] = SHL32(x[N4+j], transient_shift);
348 #else
349             for (j=0;j<16;j++)
350                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
351             for (j=transient_time;j<N+overlap;j++)
352                x[N4+j] *= 1<<transient_shift;
353 #endif
354          }
355          /* The first and last part would need to be set to zero if we actually
356          wanted to use them. */
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
359          for (j=0;j<overlap;j++)
360             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
361          for (j=0;j<2*N4;j++)
362             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
363          RESTORE_STACK;
364       }
365    }
366 }
367
368 #define FLAG_NONE        0
369 #define FLAG_INTRA       1U<<16
370 #define FLAG_PITCH       1U<<15
371 #define FLAG_SHORT       1U<<14
372 #define FLAG_FOLD        1U<<13
373 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
374
375 celt_int32_t flaglist[8] = {
376       0 /*00  */ | FLAG_FOLD,
377       1 /*01  */ | FLAG_PITCH|FLAG_FOLD,
378       8 /*1000*/ | FLAG_NONE,
379       9 /*1001*/ | FLAG_SHORT|FLAG_FOLD,
380      10 /*1010*/ | FLAG_PITCH,
381      11 /*1011*/ | FLAG_INTRA,
382       6 /*110 */ | FLAG_INTRA|FLAG_FOLD,
383       7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
384 };
385
386 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
387 {
388    int i;
389    int flags=FLAG_NONE;
390    int flag_bits;
391    flags |= intra_ener   ? FLAG_INTRA : 0;
392    flags |= has_pitch    ? FLAG_PITCH : 0;
393    flags |= shortBlocks  ? FLAG_SHORT : 0;
394    flags |= has_fold     ? FLAG_FOLD  : 0;
395    for (i=0;i<8;i++)
396       if (flags == (flaglist[i]&FLAG_MASK))
397          break;
398    celt_assert(i<8);
399    flag_bits = flaglist[i]&0xf;
400    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
401    if (i<2)
402       ec_enc_bits(enc, flag_bits, 2);
403    else if (i<6)
404       ec_enc_bits(enc, flag_bits, 4);
405    else
406       ec_enc_bits(enc, flag_bits, 3);
407 }
408
409 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
410 {
411    int i;
412    int flag_bits;
413    flag_bits = ec_dec_bits(dec, 2);
414    /*printf ("(%d) ", flag_bits);*/
415    if (flag_bits==2)
416       flag_bits = (flag_bits<<2) | ec_dec_bits(dec, 2);
417    else if (flag_bits==3)
418       flag_bits = (flag_bits<<1) | ec_dec_bits(dec, 1);
419    for (i=0;i<8;i++)
420       if (flag_bits == (flaglist[i]&0xf))
421          break;
422    celt_assert(i<8);
423    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
424    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
425    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
426    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
427    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
428 }
429
430 #ifdef FIXED_POINT
431 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
432 {
433 #else
434 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
435 {
436 #endif
437    int i, c, N, N4;
438    int has_pitch;
439    int pitch_index;
440    int bits;
441    int has_fold=1;
442    ec_byte_buffer buf;
443    ec_enc         enc;
444    VARDECL(celt_sig_t, in);
445    VARDECL(celt_sig_t, freq);
446    VARDECL(celt_norm_t, X);
447    VARDECL(celt_norm_t, P);
448    VARDECL(celt_ener_t, bandE);
449    VARDECL(celt_pgain_t, gains);
450    VARDECL(int, stereo_mode);
451    VARDECL(int, fine_quant);
452    VARDECL(celt_word16_t, error);
453    VARDECL(int, pulses);
454    VARDECL(int, offsets);
455 #ifdef EXP_PSY
456    VARDECL(celt_word32_t, mask);
457    VARDECL(celt_word32_t, tonality);
458    VARDECL(celt_word32_t, bandM);
459    VARDECL(celt_ener_t, bandN);
460 #endif
461    int intra_ener = 0;
462    int shortBlocks=0;
463    int transient_time;
464    int transient_shift;
465    const int C = CHANNELS(st->mode);
466    SAVE_STACK;
467
468    if (check_mode(st->mode) != CELT_OK)
469       return CELT_INVALID_MODE;
470
471    if (nbCompressedBytes<0)
472      return CELT_BAD_ARG; 
473
474    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
475    CELT_MEMSET(compressed, 0, nbCompressedBytes);
476    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
477    ec_enc_init(&enc,&buf);
478
479    N = st->block_size;
480    N4 = (N-st->overlap)>>1;
481    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
482
483    CELT_COPY(in, st->in_mem, C*st->overlap);
484    for (c=0;c<C;c++)
485    {
486       const celt_word16_t * restrict pcmp = pcm+c;
487       celt_sig_t * restrict inp = in+C*st->overlap+c;
488       for (i=0;i<N;i++)
489       {
490          /* Apply pre-emphasis */
491          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
492          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
493          st->preemph_memE[c] = SCALEIN(*pcmp);
494          inp += C;
495          pcmp += C;
496       }
497    }
498    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
499    
500    /* Transient handling */
501    if (st->mode->nbShortMdcts > 1)
502    {
503       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
504       {
505 #ifndef FIXED_POINT
506          float gain_1;
507 #endif
508          /* Apply the inverse shaping window */
509          if (transient_shift)
510          {
511 #ifdef FIXED_POINT
512             for (c=0;c<C;c++)
513                for (i=0;i<16;i++)
514                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
515             for (c=0;c<C;c++)
516                for (i=transient_time;i<N+st->overlap;i++)
517                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
518 #else
519             for (c=0;c<C;c++)
520                for (i=0;i<16;i++)
521                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
522             gain_1 = 1./(1<<transient_shift);
523             for (c=0;c<C;c++)
524                for (i=transient_time;i<N+st->overlap;i++)
525                   in[C*i+c] *= gain_1;
526 #endif
527          }
528          shortBlocks = 1;
529       } else {
530          transient_time = -1;
531          transient_shift = 0;
532          shortBlocks = 0;
533       }
534    } else {
535       transient_time = -1;
536       transient_shift = 0;
537       shortBlocks = 0;
538    }
539
540    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
541    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
542    /* Compute MDCTs */
543    compute_mdcts(st->mode, shortBlocks, in, freq);
544    compute_band_energies(st->mode, freq, bandE);
545
546    intra_ener = st->delayedIntra;
547    if (intra_decision(bandE, st->oldBandE, st->mode->nbEBands) || shortBlocks)
548       st->delayedIntra = 1;
549    else
550       st->delayedIntra = 0;
551    /* Pitch analysis: we do it early to save on the peak stack space */
552    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
553    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
554 #ifdef EXP_PSY
555    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
556    {
557       VARDECL(celt_word16_t, X);
558       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
559       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
560       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
561    }
562 #else
563    if (has_pitch)
564    {
565       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
566    }
567 #endif
568
569 #ifdef EXP_PSY
570    ALLOC(mask, N, celt_sig_t);
571    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
572    /*for (i=0;i<256;i++)
573       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
574    printf ("\n");*/
575 #endif
576
577    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
578    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
579    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
580    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
581
582
583    /* Band normalisation */
584    normalise_bands(st->mode, freq, X, bandE);
585
586 #ifdef EXP_PSY
587    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
588    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
589    compute_noise_energies(st->mode, freq, tonality, bandN);
590
591    /*for (i=0;i<st->mode->nbEBands;i++)
592       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
593    printf ("\n");*/
594    has_fold = 0;
595    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
596       if (bandN[i] < .4*bandE[i])
597          has_fold++;
598    /*printf ("%d\n", has_fold);*/
599    if (has_fold>=2)
600       has_fold = 0;
601    else
602       has_fold = 1;
603    for (i=0;i<N;i++)
604       mask[i] = sqrt(mask[i]);
605    compute_band_energies(st->mode, mask, bandM);
606    /*for (i=0;i<st->mode->nbEBands;i++)
607       printf ("%f %f ", bandE[i], bandM[i]);
608    printf ("\n");*/
609 #endif
610
611    /* Compute MDCTs of the pitch part */
612    if (has_pitch)
613    {
614       celt_word32_t curr_power, pitch_power=0;
615       /* Normalise the pitch vector as well (discard the energies) */
616       VARDECL(celt_ener_t, bandEp);
617       
618       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
619       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
620       compute_band_energies(st->mode, freq, bandEp);
621       normalise_bands(st->mode, freq, P, bandEp);
622       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
623       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
624       curr_power = bandE[0]+bandE[1]+bandE[2];
625       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
626       {
627          /* Pitch prediction */
628          has_pitch = compute_pitch_gain(st->mode, X, P, gains);
629       } else {
630          has_pitch = 0;
631       }
632    }
633    
634    encode_flags(&enc, intra_ener, has_pitch, shortBlocks, has_fold);
635    if (has_pitch)
636    {
637       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
638    } else {
639       for (i=0;i<st->mode->nbPBands;i++)
640          gains[i] = 0;
641       for (i=0;i<C*N;i++)
642          P[i] = 0;
643    }
644    if (shortBlocks)
645    {
646       ec_enc_bits(&enc, transient_shift, 2);
647       if (transient_shift)
648          ec_enc_uint(&enc, transient_time, N+st->overlap);
649    }
650
651 #ifdef STDIN_TUNING2
652    static int fine_quant[30];
653    static int pulses[30];
654    static int init=0;
655    if (!init)
656    {
657       for (i=0;i<st->mode->nbEBands;i++)
658          scanf("%d ", &fine_quant[i]);
659       for (i=0;i<st->mode->nbEBands;i++)
660          scanf("%d ", &pulses[i]);
661       init = 1;
662    }
663 #else
664    ALLOC(fine_quant, st->mode->nbEBands, int);
665    ALLOC(pulses, st->mode->nbEBands, int);
666 #endif
667
668    /* Bit allocation */
669    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
670    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
671    
672    ALLOC(offsets, st->mode->nbEBands, int);
673    ALLOC(stereo_mode, st->mode->nbEBands, int);
674    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
675
676    for (i=0;i<st->mode->nbEBands;i++)
677       offsets[i] = 0;
678    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
679    if (has_pitch)
680       bits -= st->mode->nbPBands;
681 #ifndef STDIN_TUNING
682    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
683 #endif
684
685    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
686
687    /* Residual quantisation */
688    if (C==1)
689       quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
690    else
691       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
692
693    /* Re-synthesis of the coded audio if required */
694    if (st->pitch_available>0 || optional_synthesis!=NULL)
695    {
696       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
697         st->pitch_available+=st->frame_size;
698
699       /* Synthesis */
700       denormalise_bands(st->mode, X, freq, bandE);
701       
702       
703       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
704       
705       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
706       /* De-emphasis and put everything back at the right place in the synthesis history */
707       if (optional_synthesis != NULL) {
708          for (c=0;c<C;c++)
709          {
710             int j;
711             for (j=0;j<N;j++)
712             {
713                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
714                                    preemph,st->preemph_memD[c]);
715                st->preemph_memD[c] = tmp;
716                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
717             }
718          }
719       }
720    }
721
722    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
723    /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
724       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
725
726    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
727    {
728       int val = 0;
729       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
730       {
731          ec_enc_uint(&enc, val, 2);
732          val = 1-val;
733       }
734    }
735    ec_enc_done(&enc);
736    {
737       /*unsigned char *data;*/
738       int nbBytes = ec_byte_bytes(&buf);
739       if (nbBytes > nbCompressedBytes)
740       {
741          celt_warning_int ("got too many bytes:", nbBytes);
742          RESTORE_STACK;
743          return CELT_INTERNAL_ERROR;
744       }
745    }
746
747    RESTORE_STACK;
748    return nbCompressedBytes;
749 }
750
751 #ifdef FIXED_POINT
752 #ifndef DISABLE_FLOAT_API
753 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
754 {
755    int j, ret;
756    const int C = CHANNELS(st->mode);
757    const int N = st->block_size;
758    VARDECL(celt_int16_t, in);
759    SAVE_STACK;
760    ALLOC(in, C*N, celt_int16_t);
761
762    for (j=0;j<C*N;j++)
763      in[j] = FLOAT2INT16(pcm[j]);
764
765    if (optional_synthesis != NULL) {
766      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
767       for (j=0;j<C*N;j++)
768          optional_synthesis[j]=in[j]*(1/32768.);
769    } else {
770      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
771    }
772    RESTORE_STACK;
773    return ret;
774
775 }
776 #endif /*DISABLE_FLOAT_API*/
777 #else
778 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
779 {
780    int j, ret;
781    VARDECL(celt_sig_t, in);
782    const int C = CHANNELS(st->mode);
783    const int N = st->block_size;
784    SAVE_STACK;
785    ALLOC(in, C*N, celt_sig_t);
786    for (j=0;j<C*N;j++) {
787      in[j] = SCALEOUT(pcm[j]);
788    }
789
790    if (optional_synthesis != NULL) {
791       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
792       for (j=0;j<C*N;j++)
793          optional_synthesis[j] = FLOAT2INT16(in[j]);
794    } else {
795       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
796    }
797    RESTORE_STACK;
798    return ret;
799 }
800 #endif
801
802 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
803 {
804    va_list ap;
805    va_start(ap, request);
806    switch (request)
807    {
808       case CELT_SET_COMPLEXITY_REQUEST:
809       {
810          int value = va_arg(ap, int);
811          if (value<0 || value>10)
812             goto bad_arg;
813          if (value<=2) {
814             st->pitch_enabled = 0; 
815             st->pitch_available = 0;
816          } else {
817               st->pitch_enabled = 1;
818               if (st->pitch_available<1)
819                 st->pitch_available = 1;
820          }   
821       }
822       break;
823       case CELT_SET_LTP_REQUEST:
824       {
825          int value = va_arg(ap, int);
826          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
827             goto bad_arg;
828          if (value==0)
829             st->pitch_enabled = 0;
830          else
831             st->pitch_enabled = 1;
832       }
833       break;
834       default:
835          goto bad_request;
836    }
837    va_end(ap);
838    return CELT_OK;
839 bad_arg:
840    va_end(ap);
841    return CELT_BAD_ARG;
842 bad_request:
843    va_end(ap);
844    return CELT_UNIMPLEMENTED;
845 }
846
847 /****************************************************************************/
848 /*                                                                          */
849 /*                                DECODER                                   */
850 /*                                                                          */
851 /****************************************************************************/
852 #ifdef NEW_PLC
853 #define DECODE_BUFFER_SIZE 2048
854 #else
855 #define DECODE_BUFFER_SIZE MAX_PERIOD
856 #endif
857
858 /** Decoder state 
859  @brief Decoder state
860  */
861 struct CELTDecoder {
862    const CELTMode *mode;
863    int frame_size;
864    int block_size;
865    int overlap;
866
867    ec_byte_buffer buf;
868    ec_enc         enc;
869
870    celt_sig_t * restrict preemph_memD;
871
872    celt_sig_t *out_mem;
873    celt_sig_t *decode_mem;
874
875    celt_word16_t *oldBandE;
876    
877    int last_pitch_index;
878 };
879
880 CELTDecoder *celt_decoder_create(const CELTMode *mode)
881 {
882    int N, C;
883    CELTDecoder *st;
884
885    if (check_mode(mode) != CELT_OK)
886       return NULL;
887
888    N = mode->mdctSize;
889    C = CHANNELS(mode);
890    st = celt_alloc(sizeof(CELTDecoder));
891    
892    st->mode = mode;
893    st->frame_size = N;
894    st->block_size = N;
895    st->overlap = mode->overlap;
896
897    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
898    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
899    
900    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
901
902    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
903
904    st->last_pitch_index = 0;
905    return st;
906 }
907
908 void celt_decoder_destroy(CELTDecoder *st)
909 {
910    if (st == NULL)
911    {
912       celt_warning("NULL passed to celt_encoder_destroy");
913       return;
914    }
915    if (check_mode(st->mode) != CELT_OK)
916       return;
917
918
919    celt_free(st->decode_mem);
920    
921    celt_free(st->oldBandE);
922    
923    celt_free(st->preemph_memD);
924
925    celt_free(st);
926 }
927
928 /** Handles lost packets by just copying past data with the same offset as the last
929     pitch period */
930 #ifdef NEW_PLC
931 #include "plc.c"
932 #else
933 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
934 {
935    int c, N;
936    int pitch_index;
937    int i, len;
938    VARDECL(celt_sig_t, freq);
939    const int C = CHANNELS(st->mode);
940    int offset;
941    SAVE_STACK;
942    N = st->block_size;
943    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
944    
945    len = N+st->mode->overlap;
946 #if 0
947    pitch_index = st->last_pitch_index;
948    
949    /* Use the pitch MDCT as the "guessed" signal */
950    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
951
952 #else
953    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
954    pitch_index = MAX_PERIOD-len-pitch_index;
955    offset = MAX_PERIOD-pitch_index;
956    while (offset+len >= MAX_PERIOD)
957       offset -= pitch_index;
958    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
959    for (i=0;i<N;i++)
960       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
961 #endif
962    
963    
964    
965    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
966    /* Compute inverse MDCTs */
967    compute_inv_mdcts(st->mode, 0, freq, -1, 0, st->out_mem);
968
969    for (c=0;c<C;c++)
970    {
971       int j;
972       for (j=0;j<N;j++)
973       {
974          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
975                                 preemph,st->preemph_memD[c]);
976          st->preemph_memD[c] = tmp;
977          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
978       }
979    }
980    RESTORE_STACK;
981 }
982 #endif
983
984 #ifdef FIXED_POINT
985 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
986 {
987 #else
988 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
989 {
990 #endif
991    int i, c, N, N4;
992    int has_pitch, has_fold;
993    int pitch_index;
994    int bits;
995    ec_dec dec;
996    ec_byte_buffer buf;
997    VARDECL(celt_sig_t, freq);
998    VARDECL(celt_norm_t, X);
999    VARDECL(celt_norm_t, P);
1000    VARDECL(celt_ener_t, bandE);
1001    VARDECL(celt_pgain_t, gains);
1002    VARDECL(int, stereo_mode);
1003    VARDECL(int, fine_quant);
1004    VARDECL(int, pulses);
1005    VARDECL(int, offsets);
1006
1007    int shortBlocks;
1008    int intra_ener;
1009    int transient_time;
1010    int transient_shift;
1011    const int C = CHANNELS(st->mode);
1012    SAVE_STACK;
1013
1014    if (check_mode(st->mode) != CELT_OK)
1015       return CELT_INVALID_MODE;
1016
1017    N = st->block_size;
1018    N4 = (N-st->overlap)>>1;
1019
1020    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1021    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
1022    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
1023    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1024    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
1025    
1026    if (check_mode(st->mode) != CELT_OK)
1027    {
1028       RESTORE_STACK;
1029       return CELT_INVALID_MODE;
1030    }
1031    if (data == NULL)
1032    {
1033       celt_decode_lost(st, pcm);
1034       RESTORE_STACK;
1035       return 0;
1036    }
1037    if (len<0) {
1038      RESTORE_STACK;
1039      return CELT_BAD_ARG;
1040    }
1041    
1042    ec_byte_readinit(&buf,(unsigned char*)data,len);
1043    ec_dec_init(&dec,&buf);
1044    
1045    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1046    if (shortBlocks)
1047    {
1048       transient_shift = ec_dec_bits(&dec, 2);
1049       if (transient_shift)
1050          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1051       else
1052          transient_time = 0;
1053    } else {
1054       transient_time = -1;
1055       transient_shift = 0;
1056    }
1057    
1058    if (has_pitch)
1059    {
1060       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1061       st->last_pitch_index = pitch_index;
1062    } else {
1063       pitch_index = 0;
1064       for (i=0;i<st->mode->nbPBands;i++)
1065          gains[i] = 0;
1066    }
1067
1068    ALLOC(fine_quant, st->mode->nbEBands, int);
1069    /* Get band energies */
1070    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, intra_ener, st->mode->prob, &dec);
1071    
1072    ALLOC(pulses, st->mode->nbEBands, int);
1073    ALLOC(offsets, st->mode->nbEBands, int);
1074    ALLOC(stereo_mode, st->mode->nbEBands, int);
1075    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1076
1077    for (i=0;i<st->mode->nbEBands;i++)
1078       offsets[i] = 0;
1079
1080    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1081    if (has_pitch)
1082       bits -= st->mode->nbPBands;
1083    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1084    /*bits = ec_dec_tell(&dec, 0);
1085    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1086    
1087    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1088
1089
1090    if (has_pitch) 
1091    {
1092       VARDECL(celt_ener_t, bandEp);
1093       
1094       /* Pitch MDCT */
1095       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1096       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1097       compute_band_energies(st->mode, freq, bandEp);
1098       normalise_bands(st->mode, freq, P, bandEp);
1099       /* Apply pitch gains */
1100    } else {
1101       for (i=0;i<C*N;i++)
1102          P[i] = 0;
1103    }
1104
1105    /* Decode fixed codebook and merge with pitch */
1106    if (C==1)
1107       unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1108    else
1109       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1110
1111    /* Synthesis */
1112    denormalise_bands(st->mode, X, freq, bandE);
1113
1114
1115    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1116    /* Compute inverse MDCTs */
1117    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1118
1119    for (c=0;c<C;c++)
1120    {
1121       int j;
1122       for (j=0;j<N;j++)
1123       {
1124          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1125                                 preemph,st->preemph_memD[c]);
1126          st->preemph_memD[c] = tmp;
1127          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1128       }
1129    }
1130
1131    {
1132       unsigned int val = 0;
1133       while (ec_dec_tell(&dec, 0) < len*8)
1134       {
1135          if (ec_dec_uint(&dec, 2) != val)
1136          {
1137             celt_warning("decode error");
1138             RESTORE_STACK;
1139             return CELT_CORRUPTED_DATA;
1140          }
1141          val = 1-val;
1142       }
1143    }
1144
1145    RESTORE_STACK;
1146    return 0;
1147    /*printf ("\n");*/
1148 }
1149
1150 #ifdef FIXED_POINT
1151 #ifndef DISABLE_FLOAT_API
1152 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1153 {
1154    int j, ret;
1155    const int C = CHANNELS(st->mode);
1156    const int N = st->block_size;
1157    VARDECL(celt_int16_t, out);
1158    SAVE_STACK;
1159    ALLOC(out, C*N, celt_int16_t);
1160
1161    ret=celt_decode(st, data, len, out);
1162
1163    for (j=0;j<C*N;j++)
1164      pcm[j]=out[j]*(1/32768.);
1165    RESTORE_STACK;
1166    return ret;
1167 }
1168 #endif /*DISABLE_FLOAT_API*/
1169 #else
1170 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1171 {
1172    int j, ret;
1173    VARDECL(celt_sig_t, out);
1174    const int C = CHANNELS(st->mode);
1175    const int N = st->block_size;
1176    SAVE_STACK;
1177    ALLOC(out, C*N, celt_sig_t);
1178
1179    ret=celt_decode_float(st, data, len, out);
1180
1181    for (j=0;j<C*N;j++)
1182      pcm[j] = FLOAT2INT16 (out[j]);
1183
1184    RESTORE_STACK;
1185    return ret;
1186 }
1187 #endif