Minor pitch handling cleanups.
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int id;
376    int pitch_index;
377    int bits;
378    int has_fold=1;
379    ec_byte_buffer buf;
380    ec_enc         enc;
381    celt_word32_t curr_power, pitch_power=0;
382    VARDECL(celt_sig_t, in);
383    VARDECL(celt_sig_t, freq);
384    VARDECL(celt_norm_t, X);
385    VARDECL(celt_norm_t, P);
386    VARDECL(celt_ener_t, bandE);
387    VARDECL(celt_pgain_t, gains);
388    VARDECL(int, stereo_mode);
389    VARDECL(int, fine_quant);
390    VARDECL(celt_word16_t, error);
391    VARDECL(int, pulses);
392    VARDECL(int, offsets);
393 #ifdef EXP_PSY
394    VARDECL(celt_word32_t, mask);
395    VARDECL(celt_word32_t, tonality);
396    VARDECL(celt_word32_t, bandM);
397    VARDECL(celt_ener_t, bandN);
398 #endif
399    int shortBlocks=0;
400    int transient_time;
401    int transient_shift;
402    const int C = CHANNELS(st->mode);
403    SAVE_STACK;
404
405    if (check_mode(st->mode) != CELT_OK)
406       return CELT_INVALID_MODE;
407
408    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
409    CELT_MEMSET(compressed, 0, nbCompressedBytes);
410    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
411    ec_enc_init(&enc,&buf);
412
413    N = st->block_size;
414    N4 = (N-st->overlap)>>1;
415    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
416
417    CELT_COPY(in, st->in_mem, C*st->overlap);
418    for (c=0;c<C;c++)
419    {
420       const celt_word16_t * restrict pcmp = pcm+c;
421       celt_sig_t * restrict inp = in+C*st->overlap+c;
422       for (i=0;i<N;i++)
423       {
424          /* Apply pre-emphasis */
425          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
426          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
427          st->preemph_memE[c] = SCALEIN(*pcmp);
428          inp += C;
429          pcmp += C;
430       }
431    }
432    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
433    
434    /* Transient handling */
435    if (st->mode->nbShortMdcts > 1)
436    {
437       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
438       {
439 #ifndef FIXED_POINT
440          float gain_1;
441 #endif
442          ec_enc_bits(&enc, 0, 1); //Pitch off
443          ec_enc_bits(&enc, 1, 1); //Transient on
444          ec_enc_bits(&enc, transient_shift, 2);
445          if (transient_shift)
446             ec_enc_uint(&enc, transient_time, N+st->overlap);
447          /* Apply the inverse shaping window */
448          if (transient_shift)
449          {
450 #ifdef FIXED_POINT
451             for (c=0;c<C;c++)
452                for (i=0;i<16;i++)
453                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
454             for (c=0;c<C;c++)
455                for (i=transient_time;i<N+st->overlap;i++)
456                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
457 #else
458             for (c=0;c<C;c++)
459                for (i=0;i<16;i++)
460                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
461             gain_1 = 1./(1<<transient_shift);
462             for (c=0;c<C;c++)
463                for (i=transient_time;i<N+st->overlap;i++)
464                   in[C*i+c] *= gain_1;
465 #endif
466          }
467          shortBlocks = 1;
468       } else {
469          transient_time = -1;
470          transient_shift = 0;
471          shortBlocks = 0;
472       }
473    } else {
474       transient_time = -1;
475       transient_shift = 0;
476       shortBlocks = 0;
477    }
478
479    /* Pitch analysis: we do it early to save on the peak stack space */
480 #ifdef EXP_PSY
481    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
482    {
483       VARDECL(celt_word16_t, X);
484       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
485       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
486       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
487    }
488 #else
489    if (st->pitch_enabled && !shortBlocks)
490    {
491       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
492    }
493 #endif
494    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
495    
496    /* Compute MDCTs */
497    compute_mdcts(st->mode, shortBlocks, in, freq);
498
499 #ifdef EXP_PSY
500    ALLOC(mask, N, celt_sig_t);
501    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
502    /*for (i=0;i<256;i++)
503       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
504    printf ("\n");*/
505 #endif
506
507    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
508    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
509    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
510    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
511    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
512
513
514    /* Band normalisation */
515    compute_band_energies(st->mode, freq, bandE);
516    normalise_bands(st->mode, freq, X, bandE);
517
518 #ifdef EXP_PSY
519    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
520    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
521    compute_noise_energies(st->mode, freq, tonality, bandN);
522
523    /*for (i=0;i<st->mode->nbEBands;i++)
524       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
525    printf ("\n");*/
526    has_fold = 0;
527    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
528       if (bandN[i] < .4*bandE[i])
529          has_fold++;
530    /*printf ("%d\n", has_fold);*/
531    if (has_fold>=2)
532       has_fold = 0;
533    else
534       has_fold = 1;
535    for (i=0;i<N;i++)
536       mask[i] = sqrt(mask[i]);
537    compute_band_energies(st->mode, mask, bandM);
538    /*for (i=0;i<st->mode->nbEBands;i++)
539       printf ("%f %f ", bandE[i], bandM[i]);
540    printf ("\n");*/
541 #endif
542
543    /* Compute MDCTs of the pitch part */
544    if (st->pitch_enabled && !shortBlocks)
545    {
546       /* Normalise the pitch vector as well (discard the energies) */
547       VARDECL(celt_ener_t, bandEp);
548       
549       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
550       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
551       compute_band_energies(st->mode, freq, bandEp);
552       normalise_bands(st->mode, freq, P, bandEp);
553       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
554    }
555
556    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
557    curr_power = bandE[0]+bandE[1]+bandE[2];
558    has_pitch = 0;
559    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
560    {
561       /* Pitch prediction */
562       compute_pitch_gain(st->mode, X, P, gains);
563       id = quant_pitch(gains, st->mode->nbPBands);
564       if (id > -1)
565          has_pitch = 1;
566    }
567    
568    if (has_pitch) 
569    {  
570       unquant_pitch(id, gains, st->mode->nbPBands);
571       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
572       ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
573       ec_enc_bits(&enc, id, 7);
574       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
575       pitch_quant_bands(st->mode, P, gains);
576    } else {
577       if (!shortBlocks)
578       {
579          ec_enc_bits(&enc, 0, 1); /* Pitch off */
580          if (st->mode->nbShortMdcts > 1)
581            ec_enc_bits(&enc, 0, 1); /* Transient off */
582       }
583       has_fold = 1;
584       /* No pitch, so we just pretend we found a gain of zero */
585       for (i=0;i<st->mode->nbPBands;i++)
586          gains[i] = 0;
587       for (i=0;i<C*N;i++)
588          P[i] = 0;
589    }
590
591 #ifdef STDIN_TUNING2
592    static int fine_quant[30];
593    static int pulses[30];
594    static int init=0;
595    if (!init)
596    {
597       for (i=0;i<st->mode->nbEBands;i++)
598          scanf("%d ", &fine_quant[i]);
599       for (i=0;i<st->mode->nbEBands;i++)
600          scanf("%d ", &pulses[i]);
601       init = 1;
602    }
603 #else
604    ALLOC(fine_quant, st->mode->nbEBands, int);
605    ALLOC(pulses, st->mode->nbEBands, int);
606 #endif
607
608    /* Bit allocation */
609    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
610    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
611    
612    ALLOC(offsets, st->mode->nbEBands, int);
613    ALLOC(stereo_mode, st->mode->nbEBands, int);
614    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
615
616    for (i=0;i<st->mode->nbEBands;i++)
617       offsets[i] = 0;
618    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
619 #ifndef STDIN_TUNING
620    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
621 #endif
622
623    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
624
625    /* Residual quantisation */
626    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
627
628    /* Re-synthesis of the coded audio if required */
629    if (st->pitch_enabled || optional_synthesis!=NULL)
630    {
631       if (C==2)
632          renormalise_bands(st->mode, X);
633       /* Synthesis */
634       denormalise_bands(st->mode, X, freq, bandE);
635       
636       
637       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
638       
639       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
640       /* De-emphasis and put everything back at the right place in the synthesis history */
641       if (optional_synthesis != NULL) {
642          for (c=0;c<C;c++)
643          {
644             int j;
645             for (j=0;j<N;j++)
646             {
647                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
648                                    preemph,st->preemph_memD[c]);
649                st->preemph_memD[c] = tmp;
650                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
651             }
652          }
653       }
654    }
655    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
656    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
657       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
658    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
659    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
660    {
661       int val = 0;
662       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
663       {
664          ec_enc_uint(&enc, val, 2);
665          val = 1-val;
666       }
667    }
668    ec_enc_done(&enc);
669    {
670       /*unsigned char *data;*/
671       int nbBytes = ec_byte_bytes(&buf);
672       if (nbBytes > nbCompressedBytes)
673       {
674          celt_warning_int ("got too many bytes:", nbBytes);
675          RESTORE_STACK;
676          return CELT_INTERNAL_ERROR;
677       }
678    }
679
680    RESTORE_STACK;
681    return nbCompressedBytes;
682 }
683
684 #ifdef FIXED_POINT
685 #ifndef DISABLE_FLOAT_API
686 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
687 {
688    int j, ret;
689    const int C = CHANNELS(st->mode);
690    const int N = st->block_size;
691    VARDECL(celt_int16_t, in);
692    SAVE_STACK;
693    ALLOC(in, C*N, celt_int16_t);
694
695    for (j=0;j<C*N;j++)
696      in[j] = FLOAT2INT16(pcm[j]);
697
698    if (optional_synthesis != NULL) {
699      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
700       for (j=0;j<C*N;j++)
701          optional_synthesis[j]=in[j]*(1/32768.);
702    } else {
703      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
704    }
705    RESTORE_STACK;
706    return ret;
707
708 }
709 #endif /*DISABLE_FLOAT_API*/
710 #else
711 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
712 {
713    int j, ret;
714    VARDECL(celt_sig_t, in);
715    const int C = CHANNELS(st->mode);
716    const int N = st->block_size;
717    SAVE_STACK;
718    ALLOC(in, C*N, celt_sig_t);
719    for (j=0;j<C*N;j++) {
720      in[j] = SCALEOUT(pcm[j]);
721    }
722
723    if (optional_synthesis != NULL) {
724       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
725       for (j=0;j<C*N;j++)
726          optional_synthesis[j] = FLOAT2INT16(in[j]);
727    } else {
728       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
729    }
730    RESTORE_STACK;
731    return ret;
732 }
733 #endif
734
735 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
736 {
737    va_list ap;
738    va_start(ap, request);
739    switch (request)
740    {
741       case CELT_SET_COMPLEXITY_REQUEST:
742       {
743          int value = va_arg(ap, int);
744          if (value<0 || value>10)
745             goto bad_arg;
746          if (value<=2)
747             st->pitch_enabled = 0;
748          else
749             st->pitch_enabled = 1;
750       }
751       break;
752       default:
753          goto bad_request;
754    }
755    va_end(ap);
756    return CELT_OK;
757 bad_arg:
758    va_end(ap);
759    return CELT_BAD_ARG;
760 bad_request:
761    va_end(ap);
762    return CELT_UNIMPLEMENTED;
763 }
764
765 /****************************************************************************/
766 /*                                                                          */
767 /*                                DECODER                                   */
768 /*                                                                          */
769 /****************************************************************************/
770
771
772 /** Decoder state 
773  @brief Decoder state
774  */
775 struct CELTDecoder {
776    const CELTMode *mode;
777    int frame_size;
778    int block_size;
779    int overlap;
780
781    ec_byte_buffer buf;
782    ec_enc         enc;
783
784    celt_sig_t * restrict preemph_memD;
785
786    celt_sig_t *out_mem;
787
788    celt_word16_t *oldBandE;
789    
790    int last_pitch_index;
791 };
792
793 CELTDecoder *celt_decoder_create(const CELTMode *mode)
794 {
795    int N, C;
796    CELTDecoder *st;
797
798    if (check_mode(mode) != CELT_OK)
799       return NULL;
800
801    N = mode->mdctSize;
802    C = CHANNELS(mode);
803    st = celt_alloc(sizeof(CELTDecoder));
804    
805    st->mode = mode;
806    st->frame_size = N;
807    st->block_size = N;
808    st->overlap = mode->overlap;
809
810    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
811    
812    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
813
814    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
815
816    st->last_pitch_index = 0;
817    return st;
818 }
819
820 void celt_decoder_destroy(CELTDecoder *st)
821 {
822    if (st == NULL)
823    {
824       celt_warning("NULL passed to celt_encoder_destroy");
825       return;
826    }
827    if (check_mode(st->mode) != CELT_OK)
828       return;
829
830
831    celt_free(st->out_mem);
832    
833    celt_free(st->oldBandE);
834    
835    celt_free(st->preemph_memD);
836
837    celt_free(st);
838 }
839
840 /** Handles lost packets by just copying past data with the same offset as the last
841     pitch period */
842 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
843 {
844    int c, N;
845    int pitch_index;
846    int i, len;
847    VARDECL(celt_sig_t, freq);
848    const int C = CHANNELS(st->mode);
849    int offset;
850    SAVE_STACK;
851    N = st->block_size;
852    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
853    
854    len = N+st->mode->overlap;
855 #if 0
856    pitch_index = st->last_pitch_index;
857    
858    /* Use the pitch MDCT as the "guessed" signal */
859    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
860
861 #else
862    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
863    pitch_index = MAX_PERIOD-len-pitch_index;
864    offset = MAX_PERIOD-pitch_index;
865    while (offset+len >= MAX_PERIOD)
866       offset -= pitch_index;
867    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
868    for (i=0;i<N;i++)
869       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
870 #endif
871    
872    
873    
874    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
875    /* Compute inverse MDCTs */
876    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
877
878    for (c=0;c<C;c++)
879    {
880       int j;
881       for (j=0;j<N;j++)
882       {
883          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
884                                 preemph,st->preemph_memD[c]);
885          st->preemph_memD[c] = tmp;
886          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
887       }
888    }
889    RESTORE_STACK;
890 }
891
892 #ifdef FIXED_POINT
893 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
894 {
895 #else
896 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
897 {
898 #endif
899    int i, c, N, N4;
900    int has_pitch, has_fold;
901    int pitch_index;
902    int bits;
903    ec_dec dec;
904    ec_byte_buffer buf;
905    VARDECL(celt_sig_t, freq);
906    VARDECL(celt_norm_t, X);
907    VARDECL(celt_norm_t, P);
908    VARDECL(celt_ener_t, bandE);
909    VARDECL(celt_pgain_t, gains);
910    VARDECL(int, stereo_mode);
911    VARDECL(int, fine_quant);
912    VARDECL(int, pulses);
913    VARDECL(int, offsets);
914
915    int shortBlocks;
916    int transient_time;
917    int transient_shift;
918    const int C = CHANNELS(st->mode);
919    SAVE_STACK;
920
921    if (check_mode(st->mode) != CELT_OK)
922       return CELT_INVALID_MODE;
923
924    N = st->block_size;
925    N4 = (N-st->overlap)>>1;
926
927    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
928    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
929    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
930    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
931    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
932    
933    if (check_mode(st->mode) != CELT_OK)
934    {
935       RESTORE_STACK;
936       return CELT_INVALID_MODE;
937    }
938    if (data == NULL)
939    {
940       celt_decode_lost(st, pcm);
941       RESTORE_STACK;
942       return 0;
943    }
944    
945    ec_byte_readinit(&buf,data,len);
946    ec_dec_init(&dec,&buf);
947    
948    has_pitch = ec_dec_bits(&dec, 1);
949    if (has_pitch)
950    {
951       has_fold = ec_dec_bits(&dec, 1);
952       shortBlocks = 0;
953    } else if (st->mode->nbShortMdcts > 1){
954       shortBlocks = ec_dec_bits(&dec, 1);
955       has_fold = 1;
956    } else {
957       shortBlocks = 0;
958       has_fold = 1;
959    }
960    if (shortBlocks)
961    {
962       transient_shift = ec_dec_bits(&dec, 2);
963       if (transient_shift)
964          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
965       else
966          transient_time = 0;
967    } else {
968       transient_time = -1;
969       transient_shift = 0;
970    }
971    
972    if (has_pitch)
973    {
974       int id;
975       /* Get the pitch gains and index */
976       id = ec_dec_bits(&dec, 7);
977       unquant_pitch(id, gains, st->mode->nbPBands);
978       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
979       st->last_pitch_index = pitch_index;
980    } else {
981       pitch_index = 0;
982       for (i=0;i<st->mode->nbPBands;i++)
983          gains[i] = 0;
984    }
985
986    ALLOC(fine_quant, st->mode->nbEBands, int);
987    /* Get band energies */
988    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
989    
990    ALLOC(pulses, st->mode->nbEBands, int);
991    ALLOC(offsets, st->mode->nbEBands, int);
992    ALLOC(stereo_mode, st->mode->nbEBands, int);
993    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
994
995    for (i=0;i<st->mode->nbEBands;i++)
996       offsets[i] = 0;
997
998    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
999    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1000    /*bits = ec_dec_tell(&dec, 0);
1001    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1002    
1003    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1004
1005
1006    if (has_pitch) 
1007    {
1008       VARDECL(celt_ener_t, bandEp);
1009       
1010       /* Pitch MDCT */
1011       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1012       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1013       compute_band_energies(st->mode, freq, bandEp);
1014       normalise_bands(st->mode, freq, P, bandEp);
1015       /* Apply pitch gains */
1016       pitch_quant_bands(st->mode, P, gains);
1017    } else {
1018       for (i=0;i<C*N;i++)
1019          P[i] = 0;
1020    }
1021
1022    /* Decode fixed codebook and merge with pitch */
1023    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1024
1025    if (C==2)
1026    {
1027       renormalise_bands(st->mode, X);
1028    }
1029    /* Synthesis */
1030    denormalise_bands(st->mode, X, freq, bandE);
1031
1032
1033    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1034    /* Compute inverse MDCTs */
1035    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1036
1037    for (c=0;c<C;c++)
1038    {
1039       int j;
1040       for (j=0;j<N;j++)
1041       {
1042          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1043                                 preemph,st->preemph_memD[c]);
1044          st->preemph_memD[c] = tmp;
1045          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1046       }
1047    }
1048
1049    {
1050       unsigned int val = 0;
1051       while (ec_dec_tell(&dec, 0) < len*8)
1052       {
1053          if (ec_dec_uint(&dec, 2) != val)
1054          {
1055             celt_warning("decode error");
1056             RESTORE_STACK;
1057             return CELT_CORRUPTED_DATA;
1058          }
1059          val = 1-val;
1060       }
1061    }
1062
1063    RESTORE_STACK;
1064    return 0;
1065    /*printf ("\n");*/
1066 }
1067
1068 #ifdef FIXED_POINT
1069 #ifndef DISABLE_FLOAT_API
1070 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1071 {
1072    int j, ret;
1073    const int C = CHANNELS(st->mode);
1074    const int N = st->block_size;
1075    VARDECL(celt_int16_t, out);
1076    SAVE_STACK;
1077    ALLOC(out, C*N, celt_int16_t);
1078
1079    ret=celt_decode(st, data, len, out);
1080
1081    for (j=0;j<C*N;j++)
1082      pcm[j]=out[j]*(1/32768.);
1083    RESTORE_STACK;
1084    return ret;
1085 }
1086 #endif /*DISABLE_FLOAT_API*/
1087 #else
1088 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1089 {
1090    int j, ret;
1091    VARDECL(celt_sig_t, out);
1092    const int C = CHANNELS(st->mode);
1093    const int N = st->block_size;
1094    SAVE_STACK;
1095    ALLOC(out, C*N, celt_sig_t);
1096
1097    ret=celt_decode_float(st, data, len, out);
1098
1099    for (j=0;j<C*N;j++)
1100      pcm[j] = FLOAT2INT16 (out[j]);
1101
1102    RESTORE_STACK;
1103    return ret;
1104 }
1105 #endif