Dynamically selecting intra energy based on energy variations from the previous
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67    
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int frame_size;
74    int block_size;
75    int overlap;
76    int channels;
77    
78    int pitch_enabled;
79    int pitch_available;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112    st->pitch_available = 1;
113
114    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
115    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
116
117    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
118
119    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
120    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
121
122 #ifdef EXP_PSY
123    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
124    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
125 #endif
126
127    return st;
128 }
129
130 void celt_encoder_destroy(CELTEncoder *st)
131 {
132    if (st == NULL)
133    {
134       celt_warning("NULL passed to celt_encoder_destroy");
135       return;
136    }
137    if (check_mode(st->mode) != CELT_OK)
138       return;
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t FLOAT2INT16(float x)
157 {
158    x = x*32768.;
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161    return (celt_int16_t)float2int(x);
162 }
163
164 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
165 {
166 #ifdef FIXED_POINT
167    x = PSHR32(x, SIG_SHIFT);
168    x = MAX32(x, -32768);
169    x = MIN32(x, 32767);
170    return EXTRACT16(x);
171 #else
172    return (celt_word16_t)x;
173 #endif
174 }
175
176 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
177 {
178    int c, i, n;
179    celt_word32_t ratio;
180    /* FIXME: Remove the floats here */
181    VARDECL(celt_word32_t, begin);
182    SAVE_STACK;
183    ALLOC(begin, len, celt_word32_t);
184    for (i=0;i<len;i++)
185       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
186    for (c=1;c<C;c++)
187    {
188       for (i=0;i<len;i++)
189          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
190    }
191    for (i=1;i<len;i++)
192       begin[i] = MAX32(begin[i-1],begin[i]);
193    n = -1;
194    for (i=8;i<len-8;i++)
195    {
196       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
197          n=i;
198    }
199    if (n<32)
200    {
201       n = -1;
202       ratio = 0;
203    } else {
204       ratio = DIV32(begin[len-1],1+begin[n-16]);
205    }
206    /*printf ("%d %f\n", n, ratio*ratio);*/
207    if (ratio < 0)
208       ratio = 0;
209    if (ratio > 1000)
210       ratio = 1000;
211    ratio *= ratio;
212    if (ratio < 50)
213       *transient_shift = 0;
214    else if (ratio < 256)
215       *transient_shift = 1;
216    else if (ratio < 4096)
217       *transient_shift = 2;
218    else
219       *transient_shift = 3;
220    *transient_time = n;
221    
222    RESTORE_STACK;
223    return ratio > 20;
224 }
225
226 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
227 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
228 {
229    const int C = CHANNELS(mode);
230    if (C==1 && !shortBlocks)
231    {
232       const mdct_lookup *lookup = MDCT(mode);
233       const int overlap = OVERLAP(mode);
234       mdct_forward(lookup, in, out, mode->window, overlap);
235    } else if (!shortBlocks) {
236       const mdct_lookup *lookup = MDCT(mode);
237       const int overlap = OVERLAP(mode);
238       const int N = FRAMESIZE(mode);
239       int c;
240       VARDECL(celt_word32_t, x);
241       VARDECL(celt_word32_t, tmp);
242       SAVE_STACK;
243       ALLOC(x, N+overlap, celt_word32_t);
244       ALLOC(tmp, N, celt_word32_t);
245       for (c=0;c<C;c++)
246       {
247          int j;
248          for (j=0;j<N+overlap;j++)
249             x[j] = in[C*j+c];
250          mdct_forward(lookup, x, tmp, mode->window, overlap);
251          /* Interleaving the sub-frames */
252          for (j=0;j<N;j++)
253             out[C*j+c] = tmp[j];
254       }
255       RESTORE_STACK;
256    } else {
257       const mdct_lookup *lookup = &mode->shortMdct;
258       const int overlap = mode->overlap;
259       const int N = mode->shortMdctSize;
260       int b, c;
261       VARDECL(celt_word32_t, x);
262       VARDECL(celt_word32_t, tmp);
263       SAVE_STACK;
264       ALLOC(x, N+overlap, celt_word32_t);
265       ALLOC(tmp, N, celt_word32_t);
266       for (c=0;c<C;c++)
267       {
268          int B = mode->nbShortMdcts;
269          for (b=0;b<B;b++)
270          {
271             int j;
272             for (j=0;j<N+overlap;j++)
273                x[j] = in[C*(b*N+j)+c];
274             mdct_forward(lookup, x, tmp, mode->window, overlap);
275             /* Interleaving the sub-frames */
276             for (j=0;j<N;j++)
277                out[C*(j*B+b)+c] = tmp[j];
278          }
279       }
280       RESTORE_STACK;
281    }
282 }
283
284 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
285 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
286 {
287    int c, N4;
288    const int C = CHANNELS(mode);
289    const int N = FRAMESIZE(mode);
290    const int overlap = OVERLAP(mode);
291    N4 = (N-overlap)>>1;
292    for (c=0;c<C;c++)
293    {
294       int j;
295       if (transient_shift==0 && C==1 && !shortBlocks) {
296          const mdct_lookup *lookup = MDCT(mode);
297          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
298       } else if (!shortBlocks) {
299          const mdct_lookup *lookup = MDCT(mode);
300          VARDECL(celt_word32_t, x);
301          VARDECL(celt_word32_t, tmp);
302          SAVE_STACK;
303          ALLOC(x, 2*N, celt_word32_t);
304          ALLOC(tmp, N, celt_word32_t);
305          /* De-interleaving the sub-frames */
306          for (j=0;j<N;j++)
307             tmp[j] = X[C*j+c];
308          /* Prevents problems from the imdct doing the overlap-add */
309          CELT_MEMSET(x+N4, 0, N);
310          mdct_backward(lookup, tmp, x, mode->window, overlap);
311          celt_assert(transient_shift == 0);
312          /* The first and last part would need to be set to zero if we actually
313             wanted to use them. */
314          for (j=0;j<overlap;j++)
315             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
316          for (j=0;j<overlap;j++)
317             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
318          for (j=0;j<2*N4;j++)
319             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
320          RESTORE_STACK;
321       } else {
322          int b;
323          const int N2 = mode->shortMdctSize;
324          const int B = mode->nbShortMdcts;
325          const mdct_lookup *lookup = &mode->shortMdct;
326          VARDECL(celt_word32_t, x);
327          VARDECL(celt_word32_t, tmp);
328          SAVE_STACK;
329          ALLOC(x, 2*N, celt_word32_t);
330          ALLOC(tmp, N, celt_word32_t);
331          /* Prevents problems from the imdct doing the overlap-add */
332          CELT_MEMSET(x+N4, 0, N2);
333          for (b=0;b<B;b++)
334          {
335             /* De-interleaving the sub-frames */
336             for (j=0;j<N2;j++)
337                tmp[j] = X[C*(j*B+b)+c];
338             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
339          }
340          if (transient_shift > 0)
341          {
342 #ifdef FIXED_POINT
343             for (j=0;j<16;j++)
344                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
345             for (j=transient_time;j<N+overlap;j++)
346                x[N4+j] = SHL32(x[N4+j], transient_shift);
347 #else
348             for (j=0;j<16;j++)
349                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
350             for (j=transient_time;j<N+overlap;j++)
351                x[N4+j] *= 1<<transient_shift;
352 #endif
353          }
354          /* The first and last part would need to be set to zero if we actually
355          wanted to use them. */
356          for (j=0;j<overlap;j++)
357             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
358          for (j=0;j<overlap;j++)
359             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
360          for (j=0;j<2*N4;j++)
361             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
362          RESTORE_STACK;
363       }
364    }
365 }
366
367 #define FLAG_NONE        0
368 #define FLAG_INTRA       1U<<16
369 #define FLAG_PITCH       1U<<15
370 #define FLAG_SHORT       1U<<14
371 #define FLAG_FOLD        1U<<13
372 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
373
374 celt_int32_t flaglist[8] = {
375       0 /*00  */ | FLAG_FOLD,
376       1 /*01  */ | FLAG_PITCH|FLAG_FOLD,
377       8 /*1000*/ | FLAG_NONE,
378       9 /*1001*/ | FLAG_SHORT|FLAG_FOLD,
379      10 /*1010*/ | FLAG_PITCH,
380      11 /*1011*/ | FLAG_INTRA,
381       6 /*110 */ | FLAG_INTRA|FLAG_FOLD,
382       7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
383 };
384
385 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
386 {
387    int i;
388    int flags=FLAG_NONE;
389    int flag_bits;
390    flags |= intra_ener   ? FLAG_INTRA : 0;
391    flags |= has_pitch    ? FLAG_PITCH : 0;
392    flags |= shortBlocks  ? FLAG_SHORT : 0;
393    flags |= has_fold     ? FLAG_FOLD  : 0;
394    for (i=0;i<8;i++)
395       if (flags == (flaglist[i]&FLAG_MASK))
396          break;
397    celt_assert(i<8);
398    flag_bits = flaglist[i]&0xf;
399    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
400    if (i<2)
401       ec_enc_bits(enc, flag_bits, 2);
402    else if (i<6)
403       ec_enc_bits(enc, flag_bits, 4);
404    else
405       ec_enc_bits(enc, flag_bits, 3);
406 }
407
408 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
409 {
410    int i;
411    int flag_bits;
412    flag_bits = ec_dec_bits(dec, 2);
413    /*printf ("(%d) ", flag_bits);*/
414    if (flag_bits==2)
415       flag_bits = (flag_bits<<2) | ec_dec_bits(dec, 2);
416    else if (flag_bits==3)
417       flag_bits = (flag_bits<<1) | ec_dec_bits(dec, 1);
418    for (i=0;i<8;i++)
419       if (flag_bits == (flaglist[i]&0xf))
420          break;
421    celt_assert(i<8);
422    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
423    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
424    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
425    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
426    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
427 }
428
429 #ifdef FIXED_POINT
430 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
431 {
432 #else
433 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
434 {
435 #endif
436    int i, c, N, N4;
437    int has_pitch;
438    int pitch_index;
439    int bits;
440    int has_fold=1;
441    ec_byte_buffer buf;
442    ec_enc         enc;
443    VARDECL(celt_sig_t, in);
444    VARDECL(celt_sig_t, freq);
445    VARDECL(celt_norm_t, X);
446    VARDECL(celt_norm_t, P);
447    VARDECL(celt_ener_t, bandE);
448    VARDECL(celt_pgain_t, gains);
449    VARDECL(int, stereo_mode);
450    VARDECL(int, fine_quant);
451    VARDECL(celt_word16_t, error);
452    VARDECL(int, pulses);
453    VARDECL(int, offsets);
454 #ifdef EXP_PSY
455    VARDECL(celt_word32_t, mask);
456    VARDECL(celt_word32_t, tonality);
457    VARDECL(celt_word32_t, bandM);
458    VARDECL(celt_ener_t, bandN);
459 #endif
460    int intra_ener = 0;
461    int shortBlocks=0;
462    int transient_time;
463    int transient_shift;
464    const int C = CHANNELS(st->mode);
465    SAVE_STACK;
466
467    if (check_mode(st->mode) != CELT_OK)
468       return CELT_INVALID_MODE;
469
470    if (nbCompressedBytes<0)
471      return CELT_BAD_ARG; 
472
473    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
474    CELT_MEMSET(compressed, 0, nbCompressedBytes);
475    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
476    ec_enc_init(&enc,&buf);
477
478    N = st->block_size;
479    N4 = (N-st->overlap)>>1;
480    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
481
482    CELT_COPY(in, st->in_mem, C*st->overlap);
483    for (c=0;c<C;c++)
484    {
485       const celt_word16_t * restrict pcmp = pcm+c;
486       celt_sig_t * restrict inp = in+C*st->overlap+c;
487       for (i=0;i<N;i++)
488       {
489          /* Apply pre-emphasis */
490          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
491          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
492          st->preemph_memE[c] = SCALEIN(*pcmp);
493          inp += C;
494          pcmp += C;
495       }
496    }
497    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
498    
499    /* Transient handling */
500    if (st->mode->nbShortMdcts > 1)
501    {
502       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
503       {
504 #ifndef FIXED_POINT
505          float gain_1;
506 #endif
507          /* Apply the inverse shaping window */
508          if (transient_shift)
509          {
510 #ifdef FIXED_POINT
511             for (c=0;c<C;c++)
512                for (i=0;i<16;i++)
513                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
514             for (c=0;c<C;c++)
515                for (i=transient_time;i<N+st->overlap;i++)
516                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
517 #else
518             for (c=0;c<C;c++)
519                for (i=0;i<16;i++)
520                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
521             gain_1 = 1./(1<<transient_shift);
522             for (c=0;c<C;c++)
523                for (i=transient_time;i<N+st->overlap;i++)
524                   in[C*i+c] *= gain_1;
525 #endif
526          }
527          shortBlocks = 1;
528       } else {
529          transient_time = -1;
530          transient_shift = 0;
531          shortBlocks = 0;
532       }
533    } else {
534       transient_time = -1;
535       transient_shift = 0;
536       shortBlocks = 0;
537    }
538
539    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
540    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
541    /* Compute MDCTs */
542    compute_mdcts(st->mode, shortBlocks, in, freq);
543    compute_band_energies(st->mode, freq, bandE);
544
545    if (intra_decision(bandE, st->oldBandE, st->mode->nbEBands) || shortBlocks)
546       intra_ener = 1;
547    else
548       intra_ener = 0;
549
550    /* Pitch analysis: we do it early to save on the peak stack space */
551    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
552    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
553 #ifdef EXP_PSY
554    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
555    {
556       VARDECL(celt_word16_t, X);
557       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
558       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
559       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
560    }
561 #else
562    if (has_pitch)
563    {
564       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
565    }
566 #endif
567
568 #ifdef EXP_PSY
569    ALLOC(mask, N, celt_sig_t);
570    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
571    /*for (i=0;i<256;i++)
572       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
573    printf ("\n");*/
574 #endif
575
576    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
577    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
578    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
579    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
580
581
582    /* Band normalisation */
583    normalise_bands(st->mode, freq, X, bandE);
584
585 #ifdef EXP_PSY
586    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
587    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
588    compute_noise_energies(st->mode, freq, tonality, bandN);
589
590    /*for (i=0;i<st->mode->nbEBands;i++)
591       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
592    printf ("\n");*/
593    has_fold = 0;
594    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
595       if (bandN[i] < .4*bandE[i])
596          has_fold++;
597    /*printf ("%d\n", has_fold);*/
598    if (has_fold>=2)
599       has_fold = 0;
600    else
601       has_fold = 1;
602    for (i=0;i<N;i++)
603       mask[i] = sqrt(mask[i]);
604    compute_band_energies(st->mode, mask, bandM);
605    /*for (i=0;i<st->mode->nbEBands;i++)
606       printf ("%f %f ", bandE[i], bandM[i]);
607    printf ("\n");*/
608 #endif
609
610    /* Compute MDCTs of the pitch part */
611    if (has_pitch)
612    {
613       celt_word32_t curr_power, pitch_power=0;
614       /* Normalise the pitch vector as well (discard the energies) */
615       VARDECL(celt_ener_t, bandEp);
616       
617       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
618       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
619       compute_band_energies(st->mode, freq, bandEp);
620       normalise_bands(st->mode, freq, P, bandEp);
621       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
622       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
623       curr_power = bandE[0]+bandE[1]+bandE[2];
624       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
625       {
626          /* Pitch prediction */
627          has_pitch = compute_pitch_gain(st->mode, X, P, gains);
628       } else {
629          has_pitch = 0;
630       }
631    }
632    
633    encode_flags(&enc, intra_ener, has_pitch, shortBlocks, has_fold);
634    if (has_pitch)
635    {
636       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
637    } else {
638       for (i=0;i<st->mode->nbPBands;i++)
639          gains[i] = 0;
640       for (i=0;i<C*N;i++)
641          P[i] = 0;
642    }
643    if (shortBlocks)
644    {
645       ec_enc_bits(&enc, transient_shift, 2);
646       if (transient_shift)
647          ec_enc_uint(&enc, transient_time, N+st->overlap);
648    }
649
650 #ifdef STDIN_TUNING2
651    static int fine_quant[30];
652    static int pulses[30];
653    static int init=0;
654    if (!init)
655    {
656       for (i=0;i<st->mode->nbEBands;i++)
657          scanf("%d ", &fine_quant[i]);
658       for (i=0;i<st->mode->nbEBands;i++)
659          scanf("%d ", &pulses[i]);
660       init = 1;
661    }
662 #else
663    ALLOC(fine_quant, st->mode->nbEBands, int);
664    ALLOC(pulses, st->mode->nbEBands, int);
665 #endif
666
667    /* Bit allocation */
668    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
669    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
670    
671    ALLOC(offsets, st->mode->nbEBands, int);
672    ALLOC(stereo_mode, st->mode->nbEBands, int);
673    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
674
675    for (i=0;i<st->mode->nbEBands;i++)
676       offsets[i] = 0;
677    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
678    if (has_pitch)
679       bits -= st->mode->nbPBands;
680 #ifndef STDIN_TUNING
681    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
682 #endif
683
684    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
685
686    /* Residual quantisation */
687    if (C==1)
688       quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
689    else
690       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
691
692    /* Re-synthesis of the coded audio if required */
693    if (st->pitch_available>0 || optional_synthesis!=NULL)
694    {
695       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
696         st->pitch_available+=st->frame_size;
697
698       /* Synthesis */
699       denormalise_bands(st->mode, X, freq, bandE);
700       
701       
702       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
703       
704       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
705       /* De-emphasis and put everything back at the right place in the synthesis history */
706       if (optional_synthesis != NULL) {
707          for (c=0;c<C;c++)
708          {
709             int j;
710             for (j=0;j<N;j++)
711             {
712                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
713                                    preemph,st->preemph_memD[c]);
714                st->preemph_memD[c] = tmp;
715                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
716             }
717          }
718       }
719    }
720
721    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
722    /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
723       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
724
725    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
726    {
727       int val = 0;
728       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
729       {
730          ec_enc_uint(&enc, val, 2);
731          val = 1-val;
732       }
733    }
734    ec_enc_done(&enc);
735    {
736       /*unsigned char *data;*/
737       int nbBytes = ec_byte_bytes(&buf);
738       if (nbBytes > nbCompressedBytes)
739       {
740          celt_warning_int ("got too many bytes:", nbBytes);
741          RESTORE_STACK;
742          return CELT_INTERNAL_ERROR;
743       }
744    }
745
746    RESTORE_STACK;
747    return nbCompressedBytes;
748 }
749
750 #ifdef FIXED_POINT
751 #ifndef DISABLE_FLOAT_API
752 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
753 {
754    int j, ret;
755    const int C = CHANNELS(st->mode);
756    const int N = st->block_size;
757    VARDECL(celt_int16_t, in);
758    SAVE_STACK;
759    ALLOC(in, C*N, celt_int16_t);
760
761    for (j=0;j<C*N;j++)
762      in[j] = FLOAT2INT16(pcm[j]);
763
764    if (optional_synthesis != NULL) {
765      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
766       for (j=0;j<C*N;j++)
767          optional_synthesis[j]=in[j]*(1/32768.);
768    } else {
769      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
770    }
771    RESTORE_STACK;
772    return ret;
773
774 }
775 #endif /*DISABLE_FLOAT_API*/
776 #else
777 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
778 {
779    int j, ret;
780    VARDECL(celt_sig_t, in);
781    const int C = CHANNELS(st->mode);
782    const int N = st->block_size;
783    SAVE_STACK;
784    ALLOC(in, C*N, celt_sig_t);
785    for (j=0;j<C*N;j++) {
786      in[j] = SCALEOUT(pcm[j]);
787    }
788
789    if (optional_synthesis != NULL) {
790       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
791       for (j=0;j<C*N;j++)
792          optional_synthesis[j] = FLOAT2INT16(in[j]);
793    } else {
794       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
795    }
796    RESTORE_STACK;
797    return ret;
798 }
799 #endif
800
801 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
802 {
803    va_list ap;
804    va_start(ap, request);
805    switch (request)
806    {
807       case CELT_SET_COMPLEXITY_REQUEST:
808       {
809          int value = va_arg(ap, int);
810          if (value<0 || value>10)
811             goto bad_arg;
812          if (value<=2) {
813             st->pitch_enabled = 0; 
814             st->pitch_available = 0;
815          } else {
816               st->pitch_enabled = 1;
817               if (st->pitch_available<1)
818                 st->pitch_available = 1;
819          }   
820       }
821       break;
822       case CELT_SET_LTP_REQUEST:
823       {
824          int value = va_arg(ap, int);
825          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
826             goto bad_arg;
827          if (value==0)
828             st->pitch_enabled = 0;
829          else
830             st->pitch_enabled = 1;
831       }
832       break;
833       default:
834          goto bad_request;
835    }
836    va_end(ap);
837    return CELT_OK;
838 bad_arg:
839    va_end(ap);
840    return CELT_BAD_ARG;
841 bad_request:
842    va_end(ap);
843    return CELT_UNIMPLEMENTED;
844 }
845
846 /****************************************************************************/
847 /*                                                                          */
848 /*                                DECODER                                   */
849 /*                                                                          */
850 /****************************************************************************/
851 #ifdef NEW_PLC
852 #define DECODE_BUFFER_SIZE 2048
853 #else
854 #define DECODE_BUFFER_SIZE MAX_PERIOD
855 #endif
856
857 /** Decoder state 
858  @brief Decoder state
859  */
860 struct CELTDecoder {
861    const CELTMode *mode;
862    int frame_size;
863    int block_size;
864    int overlap;
865
866    ec_byte_buffer buf;
867    ec_enc         enc;
868
869    celt_sig_t * restrict preemph_memD;
870
871    celt_sig_t *out_mem;
872    celt_sig_t *decode_mem;
873
874    celt_word16_t *oldBandE;
875    
876    int last_pitch_index;
877 };
878
879 CELTDecoder *celt_decoder_create(const CELTMode *mode)
880 {
881    int N, C;
882    CELTDecoder *st;
883
884    if (check_mode(mode) != CELT_OK)
885       return NULL;
886
887    N = mode->mdctSize;
888    C = CHANNELS(mode);
889    st = celt_alloc(sizeof(CELTDecoder));
890    
891    st->mode = mode;
892    st->frame_size = N;
893    st->block_size = N;
894    st->overlap = mode->overlap;
895
896    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
897    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
898    
899    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
900
901    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
902
903    st->last_pitch_index = 0;
904    return st;
905 }
906
907 void celt_decoder_destroy(CELTDecoder *st)
908 {
909    if (st == NULL)
910    {
911       celt_warning("NULL passed to celt_encoder_destroy");
912       return;
913    }
914    if (check_mode(st->mode) != CELT_OK)
915       return;
916
917
918    celt_free(st->decode_mem);
919    
920    celt_free(st->oldBandE);
921    
922    celt_free(st->preemph_memD);
923
924    celt_free(st);
925 }
926
927 /** Handles lost packets by just copying past data with the same offset as the last
928     pitch period */
929 #ifdef NEW_PLC
930 #include "plc.c"
931 #else
932 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
933 {
934    int c, N;
935    int pitch_index;
936    int i, len;
937    VARDECL(celt_sig_t, freq);
938    const int C = CHANNELS(st->mode);
939    int offset;
940    SAVE_STACK;
941    N = st->block_size;
942    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
943    
944    len = N+st->mode->overlap;
945 #if 0
946    pitch_index = st->last_pitch_index;
947    
948    /* Use the pitch MDCT as the "guessed" signal */
949    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
950
951 #else
952    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
953    pitch_index = MAX_PERIOD-len-pitch_index;
954    offset = MAX_PERIOD-pitch_index;
955    while (offset+len >= MAX_PERIOD)
956       offset -= pitch_index;
957    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
958    for (i=0;i<N;i++)
959       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
960 #endif
961    
962    
963    
964    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
965    /* Compute inverse MDCTs */
966    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
967
968    for (c=0;c<C;c++)
969    {
970       int j;
971       for (j=0;j<N;j++)
972       {
973          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
974                                 preemph,st->preemph_memD[c]);
975          st->preemph_memD[c] = tmp;
976          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
977       }
978    }
979    RESTORE_STACK;
980 }
981 #endif
982
983 #ifdef FIXED_POINT
984 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
985 {
986 #else
987 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
988 {
989 #endif
990    int i, c, N, N4;
991    int has_pitch, has_fold;
992    int pitch_index;
993    int bits;
994    ec_dec dec;
995    ec_byte_buffer buf;
996    VARDECL(celt_sig_t, freq);
997    VARDECL(celt_norm_t, X);
998    VARDECL(celt_norm_t, P);
999    VARDECL(celt_ener_t, bandE);
1000    VARDECL(celt_pgain_t, gains);
1001    VARDECL(int, stereo_mode);
1002    VARDECL(int, fine_quant);
1003    VARDECL(int, pulses);
1004    VARDECL(int, offsets);
1005
1006    int shortBlocks;
1007    int intra_ener;
1008    int transient_time;
1009    int transient_shift;
1010    const int C = CHANNELS(st->mode);
1011    SAVE_STACK;
1012
1013    if (check_mode(st->mode) != CELT_OK)
1014       return CELT_INVALID_MODE;
1015
1016    N = st->block_size;
1017    N4 = (N-st->overlap)>>1;
1018
1019    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1020    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
1021    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
1022    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1023    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
1024    
1025    if (check_mode(st->mode) != CELT_OK)
1026    {
1027       RESTORE_STACK;
1028       return CELT_INVALID_MODE;
1029    }
1030    if (data == NULL)
1031    {
1032       celt_decode_lost(st, pcm);
1033       RESTORE_STACK;
1034       return 0;
1035    }
1036    if (len<0) {
1037      RESTORE_STACK;
1038      return CELT_BAD_ARG;
1039    }
1040    
1041    ec_byte_readinit(&buf,(unsigned char*)data,len);
1042    ec_dec_init(&dec,&buf);
1043    
1044    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1045    if (shortBlocks)
1046    {
1047       transient_shift = ec_dec_bits(&dec, 2);
1048       if (transient_shift)
1049          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1050       else
1051          transient_time = 0;
1052    } else {
1053       transient_time = -1;
1054       transient_shift = 0;
1055    }
1056    
1057    if (has_pitch)
1058    {
1059       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1060       st->last_pitch_index = pitch_index;
1061    } else {
1062       pitch_index = 0;
1063       for (i=0;i<st->mode->nbPBands;i++)
1064          gains[i] = 0;
1065    }
1066
1067    ALLOC(fine_quant, st->mode->nbEBands, int);
1068    /* Get band energies */
1069    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, intra_ener, st->mode->prob, &dec);
1070    
1071    ALLOC(pulses, st->mode->nbEBands, int);
1072    ALLOC(offsets, st->mode->nbEBands, int);
1073    ALLOC(stereo_mode, st->mode->nbEBands, int);
1074    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1075
1076    for (i=0;i<st->mode->nbEBands;i++)
1077       offsets[i] = 0;
1078
1079    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1080    if (has_pitch)
1081       bits -= st->mode->nbPBands;
1082    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1083    /*bits = ec_dec_tell(&dec, 0);
1084    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1085    
1086    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1087
1088
1089    if (has_pitch) 
1090    {
1091       VARDECL(celt_ener_t, bandEp);
1092       
1093       /* Pitch MDCT */
1094       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1095       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1096       compute_band_energies(st->mode, freq, bandEp);
1097       normalise_bands(st->mode, freq, P, bandEp);
1098       /* Apply pitch gains */
1099    } else {
1100       for (i=0;i<C*N;i++)
1101          P[i] = 0;
1102    }
1103
1104    /* Decode fixed codebook and merge with pitch */
1105    if (C==1)
1106       unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1107    else
1108       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1109
1110    /* Synthesis */
1111    denormalise_bands(st->mode, X, freq, bandE);
1112
1113
1114    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1115    /* Compute inverse MDCTs */
1116    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1117
1118    for (c=0;c<C;c++)
1119    {
1120       int j;
1121       for (j=0;j<N;j++)
1122       {
1123          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1124                                 preemph,st->preemph_memD[c]);
1125          st->preemph_memD[c] = tmp;
1126          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1127       }
1128    }
1129
1130    {
1131       unsigned int val = 0;
1132       while (ec_dec_tell(&dec, 0) < len*8)
1133       {
1134          if (ec_dec_uint(&dec, 2) != val)
1135          {
1136             celt_warning("decode error");
1137             RESTORE_STACK;
1138             return CELT_CORRUPTED_DATA;
1139          }
1140          val = 1-val;
1141       }
1142    }
1143
1144    RESTORE_STACK;
1145    return 0;
1146    /*printf ("\n");*/
1147 }
1148
1149 #ifdef FIXED_POINT
1150 #ifndef DISABLE_FLOAT_API
1151 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1152 {
1153    int j, ret;
1154    const int C = CHANNELS(st->mode);
1155    const int N = st->block_size;
1156    VARDECL(celt_int16_t, out);
1157    SAVE_STACK;
1158    ALLOC(out, C*N, celt_int16_t);
1159
1160    ret=celt_decode(st, data, len, out);
1161
1162    for (j=0;j<C*N;j++)
1163      pcm[j]=out[j]*(1/32768.);
1164    RESTORE_STACK;
1165    return ret;
1166 }
1167 #endif /*DISABLE_FLOAT_API*/
1168 #else
1169 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1170 {
1171    int j, ret;
1172    VARDECL(celt_sig_t, out);
1173    const int C = CHANNELS(st->mode);
1174    const int N = st->block_size;
1175    SAVE_STACK;
1176    ALLOC(out, C*N, celt_sig_t);
1177
1178    ret=celt_decode_float(st, data, len, out);
1179
1180    for (j=0;j<C*N;j++)
1181      pcm[j] = FLOAT2INT16 (out[j]);
1182
1183    RESTORE_STACK;
1184    return ret;
1185 }
1186 #endif