Latest psychoacoustics work -- still highly experimental
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    ec_byte_buffer buf;
378    ec_enc         enc;
379    celt_word32_t curr_power, pitch_power=0;
380    VARDECL(celt_sig_t, in);
381    VARDECL(celt_sig_t, freq);
382    VARDECL(celt_norm_t, X);
383    VARDECL(celt_norm_t, P);
384    VARDECL(celt_ener_t, bandE);
385    VARDECL(celt_pgain_t, gains);
386    VARDECL(int, stereo_mode);
387    VARDECL(int, fine_quant);
388    VARDECL(celt_word16_t, error);
389    VARDECL(int, pulses);
390    VARDECL(int, offsets);
391 #ifdef EXP_PSY
392    VARDECL(celt_word32_t, mask);
393    VARDECL(celt_word32_t, tonality);
394 #endif
395    int shortBlocks=0;
396    int transient_time;
397    int transient_shift;
398    const int C = CHANNELS(st->mode);
399    SAVE_STACK;
400
401    if (check_mode(st->mode) != CELT_OK)
402       return CELT_INVALID_MODE;
403
404    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
405    CELT_MEMSET(compressed, 0, nbCompressedBytes);
406    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
407    ec_enc_init(&enc,&buf);
408
409    N = st->block_size;
410    N4 = (N-st->overlap)>>1;
411    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
412
413    CELT_COPY(in, st->in_mem, C*st->overlap);
414    for (c=0;c<C;c++)
415    {
416       const celt_word16_t * restrict pcmp = pcm+c;
417       celt_sig_t * restrict inp = in+C*st->overlap+c;
418       for (i=0;i<N;i++)
419       {
420          /* Apply pre-emphasis */
421          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
422          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
423          st->preemph_memE[c] = SCALEIN(*pcmp);
424          inp += C;
425          pcmp += C;
426       }
427    }
428    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
429    
430    if (st->mode->nbShortMdcts > 1)
431    {
432       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
433       {
434 #ifndef FIXED_POINT
435          float gain_1;
436 #endif
437          ec_enc_bits(&enc, 0, 1); //Pitch off
438          ec_enc_bits(&enc, 1, 1); //Transient on
439          ec_enc_bits(&enc, transient_shift, 2);
440          if (transient_shift)
441             ec_enc_uint(&enc, transient_time, N+st->overlap);
442          if (transient_shift)
443          {
444 #ifdef FIXED_POINT
445             for (c=0;c<C;c++)
446                for (i=0;i<16;i++)
447                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
448             for (c=0;c<C;c++)
449                for (i=transient_time;i<N+st->overlap;i++)
450                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
451 #else
452             for (c=0;c<C;c++)
453                for (i=0;i<16;i++)
454                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
455             gain_1 = 1./(1<<transient_shift);
456             for (c=0;c<C;c++)
457                for (i=transient_time;i<N+st->overlap;i++)
458                   in[C*i+c] *= gain_1;
459 #endif
460          }
461          shortBlocks = 1;
462       } else {
463          transient_time = -1;
464          transient_shift = 0;
465          shortBlocks = 0;
466       }
467    } else {
468       transient_time = -1;
469       transient_shift = 0;
470       shortBlocks = 0;
471    }
472
473    /* Pitch analysis: we do it early to save on the peak stack space */
474 #ifdef EXP_PSY
475    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
476    {
477       VARDECL(celt_word16_t, X);
478       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
479       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
480       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
481    }
482 #else
483    if (st->pitch_enabled && !shortBlocks)
484    {
485       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
486    }
487 #endif
488    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
489    
490    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
491    /* Compute MDCTs */
492    compute_mdcts(st->mode, shortBlocks, in, freq);
493
494 #ifdef EXP_PSY
495    /*CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
496    for (i=0;i<N;i++)
497       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
498    for (c=1;c<C;c++)
499       for (i=0;i<N;i++)
500          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
501    */
502    ALLOC(mask, N, celt_sig_t);
503    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
504    /*for (i=0;i<256;i++)
505       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
506    printf ("\n");*/
507
508    /* Invert and stretch the mask to length of X 
509       For some reason, I get better results by using the sqrt instead,
510       although there's no valid reason to. Must investigate further */
511    /*for (i=0;i<C*N;i++)
512       mask[i] = 1/(.1+mask[i]);*/
513 #endif
514    
515    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
516    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
517    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
518    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
519    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
520
521    /*printf ("%f %f\n", curr_power, pitch_power);*/
522    /*int j;
523    for (j=0;j<B*N;j++)
524       printf ("%f ", X[j]);
525    for (j=0;j<B*N;j++)
526       printf ("%f ", P[j]);
527    printf ("\n");*/
528
529    /* Band normalisation */
530    compute_band_energies(st->mode, freq, bandE);
531 #ifdef EXP_PSY
532    VARDECL(celt_word32_t, bandM);
533    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
534    for (i=0;i<N;i++)
535       mask[i] = sqrt(mask[i]);
536    compute_band_energies(st->mode, mask, bandM);
537    /*for (i=0;i<st->mode->nbEBands;i++)
538       printf ("%f %f ", bandE[i], bandM[i]);
539    printf ("\n");*/
540 #endif
541    normalise_bands(st->mode, freq, X, bandE);
542    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
543    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
544
545    /* Compute MDCTs of the pitch part */
546    if (st->pitch_enabled && !shortBlocks)
547    {
548       /* Normalise the pitch vector as well (discard the energies) */
549       VARDECL(celt_ener_t, bandEp);
550       
551       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
552       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
553       compute_band_energies(st->mode, freq, bandEp);
554       normalise_bands(st->mode, freq, P, bandEp);
555       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
556    }
557    curr_power = bandE[0]+bandE[1]+bandE[2];
558    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
559    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
560    {
561       /* Simulates intensity stereo */
562       /*for (i=30;i<N*B;i++)
563          X[i*C+1] = P[i*C+1] = 0;*/
564
565       /* Pitch prediction */
566       compute_pitch_gain(st->mode, X, P, gains);
567       has_pitch = quant_pitch(gains, st->mode->nbPBands, &enc);
568       if (has_pitch)
569          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
570       else if (st->mode->nbShortMdcts > 1)
571          ec_enc_bits(&enc, 0, 1); //Transient off
572    } else {
573       if (!shortBlocks)
574       {
575          ec_enc_bits(&enc, 0, 1); //Pitch off
576          if (st->mode->nbShortMdcts > 1)
577            ec_enc_bits(&enc, 0, 1); //Transient off
578       }
579       /* No pitch, so we just pretend we found a gain of zero */
580       for (i=0;i<st->mode->nbPBands;i++)
581          gains[i] = 0;
582       for (i=0;i<C*N;i++)
583          P[i] = 0;
584    }
585
586 #ifdef STDIN_TUNING2
587    static int fine_quant[30];
588    static int pulses[30];
589    static int init=0;
590    if (!init)
591    {
592       for (i=0;i<st->mode->nbEBands;i++)
593          scanf("%d ", &fine_quant[i]);
594       for (i=0;i<st->mode->nbEBands;i++)
595          scanf("%d ", &pulses[i]);
596       init = 1;
597    }
598 #else
599    ALLOC(fine_quant, st->mode->nbEBands, int);
600    ALLOC(pulses, st->mode->nbEBands, int);
601 #endif
602    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
603    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
604    
605    ALLOC(offsets, st->mode->nbEBands, int);
606    ALLOC(stereo_mode, st->mode->nbEBands, int);
607    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
608
609    for (i=0;i<st->mode->nbEBands;i++)
610       offsets[i] = 0;
611    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
612 #ifndef STDIN_TUNING
613    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
614 #endif
615    /*for (i=0;i<st->mode->nbEBands;i++)
616       printf("%d ", fine_quant[i]);
617    for (i=0;i<st->mode->nbEBands;i++)
618       printf("%d ", pulses[i]);
619    printf ("\n");*/
620    /*bits = ec_enc_tell(&st->enc, 0);
621    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
622    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
623
624    pitch_quant_bands(st->mode, P, gains);
625
626    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
627
628    /* Residual quantisation */
629    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &enc);
630    
631    if (st->pitch_enabled || optional_synthesis!=NULL)
632    {
633       if (C==2)
634          renormalise_bands(st->mode, X);
635       /* Synthesis */
636       denormalise_bands(st->mode, X, freq, bandE);
637       
638       
639       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
640       
641       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
642       /* De-emphasis and put everything back at the right place in the synthesis history */
643       if (optional_synthesis != NULL) {
644          for (c=0;c<C;c++)
645          {
646             int j;
647             for (j=0;j<N;j++)
648             {
649                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
650                                    preemph,st->preemph_memD[c]);
651                st->preemph_memD[c] = tmp;
652                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
653             }
654          }
655       }
656    }
657    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
658    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
659       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
660    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
661    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
662    {
663       int val = 0;
664       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
665       {
666          ec_enc_uint(&enc, val, 2);
667          val = 1-val;
668       }
669    }
670    ec_enc_done(&enc);
671    {
672       /*unsigned char *data;*/
673       int nbBytes = ec_byte_bytes(&buf);
674       if (nbBytes > nbCompressedBytes)
675       {
676          celt_warning_int ("got too many bytes:", nbBytes);
677          RESTORE_STACK;
678          return CELT_INTERNAL_ERROR;
679       }
680       /*printf ("%d\n", *nbBytes);*/
681       /*data = ec_byte_get_buffer(&buf);
682       for (i=0;i<nbBytes;i++)
683          compressed[i] = data[i];
684       for (i=nbBytes;i<nbCompressedBytes;i++)
685          compressed[i] = 0;*/
686    }
687
688    RESTORE_STACK;
689    return nbCompressedBytes;
690 }
691
692 #ifdef FIXED_POINT
693 #ifndef DISABLE_FLOAT_API
694 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
695 {
696    int j, ret;
697    const int C = CHANNELS(st->mode);
698    const int N = st->block_size;
699    VARDECL(celt_int16_t, in);
700    SAVE_STACK;
701    ALLOC(in, C*N, celt_int16_t);
702
703    for (j=0;j<C*N;j++)
704      in[j] = FLOAT2INT16(pcm[j]);
705
706    if (optional_synthesis != NULL) {
707      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
708    /*Converts backwards for inplace operation*/
709       for (j=0;j=C*N;j++)
710          optional_synthesis[j]=in[j]*(1/32768.);
711    } else {
712      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
713    }
714    RESTORE_STACK;
715    return ret;
716
717 }
718 #endif /*DISABLE_FLOAT_API*/
719 #else
720 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
721 {
722    int j, ret;
723    VARDECL(celt_sig_t, in);
724    const int C = CHANNELS(st->mode);
725    const int N = st->block_size;
726    SAVE_STACK;
727    ALLOC(in, C*N, celt_sig_t);
728    for (j=0;j<C*N;j++) {
729      in[j] = SCALEOUT(pcm[j]);
730    }
731
732    if (optional_synthesis != NULL) {
733       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
734       for (j=0;j<C*N;j++)
735          optional_synthesis[j] = FLOAT2INT16(in[j]);
736    } else {
737       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
738    }
739    RESTORE_STACK;
740    return ret;
741 }
742 #endif
743
744 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
745 {
746    va_list ap;
747    va_start(ap, request);
748    switch (request)
749    {
750       case CELT_SET_COMPLEXITY_REQUEST:
751       {
752          int value = va_arg(ap, int);
753          if (value<0 || value>10)
754             goto bad_arg;
755          if (value<=2)
756             st->pitch_enabled = 0;
757          else
758             st->pitch_enabled = 1;
759       }
760       break;
761       default:
762          goto bad_request;
763    }
764    va_end(ap);
765    return CELT_OK;
766 bad_arg:
767    va_end(ap);
768    return CELT_BAD_ARG;
769 bad_request:
770    va_end(ap);
771    return CELT_UNIMPLEMENTED;
772 }
773
774 /****************************************************************************/
775 /*                                                                          */
776 /*                                DECODER                                   */
777 /*                                                                          */
778 /****************************************************************************/
779
780
781 /** Decoder state 
782  @brief Decoder state
783  */
784 struct CELTDecoder {
785    const CELTMode *mode;
786    int frame_size;
787    int block_size;
788    int overlap;
789
790    ec_byte_buffer buf;
791    ec_enc         enc;
792
793    celt_sig_t * restrict preemph_memD;
794
795    celt_sig_t *out_mem;
796
797    celt_word16_t *oldBandE;
798    
799    int last_pitch_index;
800 };
801
802 CELTDecoder *celt_decoder_create(const CELTMode *mode)
803 {
804    int N, C;
805    CELTDecoder *st;
806
807    if (check_mode(mode) != CELT_OK)
808       return NULL;
809
810    N = mode->mdctSize;
811    C = CHANNELS(mode);
812    st = celt_alloc(sizeof(CELTDecoder));
813    
814    st->mode = mode;
815    st->frame_size = N;
816    st->block_size = N;
817    st->overlap = mode->overlap;
818
819    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
820    
821    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
822
823    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
824
825    st->last_pitch_index = 0;
826    return st;
827 }
828
829 void celt_decoder_destroy(CELTDecoder *st)
830 {
831    if (st == NULL)
832    {
833       celt_warning("NULL passed to celt_encoder_destroy");
834       return;
835    }
836    if (check_mode(st->mode) != CELT_OK)
837       return;
838
839
840    celt_free(st->out_mem);
841    
842    celt_free(st->oldBandE);
843    
844    celt_free(st->preemph_memD);
845
846    celt_free(st);
847 }
848
849 /** Handles lost packets by just copying past data with the same offset as the last
850     pitch period */
851 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
852 {
853    int c, N;
854    int pitch_index;
855    int i, len;
856    VARDECL(celt_sig_t, freq);
857    const int C = CHANNELS(st->mode);
858    int offset;
859    SAVE_STACK;
860    N = st->block_size;
861    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
862    
863    len = N+st->mode->overlap;
864 #if 0
865    pitch_index = st->last_pitch_index;
866    
867    /* Use the pitch MDCT as the "guessed" signal */
868    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
869
870 #else
871    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
872    pitch_index = MAX_PERIOD-len-pitch_index;
873    offset = MAX_PERIOD-pitch_index;
874    while (offset+len >= MAX_PERIOD)
875       offset -= pitch_index;
876    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
877    for (i=0;i<N;i++)
878       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
879 #endif
880    
881    
882    
883    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
884    /* Compute inverse MDCTs */
885    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
886
887    for (c=0;c<C;c++)
888    {
889       int j;
890       for (j=0;j<N;j++)
891       {
892          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
893                                 preemph,st->preemph_memD[c]);
894          st->preemph_memD[c] = tmp;
895          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
896       }
897    }
898    RESTORE_STACK;
899 }
900
901 #ifdef FIXED_POINT
902 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
903 {
904 #else
905 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
906 {
907 #endif
908    int i, c, N, N4;
909    int has_pitch, has_fold;
910    int pitch_index;
911    int bits;
912    ec_dec dec;
913    ec_byte_buffer buf;
914    VARDECL(celt_sig_t, freq);
915    VARDECL(celt_norm_t, X);
916    VARDECL(celt_norm_t, P);
917    VARDECL(celt_ener_t, bandE);
918    VARDECL(celt_pgain_t, gains);
919    VARDECL(int, stereo_mode);
920    VARDECL(int, fine_quant);
921    VARDECL(int, pulses);
922    VARDECL(int, offsets);
923
924    int shortBlocks;
925    int transient_time;
926    int transient_shift;
927    const int C = CHANNELS(st->mode);
928    SAVE_STACK;
929
930    if (check_mode(st->mode) != CELT_OK)
931       return CELT_INVALID_MODE;
932
933    N = st->block_size;
934    N4 = (N-st->overlap)>>1;
935
936    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
937    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
938    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
939    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
940    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
941    
942    if (check_mode(st->mode) != CELT_OK)
943    {
944       RESTORE_STACK;
945       return CELT_INVALID_MODE;
946    }
947    if (data == NULL)
948    {
949       celt_decode_lost(st, pcm);
950       RESTORE_STACK;
951       return 0;
952    }
953    
954    ec_byte_readinit(&buf,data,len);
955    ec_dec_init(&dec,&buf);
956    
957    has_pitch = ec_dec_bits(&dec, 1);
958    if (has_pitch)
959    {
960       has_fold = ec_dec_bits(&dec, 1);
961       shortBlocks = 0;
962    } else if (st->mode->nbShortMdcts > 1){
963       shortBlocks = ec_dec_bits(&dec, 1);
964       has_fold = 1;
965    } else {
966       shortBlocks = 0;
967       has_fold = 1;
968    }
969    if (shortBlocks)
970    {
971       transient_shift = ec_dec_bits(&dec, 2);
972       if (transient_shift)
973          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
974       else
975          transient_time = 0;
976    } else {
977       transient_time = -1;
978       transient_shift = 0;
979    }
980    /* Get the pitch gains */
981    
982    /* Get the pitch index */
983    if (has_pitch)
984    {
985       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
986       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
987       st->last_pitch_index = pitch_index;
988    } else {
989       /* FIXME: We could be more intelligent here and just not compute the MDCT */
990       pitch_index = 0;
991       for (i=0;i<st->mode->nbPBands;i++)
992          gains[i] = 0;
993    }
994
995    ALLOC(fine_quant, st->mode->nbEBands, int);
996    /* Get band energies */
997    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
998    
999    ALLOC(pulses, st->mode->nbEBands, int);
1000    ALLOC(offsets, st->mode->nbEBands, int);
1001    ALLOC(stereo_mode, st->mode->nbEBands, int);
1002    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1003
1004    for (i=0;i<st->mode->nbEBands;i++)
1005       offsets[i] = 0;
1006
1007    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1008    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1009    /*bits = ec_dec_tell(&dec, 0);
1010    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1011    
1012    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1013
1014
1015    if (has_pitch) 
1016    {
1017       VARDECL(celt_ener_t, bandEp);
1018       
1019       /* Pitch MDCT */
1020       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1021       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1022       compute_band_energies(st->mode, freq, bandEp);
1023       normalise_bands(st->mode, freq, P, bandEp);
1024    } else {
1025       for (i=0;i<C*N;i++)
1026          P[i] = 0;
1027    }
1028
1029    /* Apply pitch gains */
1030    pitch_quant_bands(st->mode, P, gains);
1031
1032    /* Decode fixed codebook and merge with pitch */
1033    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
1034
1035    if (C==2)
1036    {
1037       renormalise_bands(st->mode, X);
1038    }
1039    /* Synthesis */
1040    denormalise_bands(st->mode, X, freq, bandE);
1041
1042
1043    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1044    /* Compute inverse MDCTs */
1045    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1046
1047    for (c=0;c<C;c++)
1048    {
1049       int j;
1050       for (j=0;j<N;j++)
1051       {
1052          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1053                                 preemph,st->preemph_memD[c]);
1054          st->preemph_memD[c] = tmp;
1055          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1056       }
1057    }
1058
1059    {
1060       unsigned int val = 0;
1061       while (ec_dec_tell(&dec, 0) < len*8)
1062       {
1063          if (ec_dec_uint(&dec, 2) != val)
1064          {
1065             celt_warning("decode error");
1066             RESTORE_STACK;
1067             return CELT_CORRUPTED_DATA;
1068          }
1069          val = 1-val;
1070       }
1071    }
1072
1073    RESTORE_STACK;
1074    return 0;
1075    /*printf ("\n");*/
1076 }
1077
1078 #ifdef FIXED_POINT
1079 #ifndef DISABLE_FLOAT_API
1080 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1081 {
1082    int j, ret;
1083    const int C = CHANNELS(st->mode);
1084    const int N = st->block_size;
1085    VARDECL(celt_int16_t, out);
1086    SAVE_STACK;
1087    ALLOC(out, C*N, celt_int16_t);
1088
1089    ret=celt_decode(st, data, len, out);
1090
1091    for (j=0;j<C*N;j++)
1092      pcm[j]=out[j]*(1/32768.);
1093    RESTORE_STACK;
1094    return ret;
1095 }
1096 #endif /*DISABLE_FLOAT_API*/
1097 #else
1098 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1099 {
1100    int j, ret;
1101    VARDECL(celt_sig_t, out);
1102    const int C = CHANNELS(st->mode);
1103    const int N = st->block_size;
1104    SAVE_STACK;
1105    ALLOC(out, C*N, celt_sig_t);
1106
1107    ret=celt_decode_float(st, data, len, out);
1108
1109    for (j=0;j<C*N;j++)
1110      pcm[j] = FLOAT2INT16 (out[j]);
1111
1112    RESTORE_STACK;
1113    return ret;
1114 }
1115 #endif