New code for encoding the flags
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2    (C) 2008 Gregory Maxwell */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_bands.h"
48 #include "psy.h"
49 #include "rate.h"
50 #include "stack_alloc.h"
51 #include "mathops.h"
52 #include "float_cast.h"
53 #include <stdarg.h>
54
55 static const celt_word16_t preemph = QCONST16(0.8f,15);
56
57 #ifdef FIXED_POINT
58 static const celt_word16_t transientWindow[16] = {
59      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
60    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
61 #else
62 static const float transientWindow[16] = {
63    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
64    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
65 #endif
66
67    
68 /** Encoder state 
69  @brief Encoder state
70  */
71 struct CELTEncoder {
72    const CELTMode *mode;     /**< Mode used by the encoder */
73    int frame_size;
74    int block_size;
75    int overlap;
76    int channels;
77    
78    int pitch_enabled;
79    int pitch_available;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112    st->pitch_available = 1;
113
114    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
115    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
116
117    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
118
119    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
120    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
121
122 #ifdef EXP_PSY
123    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
124    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
125 #endif
126
127    return st;
128 }
129
130 void celt_encoder_destroy(CELTEncoder *st)
131 {
132    if (st == NULL)
133    {
134       celt_warning("NULL passed to celt_encoder_destroy");
135       return;
136    }
137    if (check_mode(st->mode) != CELT_OK)
138       return;
139
140    celt_free(st->in_mem);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148 #ifdef EXP_PSY
149    celt_free (st->psy_mem);
150    psydecay_clear(&st->psy);
151 #endif
152    
153    celt_free(st);
154 }
155
156 static inline celt_int16_t FLOAT2INT16(float x)
157 {
158    x = x*32768.;
159    x = MAX32(x, -32768);
160    x = MIN32(x, 32767);
161    return (celt_int16_t)float2int(x);
162 }
163
164 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
165 {
166 #ifdef FIXED_POINT
167    x = PSHR32(x, SIG_SHIFT);
168    x = MAX32(x, -32768);
169    x = MIN32(x, 32767);
170    return EXTRACT16(x);
171 #else
172    return (celt_word16_t)x;
173 #endif
174 }
175
176 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
177 {
178    int c, i, n;
179    celt_word32_t ratio;
180    /* FIXME: Remove the floats here */
181    VARDECL(celt_word32_t, begin);
182    SAVE_STACK;
183    ALLOC(begin, len, celt_word32_t);
184    for (i=0;i<len;i++)
185       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
186    for (c=1;c<C;c++)
187    {
188       for (i=0;i<len;i++)
189          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
190    }
191    for (i=1;i<len;i++)
192       begin[i] = MAX32(begin[i-1],begin[i]);
193    n = -1;
194    for (i=8;i<len-8;i++)
195    {
196       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
197          n=i;
198    }
199    if (n<32)
200    {
201       n = -1;
202       ratio = 0;
203    } else {
204       ratio = DIV32(begin[len-1],1+begin[n-16]);
205    }
206    /*printf ("%d %f\n", n, ratio*ratio);*/
207    if (ratio < 0)
208       ratio = 0;
209    if (ratio > 1000)
210       ratio = 1000;
211    ratio *= ratio;
212    if (ratio < 50)
213       *transient_shift = 0;
214    else if (ratio < 256)
215       *transient_shift = 1;
216    else if (ratio < 4096)
217       *transient_shift = 2;
218    else
219       *transient_shift = 3;
220    *transient_time = n;
221    
222    RESTORE_STACK;
223    return ratio > 20;
224 }
225
226 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
227 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
228 {
229    const int C = CHANNELS(mode);
230    if (C==1 && !shortBlocks)
231    {
232       const mdct_lookup *lookup = MDCT(mode);
233       const int overlap = OVERLAP(mode);
234       mdct_forward(lookup, in, out, mode->window, overlap);
235    } else if (!shortBlocks) {
236       const mdct_lookup *lookup = MDCT(mode);
237       const int overlap = OVERLAP(mode);
238       const int N = FRAMESIZE(mode);
239       int c;
240       VARDECL(celt_word32_t, x);
241       VARDECL(celt_word32_t, tmp);
242       SAVE_STACK;
243       ALLOC(x, N+overlap, celt_word32_t);
244       ALLOC(tmp, N, celt_word32_t);
245       for (c=0;c<C;c++)
246       {
247          int j;
248          for (j=0;j<N+overlap;j++)
249             x[j] = in[C*j+c];
250          mdct_forward(lookup, x, tmp, mode->window, overlap);
251          /* Interleaving the sub-frames */
252          for (j=0;j<N;j++)
253             out[C*j+c] = tmp[j];
254       }
255       RESTORE_STACK;
256    } else {
257       const mdct_lookup *lookup = &mode->shortMdct;
258       const int overlap = mode->overlap;
259       const int N = mode->shortMdctSize;
260       int b, c;
261       VARDECL(celt_word32_t, x);
262       VARDECL(celt_word32_t, tmp);
263       SAVE_STACK;
264       ALLOC(x, N+overlap, celt_word32_t);
265       ALLOC(tmp, N, celt_word32_t);
266       for (c=0;c<C;c++)
267       {
268          int B = mode->nbShortMdcts;
269          for (b=0;b<B;b++)
270          {
271             int j;
272             for (j=0;j<N+overlap;j++)
273                x[j] = in[C*(b*N+j)+c];
274             mdct_forward(lookup, x, tmp, mode->window, overlap);
275             /* Interleaving the sub-frames */
276             for (j=0;j<N;j++)
277                out[C*(j*B+b)+c] = tmp[j];
278          }
279       }
280       RESTORE_STACK;
281    }
282 }
283
284 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
285 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
286 {
287    int c, N4;
288    const int C = CHANNELS(mode);
289    const int N = FRAMESIZE(mode);
290    const int overlap = OVERLAP(mode);
291    N4 = (N-overlap)>>1;
292    for (c=0;c<C;c++)
293    {
294       int j;
295       if (transient_shift==0 && C==1 && !shortBlocks) {
296          const mdct_lookup *lookup = MDCT(mode);
297          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
298       } else if (!shortBlocks) {
299          const mdct_lookup *lookup = MDCT(mode);
300          VARDECL(celt_word32_t, x);
301          VARDECL(celt_word32_t, tmp);
302          SAVE_STACK;
303          ALLOC(x, 2*N, celt_word32_t);
304          ALLOC(tmp, N, celt_word32_t);
305          /* De-interleaving the sub-frames */
306          for (j=0;j<N;j++)
307             tmp[j] = X[C*j+c];
308          /* Prevents problems from the imdct doing the overlap-add */
309          CELT_MEMSET(x+N4, 0, N);
310          mdct_backward(lookup, tmp, x, mode->window, overlap);
311          celt_assert(transient_shift == 0);
312          /* The first and last part would need to be set to zero if we actually
313             wanted to use them. */
314          for (j=0;j<overlap;j++)
315             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
316          for (j=0;j<overlap;j++)
317             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
318          for (j=0;j<2*N4;j++)
319             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
320          RESTORE_STACK;
321       } else {
322          int b;
323          const int N2 = mode->shortMdctSize;
324          const int B = mode->nbShortMdcts;
325          const mdct_lookup *lookup = &mode->shortMdct;
326          VARDECL(celt_word32_t, x);
327          VARDECL(celt_word32_t, tmp);
328          SAVE_STACK;
329          ALLOC(x, 2*N, celt_word32_t);
330          ALLOC(tmp, N, celt_word32_t);
331          /* Prevents problems from the imdct doing the overlap-add */
332          CELT_MEMSET(x+N4, 0, N2);
333          for (b=0;b<B;b++)
334          {
335             /* De-interleaving the sub-frames */
336             for (j=0;j<N2;j++)
337                tmp[j] = X[C*(j*B+b)+c];
338             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
339          }
340          if (transient_shift > 0)
341          {
342 #ifdef FIXED_POINT
343             for (j=0;j<16;j++)
344                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
345             for (j=transient_time;j<N+overlap;j++)
346                x[N4+j] = SHL32(x[N4+j], transient_shift);
347 #else
348             for (j=0;j<16;j++)
349                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
350             for (j=transient_time;j<N+overlap;j++)
351                x[N4+j] *= 1<<transient_shift;
352 #endif
353          }
354          /* The first and last part would need to be set to zero if we actually
355          wanted to use them. */
356          for (j=0;j<overlap;j++)
357             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
358          for (j=0;j<overlap;j++)
359             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
360          for (j=0;j<2*N4;j++)
361             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
362          RESTORE_STACK;
363       }
364    }
365 }
366
367 #define FLAG_NONE        0
368 #define FLAG_INTRA       1U<<16
369 #define FLAG_PITCH       1U<<15
370 #define FLAG_SHORT       1U<<14
371 #define FLAG_FOLD        1U<<13
372 #define FLAG_MASK        (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD)
373
374 celt_int32_t flaglist[8] = {
375       0b00   | FLAG_FOLD,
376       0b01   | FLAG_PITCH|FLAG_FOLD,
377       0b1000 | FLAG_NONE,
378       0b1001 | FLAG_SHORT|FLAG_FOLD,
379       0b1010 | FLAG_PITCH,
380       0b1011 | FLAG_INTRA,
381       0b110  | FLAG_INTRA|FLAG_FOLD,
382       0b111  | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD
383 };
384
385 void encode_flags(ec_enc *enc, int intra_ener, int has_pitch, int shortBlocks, int has_fold)
386 {
387    int i;
388    int flags=FLAG_NONE;
389    int flag_bits;
390    flags |= intra_ener   ? FLAG_INTRA : 0;
391    flags |= has_pitch    ? FLAG_PITCH : 0;
392    flags |= shortBlocks  ? FLAG_SHORT : 0;
393    flags |= has_fold     ? FLAG_FOLD  : 0;
394    for (i=0;i<8;i++)
395       if (flags == (flaglist[i]&FLAG_MASK))
396          break;
397    celt_assert(i<8);
398    flag_bits = flaglist[i]&0xf;
399    /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/
400    if (i<2)
401       ec_enc_bits(enc, flag_bits, 2);
402    else if (i<6)
403       ec_enc_bits(enc, flag_bits, 4);
404    else
405       ec_enc_bits(enc, flag_bits, 3);
406 }
407
408 void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shortBlocks, int *has_fold)
409 {
410    int i;
411    int flag_bits;
412    flag_bits = ec_dec_bits(dec, 2);
413    /*printf ("(%d) ", flag_bits);*/
414    if (flag_bits==2)
415       flag_bits = (flag_bits<<2) | ec_dec_bits(dec, 2);
416    else if (flag_bits==3)
417       flag_bits = (flag_bits<<1) | ec_dec_bits(dec, 1);
418    for (i=0;i<8;i++)
419       if (flag_bits == (flaglist[i]&0xf))
420          break;
421    celt_assert(i<8);
422    *intra_ener  = (flaglist[i]&FLAG_INTRA) != 0;
423    *has_pitch   = (flaglist[i]&FLAG_PITCH) != 0;
424    *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0;
425    *has_fold    = (flaglist[i]&FLAG_FOLD ) != 0;
426    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
427 }
428
429 #ifdef FIXED_POINT
430 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
431 {
432 #else
433 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
434 {
435 #endif
436    int i, c, N, N4;
437    int has_pitch;
438    int pitch_index;
439    int bits;
440    int has_fold=1;
441    ec_byte_buffer buf;
442    ec_enc         enc;
443    VARDECL(celt_sig_t, in);
444    VARDECL(celt_sig_t, freq);
445    VARDECL(celt_norm_t, X);
446    VARDECL(celt_norm_t, P);
447    VARDECL(celt_ener_t, bandE);
448    VARDECL(celt_pgain_t, gains);
449    VARDECL(int, stereo_mode);
450    VARDECL(int, fine_quant);
451    VARDECL(celt_word16_t, error);
452    VARDECL(int, pulses);
453    VARDECL(int, offsets);
454 #ifdef EXP_PSY
455    VARDECL(celt_word32_t, mask);
456    VARDECL(celt_word32_t, tonality);
457    VARDECL(celt_word32_t, bandM);
458    VARDECL(celt_ener_t, bandN);
459 #endif
460    int shortBlocks=0;
461    int transient_time;
462    int transient_shift;
463    const int C = CHANNELS(st->mode);
464    SAVE_STACK;
465
466    if (check_mode(st->mode) != CELT_OK)
467       return CELT_INVALID_MODE;
468
469    if (nbCompressedBytes<0)
470      return CELT_BAD_ARG; 
471
472    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
473    CELT_MEMSET(compressed, 0, nbCompressedBytes);
474    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
475    ec_enc_init(&enc,&buf);
476
477    N = st->block_size;
478    N4 = (N-st->overlap)>>1;
479    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
480
481    CELT_COPY(in, st->in_mem, C*st->overlap);
482    for (c=0;c<C;c++)
483    {
484       const celt_word16_t * restrict pcmp = pcm+c;
485       celt_sig_t * restrict inp = in+C*st->overlap+c;
486       for (i=0;i<N;i++)
487       {
488          /* Apply pre-emphasis */
489          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
490          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
491          st->preemph_memE[c] = SCALEIN(*pcmp);
492          inp += C;
493          pcmp += C;
494       }
495    }
496    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
497    
498    /* Transient handling */
499    if (st->mode->nbShortMdcts > 1)
500    {
501       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
502       {
503 #ifndef FIXED_POINT
504          float gain_1;
505 #endif
506          /* Apply the inverse shaping window */
507          if (transient_shift)
508          {
509 #ifdef FIXED_POINT
510             for (c=0;c<C;c++)
511                for (i=0;i<16;i++)
512                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
513             for (c=0;c<C;c++)
514                for (i=transient_time;i<N+st->overlap;i++)
515                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
516 #else
517             for (c=0;c<C;c++)
518                for (i=0;i<16;i++)
519                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
520             gain_1 = 1./(1<<transient_shift);
521             for (c=0;c<C;c++)
522                for (i=transient_time;i<N+st->overlap;i++)
523                   in[C*i+c] *= gain_1;
524 #endif
525          }
526          shortBlocks = 1;
527       } else {
528          transient_time = -1;
529          transient_shift = 0;
530          shortBlocks = 0;
531       }
532    } else {
533       transient_time = -1;
534       transient_shift = 0;
535       shortBlocks = 0;
536    }
537
538    /* Pitch analysis: we do it early to save on the peak stack space */
539    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
540    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks);
541 #ifdef EXP_PSY
542    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
543    {
544       VARDECL(celt_word16_t, X);
545       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
546       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
547       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
548    }
549 #else
550    if (has_pitch)
551    {
552       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
553    }
554 #endif
555    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
556    
557    /* Compute MDCTs */
558    compute_mdcts(st->mode, shortBlocks, in, freq);
559
560 #ifdef EXP_PSY
561    ALLOC(mask, N, celt_sig_t);
562    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
563    /*for (i=0;i<256;i++)
564       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
565    printf ("\n");*/
566 #endif
567
568    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
569    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
570    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
571    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
572    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
573
574
575    /* Band normalisation */
576    compute_band_energies(st->mode, freq, bandE);
577    normalise_bands(st->mode, freq, X, bandE);
578
579 #ifdef EXP_PSY
580    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
581    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
582    compute_noise_energies(st->mode, freq, tonality, bandN);
583
584    /*for (i=0;i<st->mode->nbEBands;i++)
585       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
586    printf ("\n");*/
587    has_fold = 0;
588    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
589       if (bandN[i] < .4*bandE[i])
590          has_fold++;
591    /*printf ("%d\n", has_fold);*/
592    if (has_fold>=2)
593       has_fold = 0;
594    else
595       has_fold = 1;
596    for (i=0;i<N;i++)
597       mask[i] = sqrt(mask[i]);
598    compute_band_energies(st->mode, mask, bandM);
599    /*for (i=0;i<st->mode->nbEBands;i++)
600       printf ("%f %f ", bandE[i], bandM[i]);
601    printf ("\n");*/
602 #endif
603
604    /* Compute MDCTs of the pitch part */
605    if (has_pitch)
606    {
607       celt_word32_t curr_power, pitch_power=0;
608       /* Normalise the pitch vector as well (discard the energies) */
609       VARDECL(celt_ener_t, bandEp);
610       
611       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
612       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
613       compute_band_energies(st->mode, freq, bandEp);
614       normalise_bands(st->mode, freq, P, bandEp);
615       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
616       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
617       curr_power = bandE[0]+bandE[1]+bandE[2];
618       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
619       {
620          /* Pitch prediction */
621          has_pitch = compute_pitch_gain(st->mode, X, P, gains);
622       } else {
623          has_pitch = 0;
624       }
625    }
626    
627    encode_flags(&enc, 0, has_pitch, shortBlocks, has_fold);
628    if (has_pitch)
629    {
630       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
631    } else {
632       for (i=0;i<st->mode->nbPBands;i++)
633          gains[i] = 0;
634       for (i=0;i<C*N;i++)
635          P[i] = 0;
636    }
637    if (shortBlocks)
638    {
639       ec_enc_bits(&enc, transient_shift, 2);
640       if (transient_shift)
641          ec_enc_uint(&enc, transient_time, N+st->overlap);
642    }
643
644 #ifdef STDIN_TUNING2
645    static int fine_quant[30];
646    static int pulses[30];
647    static int init=0;
648    if (!init)
649    {
650       for (i=0;i<st->mode->nbEBands;i++)
651          scanf("%d ", &fine_quant[i]);
652       for (i=0;i<st->mode->nbEBands;i++)
653          scanf("%d ", &pulses[i]);
654       init = 1;
655    }
656 #else
657    ALLOC(fine_quant, st->mode->nbEBands, int);
658    ALLOC(pulses, st->mode->nbEBands, int);
659 #endif
660
661    /* Bit allocation */
662    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
663    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
664    
665    ALLOC(offsets, st->mode->nbEBands, int);
666    ALLOC(stereo_mode, st->mode->nbEBands, int);
667    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
668
669    for (i=0;i<st->mode->nbEBands;i++)
670       offsets[i] = 0;
671    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
672    if (has_pitch)
673       bits -= st->mode->nbPBands;
674 #ifndef STDIN_TUNING
675    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
676 #endif
677
678    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
679
680    /* Residual quantisation */
681    if (C==1)
682       quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
683    else
684       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
685
686    /* Re-synthesis of the coded audio if required */
687    if (st->pitch_available>0 || optional_synthesis!=NULL)
688    {
689       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
690         st->pitch_available+=st->frame_size;
691
692       /* Synthesis */
693       denormalise_bands(st->mode, X, freq, bandE);
694       
695       
696       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
697       
698       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
699       /* De-emphasis and put everything back at the right place in the synthesis history */
700       if (optional_synthesis != NULL) {
701          for (c=0;c<C;c++)
702          {
703             int j;
704             for (j=0;j<N;j++)
705             {
706                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
707                                    preemph,st->preemph_memD[c]);
708                st->preemph_memD[c] = tmp;
709                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
710             }
711          }
712       }
713    }
714
715    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
716    /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
717       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
718
719    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
720    {
721       int val = 0;
722       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
723       {
724          ec_enc_uint(&enc, val, 2);
725          val = 1-val;
726       }
727    }
728    ec_enc_done(&enc);
729    {
730       /*unsigned char *data;*/
731       int nbBytes = ec_byte_bytes(&buf);
732       if (nbBytes > nbCompressedBytes)
733       {
734          celt_warning_int ("got too many bytes:", nbBytes);
735          RESTORE_STACK;
736          return CELT_INTERNAL_ERROR;
737       }
738    }
739
740    RESTORE_STACK;
741    return nbCompressedBytes;
742 }
743
744 #ifdef FIXED_POINT
745 #ifndef DISABLE_FLOAT_API
746 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
747 {
748    int j, ret;
749    const int C = CHANNELS(st->mode);
750    const int N = st->block_size;
751    VARDECL(celt_int16_t, in);
752    SAVE_STACK;
753    ALLOC(in, C*N, celt_int16_t);
754
755    for (j=0;j<C*N;j++)
756      in[j] = FLOAT2INT16(pcm[j]);
757
758    if (optional_synthesis != NULL) {
759      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
760       for (j=0;j<C*N;j++)
761          optional_synthesis[j]=in[j]*(1/32768.);
762    } else {
763      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
764    }
765    RESTORE_STACK;
766    return ret;
767
768 }
769 #endif /*DISABLE_FLOAT_API*/
770 #else
771 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
772 {
773    int j, ret;
774    VARDECL(celt_sig_t, in);
775    const int C = CHANNELS(st->mode);
776    const int N = st->block_size;
777    SAVE_STACK;
778    ALLOC(in, C*N, celt_sig_t);
779    for (j=0;j<C*N;j++) {
780      in[j] = SCALEOUT(pcm[j]);
781    }
782
783    if (optional_synthesis != NULL) {
784       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
785       for (j=0;j<C*N;j++)
786          optional_synthesis[j] = FLOAT2INT16(in[j]);
787    } else {
788       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
789    }
790    RESTORE_STACK;
791    return ret;
792 }
793 #endif
794
795 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
796 {
797    va_list ap;
798    va_start(ap, request);
799    switch (request)
800    {
801       case CELT_SET_COMPLEXITY_REQUEST:
802       {
803          int value = va_arg(ap, int);
804          if (value<0 || value>10)
805             goto bad_arg;
806          if (value<=2) {
807             st->pitch_enabled = 0; 
808             st->pitch_available = 0;
809          } else {
810               st->pitch_enabled = 1;
811               if (st->pitch_available<1)
812                 st->pitch_available = 1;
813          }   
814       }
815       break;
816       case CELT_SET_LTP_REQUEST:
817       {
818          int value = va_arg(ap, int);
819          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
820             goto bad_arg;
821          if (value==0)
822             st->pitch_enabled = 0;
823          else
824             st->pitch_enabled = 1;
825       }
826       break;
827       default:
828          goto bad_request;
829    }
830    va_end(ap);
831    return CELT_OK;
832 bad_arg:
833    va_end(ap);
834    return CELT_BAD_ARG;
835 bad_request:
836    va_end(ap);
837    return CELT_UNIMPLEMENTED;
838 }
839
840 /****************************************************************************/
841 /*                                                                          */
842 /*                                DECODER                                   */
843 /*                                                                          */
844 /****************************************************************************/
845 #ifdef NEW_PLC
846 #define DECODE_BUFFER_SIZE 2048
847 #else
848 #define DECODE_BUFFER_SIZE MAX_PERIOD
849 #endif
850
851 /** Decoder state 
852  @brief Decoder state
853  */
854 struct CELTDecoder {
855    const CELTMode *mode;
856    int frame_size;
857    int block_size;
858    int overlap;
859
860    ec_byte_buffer buf;
861    ec_enc         enc;
862
863    celt_sig_t * restrict preemph_memD;
864
865    celt_sig_t *out_mem;
866    celt_sig_t *decode_mem;
867
868    celt_word16_t *oldBandE;
869    
870    int last_pitch_index;
871 };
872
873 CELTDecoder *celt_decoder_create(const CELTMode *mode)
874 {
875    int N, C;
876    CELTDecoder *st;
877
878    if (check_mode(mode) != CELT_OK)
879       return NULL;
880
881    N = mode->mdctSize;
882    C = CHANNELS(mode);
883    st = celt_alloc(sizeof(CELTDecoder));
884    
885    st->mode = mode;
886    st->frame_size = N;
887    st->block_size = N;
888    st->overlap = mode->overlap;
889
890    st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
891    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
892    
893    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
894
895    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
896
897    st->last_pitch_index = 0;
898    return st;
899 }
900
901 void celt_decoder_destroy(CELTDecoder *st)
902 {
903    if (st == NULL)
904    {
905       celt_warning("NULL passed to celt_encoder_destroy");
906       return;
907    }
908    if (check_mode(st->mode) != CELT_OK)
909       return;
910
911
912    celt_free(st->decode_mem);
913    
914    celt_free(st->oldBandE);
915    
916    celt_free(st->preemph_memD);
917
918    celt_free(st);
919 }
920
921 /** Handles lost packets by just copying past data with the same offset as the last
922     pitch period */
923 #ifdef NEW_PLC
924 #include "plc.c"
925 #else
926 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
927 {
928    int c, N;
929    int pitch_index;
930    int i, len;
931    VARDECL(celt_sig_t, freq);
932    const int C = CHANNELS(st->mode);
933    int offset;
934    SAVE_STACK;
935    N = st->block_size;
936    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
937    
938    len = N+st->mode->overlap;
939 #if 0
940    pitch_index = st->last_pitch_index;
941    
942    /* Use the pitch MDCT as the "guessed" signal */
943    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
944
945 #else
946    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
947    pitch_index = MAX_PERIOD-len-pitch_index;
948    offset = MAX_PERIOD-pitch_index;
949    while (offset+len >= MAX_PERIOD)
950       offset -= pitch_index;
951    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
952    for (i=0;i<N;i++)
953       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
954 #endif
955    
956    
957    
958    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
959    /* Compute inverse MDCTs */
960    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
961
962    for (c=0;c<C;c++)
963    {
964       int j;
965       for (j=0;j<N;j++)
966       {
967          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
968                                 preemph,st->preemph_memD[c]);
969          st->preemph_memD[c] = tmp;
970          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
971       }
972    }
973    RESTORE_STACK;
974 }
975 #endif
976
977 #ifdef FIXED_POINT
978 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
979 {
980 #else
981 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
982 {
983 #endif
984    int i, c, N, N4;
985    int has_pitch, has_fold;
986    int pitch_index;
987    int bits;
988    ec_dec dec;
989    ec_byte_buffer buf;
990    VARDECL(celt_sig_t, freq);
991    VARDECL(celt_norm_t, X);
992    VARDECL(celt_norm_t, P);
993    VARDECL(celt_ener_t, bandE);
994    VARDECL(celt_pgain_t, gains);
995    VARDECL(int, stereo_mode);
996    VARDECL(int, fine_quant);
997    VARDECL(int, pulses);
998    VARDECL(int, offsets);
999
1000    int shortBlocks;
1001    int intra_ener;
1002    int transient_time;
1003    int transient_shift;
1004    const int C = CHANNELS(st->mode);
1005    SAVE_STACK;
1006
1007    if (check_mode(st->mode) != CELT_OK)
1008       return CELT_INVALID_MODE;
1009
1010    N = st->block_size;
1011    N4 = (N-st->overlap)>>1;
1012
1013    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
1014    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
1015    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
1016    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
1017    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
1018    
1019    if (check_mode(st->mode) != CELT_OK)
1020    {
1021       RESTORE_STACK;
1022       return CELT_INVALID_MODE;
1023    }
1024    if (data == NULL)
1025    {
1026       celt_decode_lost(st, pcm);
1027       RESTORE_STACK;
1028       return 0;
1029    }
1030    if (len<0) {
1031      RESTORE_STACK;
1032      return CELT_BAD_ARG;
1033    }
1034    
1035    ec_byte_readinit(&buf,(unsigned char*)data,len);
1036    ec_dec_init(&dec,&buf);
1037    
1038    decode_flags(&dec, &intra_ener, &has_pitch, &shortBlocks, &has_fold);
1039    if (shortBlocks)
1040    {
1041       transient_shift = ec_dec_bits(&dec, 2);
1042       if (transient_shift)
1043          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
1044       else
1045          transient_time = 0;
1046    } else {
1047       transient_time = -1;
1048       transient_shift = 0;
1049    }
1050    
1051    if (has_pitch)
1052    {
1053       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1054       st->last_pitch_index = pitch_index;
1055    } else {
1056       pitch_index = 0;
1057       for (i=0;i<st->mode->nbPBands;i++)
1058          gains[i] = 0;
1059    }
1060
1061    ALLOC(fine_quant, st->mode->nbEBands, int);
1062    /* Get band energies */
1063    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
1064    
1065    ALLOC(pulses, st->mode->nbEBands, int);
1066    ALLOC(offsets, st->mode->nbEBands, int);
1067    ALLOC(stereo_mode, st->mode->nbEBands, int);
1068    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1069
1070    for (i=0;i<st->mode->nbEBands;i++)
1071       offsets[i] = 0;
1072
1073    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1074    if (has_pitch)
1075       bits -= st->mode->nbPBands;
1076    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1077    /*bits = ec_dec_tell(&dec, 0);
1078    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1079    
1080    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1081
1082
1083    if (has_pitch) 
1084    {
1085       VARDECL(celt_ener_t, bandEp);
1086       
1087       /* Pitch MDCT */
1088       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1089       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1090       compute_band_energies(st->mode, freq, bandEp);
1091       normalise_bands(st->mode, freq, P, bandEp);
1092       /* Apply pitch gains */
1093    } else {
1094       for (i=0;i<C*N;i++)
1095          P[i] = 0;
1096    }
1097
1098    /* Decode fixed codebook and merge with pitch */
1099    if (C==1)
1100       unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1101    else
1102       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
1103
1104    /* Synthesis */
1105    denormalise_bands(st->mode, X, freq, bandE);
1106
1107
1108    CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
1109    /* Compute inverse MDCTs */
1110    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1111
1112    for (c=0;c<C;c++)
1113    {
1114       int j;
1115       for (j=0;j<N;j++)
1116       {
1117          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1118                                 preemph,st->preemph_memD[c]);
1119          st->preemph_memD[c] = tmp;
1120          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1121       }
1122    }
1123
1124    {
1125       unsigned int val = 0;
1126       while (ec_dec_tell(&dec, 0) < len*8)
1127       {
1128          if (ec_dec_uint(&dec, 2) != val)
1129          {
1130             celt_warning("decode error");
1131             RESTORE_STACK;
1132             return CELT_CORRUPTED_DATA;
1133          }
1134          val = 1-val;
1135       }
1136    }
1137
1138    RESTORE_STACK;
1139    return 0;
1140    /*printf ("\n");*/
1141 }
1142
1143 #ifdef FIXED_POINT
1144 #ifndef DISABLE_FLOAT_API
1145 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
1146 {
1147    int j, ret;
1148    const int C = CHANNELS(st->mode);
1149    const int N = st->block_size;
1150    VARDECL(celt_int16_t, out);
1151    SAVE_STACK;
1152    ALLOC(out, C*N, celt_int16_t);
1153
1154    ret=celt_decode(st, data, len, out);
1155
1156    for (j=0;j<C*N;j++)
1157      pcm[j]=out[j]*(1/32768.);
1158    RESTORE_STACK;
1159    return ret;
1160 }
1161 #endif /*DISABLE_FLOAT_API*/
1162 #else
1163 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
1164 {
1165    int j, ret;
1166    VARDECL(celt_sig_t, out);
1167    const int C = CHANNELS(st->mode);
1168    const int N = st->block_size;
1169    SAVE_STACK;
1170    ALLOC(out, C*N, celt_sig_t);
1171
1172    ret=celt_decode_float(st, data, len, out);
1173
1174    for (j=0;j<C*N;j++)
1175      pcm[j] = FLOAT2INT16 (out[j]);
1176
1177    RESTORE_STACK;
1178    return ret;
1179 }
1180 #endif