Moved the application of the pitch gain to (un)quant_bands(). This doesn't
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80    int pitch_available;
81
82    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
83    celt_sig_t    * restrict preemph_memD;
84
85    celt_sig_t *in_mem;
86    celt_sig_t *out_mem;
87
88    celt_word16_t *oldBandE;
89 #ifdef EXP_PSY
90    celt_word16_t *psy_mem;
91    struct PsyDecay psy;
92 #endif
93 };
94
95 CELTEncoder *celt_encoder_create(const CELTMode *mode)
96 {
97    int N, C;
98    CELTEncoder *st;
99
100    if (check_mode(mode) != CELT_OK)
101       return NULL;
102
103    N = mode->mdctSize;
104    C = mode->nbChannels;
105    st = celt_alloc(sizeof(CELTEncoder));
106    
107    st->mode = mode;
108    st->frame_size = N;
109    st->block_size = N;
110    st->overlap = mode->overlap;
111
112    st->pitch_enabled = 1;
113    st->pitch_available = 1;
114
115    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
116    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
117
118    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
119
120    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
121    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
122
123 #ifdef EXP_PSY
124    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
125    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
126 #endif
127
128    return st;
129 }
130
131 void celt_encoder_destroy(CELTEncoder *st)
132 {
133    if (st == NULL)
134    {
135       celt_warning("NULL passed to celt_encoder_destroy");
136       return;
137    }
138    if (check_mode(st->mode) != CELT_OK)
139       return;
140
141    celt_free(st->in_mem);
142    celt_free(st->out_mem);
143    
144    celt_free(st->oldBandE);
145    
146    celt_free(st->preemph_memE);
147    celt_free(st->preemph_memD);
148    
149 #ifdef EXP_PSY
150    celt_free (st->psy_mem);
151    psydecay_clear(&st->psy);
152 #endif
153    
154    celt_free(st);
155 }
156
157 static inline celt_int16_t FLOAT2INT16(float x)
158 {
159    x = x*32768.;
160    x = MAX32(x, -32768);
161    x = MIN32(x, 32767);
162    return (celt_int16_t)float2int(x);
163 }
164
165 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
166 {
167 #ifdef FIXED_POINT
168    x = PSHR32(x, SIG_SHIFT);
169    x = MAX32(x, -32768);
170    x = MIN32(x, 32767);
171    return EXTRACT16(x);
172 #else
173    return (celt_word16_t)x;
174 #endif
175 }
176
177 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
178 {
179    int c, i, n;
180    celt_word32_t ratio;
181    /* FIXME: Remove the floats here */
182    VARDECL(celt_word32_t, begin);
183    SAVE_STACK;
184    ALLOC(begin, len, celt_word32_t);
185    for (i=0;i<len;i++)
186       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
187    for (c=1;c<C;c++)
188    {
189       for (i=0;i<len;i++)
190          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
191    }
192    for (i=1;i<len;i++)
193       begin[i] = MAX32(begin[i-1],begin[i]);
194    n = -1;
195    for (i=8;i<len-8;i++)
196    {
197       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
198          n=i;
199    }
200    if (n<32)
201    {
202       n = -1;
203       ratio = 0;
204    } else {
205       ratio = DIV32(begin[len-1],1+begin[n-16]);
206    }
207    /*printf ("%d %f\n", n, ratio*ratio);*/
208    if (ratio < 0)
209       ratio = 0;
210    if (ratio > 1000)
211       ratio = 1000;
212    ratio *= ratio;
213    if (ratio < 50)
214       *transient_shift = 0;
215    else if (ratio < 256)
216       *transient_shift = 1;
217    else if (ratio < 4096)
218       *transient_shift = 2;
219    else
220       *transient_shift = 3;
221    *transient_time = n;
222    
223    RESTORE_STACK;
224    return ratio > 20;
225 }
226
227 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
228 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
229 {
230    const int C = CHANNELS(mode);
231    if (C==1 && !shortBlocks)
232    {
233       const mdct_lookup *lookup = MDCT(mode);
234       const int overlap = OVERLAP(mode);
235       mdct_forward(lookup, in, out, mode->window, overlap);
236    } else if (!shortBlocks) {
237       const mdct_lookup *lookup = MDCT(mode);
238       const int overlap = OVERLAP(mode);
239       const int N = FRAMESIZE(mode);
240       int c;
241       VARDECL(celt_word32_t, x);
242       VARDECL(celt_word32_t, tmp);
243       SAVE_STACK;
244       ALLOC(x, N+overlap, celt_word32_t);
245       ALLOC(tmp, N, celt_word32_t);
246       for (c=0;c<C;c++)
247       {
248          int j;
249          for (j=0;j<N+overlap;j++)
250             x[j] = in[C*j+c];
251          mdct_forward(lookup, x, tmp, mode->window, overlap);
252          /* Interleaving the sub-frames */
253          for (j=0;j<N;j++)
254             out[C*j+c] = tmp[j];
255       }
256       RESTORE_STACK;
257    } else {
258       const mdct_lookup *lookup = &mode->shortMdct;
259       const int overlap = mode->overlap;
260       const int N = mode->shortMdctSize;
261       int b, c;
262       VARDECL(celt_word32_t, x);
263       VARDECL(celt_word32_t, tmp);
264       SAVE_STACK;
265       ALLOC(x, N+overlap, celt_word32_t);
266       ALLOC(tmp, N, celt_word32_t);
267       for (c=0;c<C;c++)
268       {
269          int B = mode->nbShortMdcts;
270          for (b=0;b<B;b++)
271          {
272             int j;
273             for (j=0;j<N+overlap;j++)
274                x[j] = in[C*(b*N+j)+c];
275             mdct_forward(lookup, x, tmp, mode->window, overlap);
276             /* Interleaving the sub-frames */
277             for (j=0;j<N;j++)
278                out[C*(j*B+b)+c] = tmp[j];
279          }
280       }
281       RESTORE_STACK;
282    }
283 }
284
285 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
286 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
287 {
288    int c, N4;
289    const int C = CHANNELS(mode);
290    const int N = FRAMESIZE(mode);
291    const int overlap = OVERLAP(mode);
292    N4 = (N-overlap)>>1;
293    for (c=0;c<C;c++)
294    {
295       int j;
296       if (transient_shift==0 && C==1 && !shortBlocks) {
297          const mdct_lookup *lookup = MDCT(mode);
298          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
299       } else if (!shortBlocks) {
300          const mdct_lookup *lookup = MDCT(mode);
301          VARDECL(celt_word32_t, x);
302          VARDECL(celt_word32_t, tmp);
303          SAVE_STACK;
304          ALLOC(x, 2*N, celt_word32_t);
305          ALLOC(tmp, N, celt_word32_t);
306          /* De-interleaving the sub-frames */
307          for (j=0;j<N;j++)
308             tmp[j] = X[C*j+c];
309          /* Prevents problems from the imdct doing the overlap-add */
310          CELT_MEMSET(x+N4, 0, N);
311          mdct_backward(lookup, tmp, x, mode->window, overlap);
312          celt_assert(transient_shift == 0);
313          /* The first and last part would need to be set to zero if we actually
314             wanted to use them. */
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
317          for (j=0;j<overlap;j++)
318             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
319          for (j=0;j<2*N4;j++)
320             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
321          RESTORE_STACK;
322       } else {
323          int b;
324          const int N2 = mode->shortMdctSize;
325          const int B = mode->nbShortMdcts;
326          const mdct_lookup *lookup = &mode->shortMdct;
327          VARDECL(celt_word32_t, x);
328          VARDECL(celt_word32_t, tmp);
329          SAVE_STACK;
330          ALLOC(x, 2*N, celt_word32_t);
331          ALLOC(tmp, N, celt_word32_t);
332          /* Prevents problems from the imdct doing the overlap-add */
333          CELT_MEMSET(x+N4, 0, N2);
334          for (b=0;b<B;b++)
335          {
336             /* De-interleaving the sub-frames */
337             for (j=0;j<N2;j++)
338                tmp[j] = X[C*(j*B+b)+c];
339             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
340          }
341          if (transient_shift > 0)
342          {
343 #ifdef FIXED_POINT
344             for (j=0;j<16;j++)
345                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
346             for (j=transient_time;j<N+overlap;j++)
347                x[N4+j] = SHL32(x[N4+j], transient_shift);
348 #else
349             for (j=0;j<16;j++)
350                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
351             for (j=transient_time;j<N+overlap;j++)
352                x[N4+j] *= 1<<transient_shift;
353 #endif
354          }
355          /* The first and last part would need to be set to zero if we actually
356          wanted to use them. */
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
359          for (j=0;j<overlap;j++)
360             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
361          for (j=0;j<2*N4;j++)
362             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
363          RESTORE_STACK;
364       }
365    }
366 }
367
368 #ifdef FIXED_POINT
369 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
370 {
371 #else
372 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
373 {
374 #endif
375    int i, c, N, N4;
376    int has_pitch;
377    int id;
378    int pitch_index;
379    int bits;
380    int has_fold=1;
381    ec_byte_buffer buf;
382    ec_enc         enc;
383    VARDECL(celt_sig_t, in);
384    VARDECL(celt_sig_t, freq);
385    VARDECL(celt_norm_t, X);
386    VARDECL(celt_norm_t, P);
387    VARDECL(celt_ener_t, bandE);
388    VARDECL(celt_pgain_t, gains);
389    VARDECL(int, stereo_mode);
390    VARDECL(int, fine_quant);
391    VARDECL(celt_word16_t, error);
392    VARDECL(int, pulses);
393    VARDECL(int, offsets);
394 #ifdef EXP_PSY
395    VARDECL(celt_word32_t, mask);
396    VARDECL(celt_word32_t, tonality);
397    VARDECL(celt_word32_t, bandM);
398    VARDECL(celt_ener_t, bandN);
399 #endif
400    int shortBlocks=0;
401    int transient_time;
402    int transient_shift;
403    const int C = CHANNELS(st->mode);
404    SAVE_STACK;
405
406    if (check_mode(st->mode) != CELT_OK)
407       return CELT_INVALID_MODE;
408
409    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
410    CELT_MEMSET(compressed, 0, nbCompressedBytes);
411    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
412    ec_enc_init(&enc,&buf);
413
414    N = st->block_size;
415    N4 = (N-st->overlap)>>1;
416    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
417
418    CELT_COPY(in, st->in_mem, C*st->overlap);
419    for (c=0;c<C;c++)
420    {
421       const celt_word16_t * restrict pcmp = pcm+c;
422       celt_sig_t * restrict inp = in+C*st->overlap+c;
423       for (i=0;i<N;i++)
424       {
425          /* Apply pre-emphasis */
426          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
427          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
428          st->preemph_memE[c] = SCALEIN(*pcmp);
429          inp += C;
430          pcmp += C;
431       }
432    }
433    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
434    
435    /* Transient handling */
436    if (st->mode->nbShortMdcts > 1)
437    {
438       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
439       {
440 #ifndef FIXED_POINT
441          float gain_1;
442 #endif
443          ec_enc_bits(&enc, 0, 1); //Pitch off
444          ec_enc_bits(&enc, 1, 1); //Transient on
445          ec_enc_bits(&enc, transient_shift, 2);
446          if (transient_shift)
447             ec_enc_uint(&enc, transient_time, N+st->overlap);
448          /* Apply the inverse shaping window */
449          if (transient_shift)
450          {
451 #ifdef FIXED_POINT
452             for (c=0;c<C;c++)
453                for (i=0;i<16;i++)
454                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
455             for (c=0;c<C;c++)
456                for (i=transient_time;i<N+st->overlap;i++)
457                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
458 #else
459             for (c=0;c<C;c++)
460                for (i=0;i<16;i++)
461                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
462             gain_1 = 1./(1<<transient_shift);
463             for (c=0;c<C;c++)
464                for (i=transient_time;i<N+st->overlap;i++)
465                   in[C*i+c] *= gain_1;
466 #endif
467          }
468          shortBlocks = 1;
469       } else {
470          transient_time = -1;
471          transient_shift = 0;
472          shortBlocks = 0;
473       }
474    } else {
475       transient_time = -1;
476       transient_shift = 0;
477       shortBlocks = 0;
478    }
479
480    /* Pitch analysis: we do it early to save on the peak stack space */
481    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
482    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks);
483 #ifdef EXP_PSY
484    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
485    {
486       VARDECL(celt_word16_t, X);
487       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
488       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
489       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
490    }
491 #else
492    if (has_pitch)
493    {
494       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
495    }
496 #endif
497    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
498    
499    /* Compute MDCTs */
500    compute_mdcts(st->mode, shortBlocks, in, freq);
501
502 #ifdef EXP_PSY
503    ALLOC(mask, N, celt_sig_t);
504    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
505    /*for (i=0;i<256;i++)
506       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
507    printf ("\n");*/
508 #endif
509
510    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
511    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
512    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
513    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
514    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
515
516
517    /* Band normalisation */
518    compute_band_energies(st->mode, freq, bandE);
519    normalise_bands(st->mode, freq, X, bandE);
520
521 #ifdef EXP_PSY
522    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
523    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
524    compute_noise_energies(st->mode, freq, tonality, bandN);
525
526    /*for (i=0;i<st->mode->nbEBands;i++)
527       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
528    printf ("\n");*/
529    has_fold = 0;
530    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
531       if (bandN[i] < .4*bandE[i])
532          has_fold++;
533    /*printf ("%d\n", has_fold);*/
534    if (has_fold>=2)
535       has_fold = 0;
536    else
537       has_fold = 1;
538    for (i=0;i<N;i++)
539       mask[i] = sqrt(mask[i]);
540    compute_band_energies(st->mode, mask, bandM);
541    /*for (i=0;i<st->mode->nbEBands;i++)
542       printf ("%f %f ", bandE[i], bandM[i]);
543    printf ("\n");*/
544 #endif
545
546    /* Compute MDCTs of the pitch part */
547    if (has_pitch)
548    {
549       celt_word32_t curr_power, pitch_power=0;
550       /* Normalise the pitch vector as well (discard the energies) */
551       VARDECL(celt_ener_t, bandEp);
552       
553       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
554       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
555       compute_band_energies(st->mode, freq, bandEp);
556       normalise_bands(st->mode, freq, P, bandEp);
557       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
558       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
559       curr_power = bandE[0]+bandE[1]+bandE[2];
560       id=-1;
561       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
562       {
563          /* Pitch prediction */
564          compute_pitch_gain(st->mode, X, P, gains);
565          id = quant_pitch(gains, st->mode->nbPBands);
566       } 
567       if (id == -1)
568          has_pitch = 0;
569    }
570    
571    if (has_pitch) 
572    {  
573       unquant_pitch(id, gains, st->mode->nbPBands);
574       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
575       ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
576       ec_enc_bits(&enc, id, 7);
577       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
578    } else {
579       if (!shortBlocks)
580       {
581          ec_enc_bits(&enc, 0, 1); /* Pitch off */
582          if (st->mode->nbShortMdcts > 1)
583            ec_enc_bits(&enc, 0, 1); /* Transient off */
584       }
585       has_fold = 1;
586       /* No pitch, so we just pretend we found a gain of zero */
587       for (i=0;i<st->mode->nbPBands;i++)
588          gains[i] = 0;
589       for (i=0;i<C*N;i++)
590          P[i] = 0;
591    }
592
593 #ifdef STDIN_TUNING2
594    static int fine_quant[30];
595    static int pulses[30];
596    static int init=0;
597    if (!init)
598    {
599       for (i=0;i<st->mode->nbEBands;i++)
600          scanf("%d ", &fine_quant[i]);
601       for (i=0;i<st->mode->nbEBands;i++)
602          scanf("%d ", &pulses[i]);
603       init = 1;
604    }
605 #else
606    ALLOC(fine_quant, st->mode->nbEBands, int);
607    ALLOC(pulses, st->mode->nbEBands, int);
608 #endif
609
610    /* Bit allocation */
611    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
612    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
613    
614    ALLOC(offsets, st->mode->nbEBands, int);
615    ALLOC(stereo_mode, st->mode->nbEBands, int);
616    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
617
618    for (i=0;i<st->mode->nbEBands;i++)
619       offsets[i] = 0;
620    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
621 #ifndef STDIN_TUNING
622    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
623 #endif
624
625    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
626
627    /* Residual quantisation */
628    quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
629
630    /* Re-synthesis of the coded audio if required */
631    if (st->pitch_available>0 || optional_synthesis!=NULL)
632    {
633       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
634         st->pitch_available+=st->frame_size;
635
636       if (C==2)
637          renormalise_bands(st->mode, X);
638       /* Synthesis */
639       denormalise_bands(st->mode, X, freq, bandE);
640       
641       
642       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
643       
644       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
645       /* De-emphasis and put everything back at the right place in the synthesis history */
646       if (optional_synthesis != NULL) {
647          for (c=0;c<C;c++)
648          {
649             int j;
650             for (j=0;j<N;j++)
651             {
652                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
653                                    preemph,st->preemph_memD[c]);
654                st->preemph_memD[c] = tmp;
655                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
656             }
657          }
658       }
659    }
660    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
661    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
662       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
663    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
664    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
665    {
666       int val = 0;
667       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
668       {
669          ec_enc_uint(&enc, val, 2);
670          val = 1-val;
671       }
672    }
673    ec_enc_done(&enc);
674    {
675       /*unsigned char *data;*/
676       int nbBytes = ec_byte_bytes(&buf);
677       if (nbBytes > nbCompressedBytes)
678       {
679          celt_warning_int ("got too many bytes:", nbBytes);
680          RESTORE_STACK;
681          return CELT_INTERNAL_ERROR;
682       }
683    }
684
685    RESTORE_STACK;
686    return nbCompressedBytes;
687 }
688
689 #ifdef FIXED_POINT
690 #ifndef DISABLE_FLOAT_API
691 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
692 {
693    int j, ret;
694    const int C = CHANNELS(st->mode);
695    const int N = st->block_size;
696    VARDECL(celt_int16_t, in);
697    SAVE_STACK;
698    ALLOC(in, C*N, celt_int16_t);
699
700    for (j=0;j<C*N;j++)
701      in[j] = FLOAT2INT16(pcm[j]);
702
703    if (optional_synthesis != NULL) {
704      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
705       for (j=0;j<C*N;j++)
706          optional_synthesis[j]=in[j]*(1/32768.);
707    } else {
708      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
709    }
710    RESTORE_STACK;
711    return ret;
712
713 }
714 #endif /*DISABLE_FLOAT_API*/
715 #else
716 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
717 {
718    int j, ret;
719    VARDECL(celt_sig_t, in);
720    const int C = CHANNELS(st->mode);
721    const int N = st->block_size;
722    SAVE_STACK;
723    ALLOC(in, C*N, celt_sig_t);
724    for (j=0;j<C*N;j++) {
725      in[j] = SCALEOUT(pcm[j]);
726    }
727
728    if (optional_synthesis != NULL) {
729       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
730       for (j=0;j<C*N;j++)
731          optional_synthesis[j] = FLOAT2INT16(in[j]);
732    } else {
733       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
734    }
735    RESTORE_STACK;
736    return ret;
737 }
738 #endif
739
740 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
741 {
742    va_list ap;
743    va_start(ap, request);
744    switch (request)
745    {
746       case CELT_SET_COMPLEXITY_REQUEST:
747       {
748          int value = va_arg(ap, int);
749          if (value<0 || value>10)
750             goto bad_arg;
751          if (value<=2) {
752             st->pitch_enabled = 0; 
753             st->pitch_available = 0;
754          } else {
755               st->pitch_enabled = 1;
756               if (st->pitch_available<1)
757                 st->pitch_available = 1;
758          }   
759       }
760       break;
761       case CELT_SET_LTP_REQUEST:
762       {
763          int value = va_arg(ap, int);
764          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
765             goto bad_arg;
766          if (value==0)
767             st->pitch_enabled = 0;
768          else
769             st->pitch_enabled = 1;
770       }
771       break;
772       default:
773          goto bad_request;
774    }
775    va_end(ap);
776    return CELT_OK;
777 bad_arg:
778    va_end(ap);
779    return CELT_BAD_ARG;
780 bad_request:
781    va_end(ap);
782    return CELT_UNIMPLEMENTED;
783 }
784
785 /****************************************************************************/
786 /*                                                                          */
787 /*                                DECODER                                   */
788 /*                                                                          */
789 /****************************************************************************/
790
791
792 /** Decoder state 
793  @brief Decoder state
794  */
795 struct CELTDecoder {
796    const CELTMode *mode;
797    int frame_size;
798    int block_size;
799    int overlap;
800
801    ec_byte_buffer buf;
802    ec_enc         enc;
803
804    celt_sig_t * restrict preemph_memD;
805
806    celt_sig_t *out_mem;
807
808    celt_word16_t *oldBandE;
809    
810    int last_pitch_index;
811 };
812
813 CELTDecoder *celt_decoder_create(const CELTMode *mode)
814 {
815    int N, C;
816    CELTDecoder *st;
817
818    if (check_mode(mode) != CELT_OK)
819       return NULL;
820
821    N = mode->mdctSize;
822    C = CHANNELS(mode);
823    st = celt_alloc(sizeof(CELTDecoder));
824    
825    st->mode = mode;
826    st->frame_size = N;
827    st->block_size = N;
828    st->overlap = mode->overlap;
829
830    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
831    
832    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
833
834    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
835
836    st->last_pitch_index = 0;
837    return st;
838 }
839
840 void celt_decoder_destroy(CELTDecoder *st)
841 {
842    if (st == NULL)
843    {
844       celt_warning("NULL passed to celt_encoder_destroy");
845       return;
846    }
847    if (check_mode(st->mode) != CELT_OK)
848       return;
849
850
851    celt_free(st->out_mem);
852    
853    celt_free(st->oldBandE);
854    
855    celt_free(st->preemph_memD);
856
857    celt_free(st);
858 }
859
860 /** Handles lost packets by just copying past data with the same offset as the last
861     pitch period */
862 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
863 {
864    int c, N;
865    int pitch_index;
866    int i, len;
867    VARDECL(celt_sig_t, freq);
868    const int C = CHANNELS(st->mode);
869    int offset;
870    SAVE_STACK;
871    N = st->block_size;
872    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
873    
874    len = N+st->mode->overlap;
875 #if 0
876    pitch_index = st->last_pitch_index;
877    
878    /* Use the pitch MDCT as the "guessed" signal */
879    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
880
881 #else
882    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
883    pitch_index = MAX_PERIOD-len-pitch_index;
884    offset = MAX_PERIOD-pitch_index;
885    while (offset+len >= MAX_PERIOD)
886       offset -= pitch_index;
887    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
888    for (i=0;i<N;i++)
889       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
890 #endif
891    
892    
893    
894    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
895    /* Compute inverse MDCTs */
896    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
897
898    for (c=0;c<C;c++)
899    {
900       int j;
901       for (j=0;j<N;j++)
902       {
903          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
904                                 preemph,st->preemph_memD[c]);
905          st->preemph_memD[c] = tmp;
906          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
907       }
908    }
909    RESTORE_STACK;
910 }
911
912 #ifdef FIXED_POINT
913 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
914 {
915 #else
916 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
917 {
918 #endif
919    int i, c, N, N4;
920    int has_pitch, has_fold;
921    int pitch_index;
922    int bits;
923    ec_dec dec;
924    ec_byte_buffer buf;
925    VARDECL(celt_sig_t, freq);
926    VARDECL(celt_norm_t, X);
927    VARDECL(celt_norm_t, P);
928    VARDECL(celt_ener_t, bandE);
929    VARDECL(celt_pgain_t, gains);
930    VARDECL(int, stereo_mode);
931    VARDECL(int, fine_quant);
932    VARDECL(int, pulses);
933    VARDECL(int, offsets);
934
935    int shortBlocks;
936    int transient_time;
937    int transient_shift;
938    const int C = CHANNELS(st->mode);
939    SAVE_STACK;
940
941    if (check_mode(st->mode) != CELT_OK)
942       return CELT_INVALID_MODE;
943
944    N = st->block_size;
945    N4 = (N-st->overlap)>>1;
946
947    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
948    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
949    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
950    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
951    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
952    
953    if (check_mode(st->mode) != CELT_OK)
954    {
955       RESTORE_STACK;
956       return CELT_INVALID_MODE;
957    }
958    if (data == NULL)
959    {
960       celt_decode_lost(st, pcm);
961       RESTORE_STACK;
962       return 0;
963    }
964    
965    ec_byte_readinit(&buf,data,len);
966    ec_dec_init(&dec,&buf);
967    
968    has_pitch = ec_dec_bits(&dec, 1);
969    if (has_pitch)
970    {
971       has_fold = ec_dec_bits(&dec, 1);
972       shortBlocks = 0;
973    } else if (st->mode->nbShortMdcts > 1){
974       shortBlocks = ec_dec_bits(&dec, 1);
975       has_fold = 1;
976    } else {
977       shortBlocks = 0;
978       has_fold = 1;
979    }
980    if (shortBlocks)
981    {
982       transient_shift = ec_dec_bits(&dec, 2);
983       if (transient_shift)
984          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
985       else
986          transient_time = 0;
987    } else {
988       transient_time = -1;
989       transient_shift = 0;
990    }
991    
992    if (has_pitch)
993    {
994       int id;
995       /* Get the pitch gains and index */
996       id = ec_dec_bits(&dec, 7);
997       unquant_pitch(id, gains, st->mode->nbPBands);
998       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
999       st->last_pitch_index = pitch_index;
1000    } else {
1001       pitch_index = 0;
1002       for (i=0;i<st->mode->nbPBands;i++)
1003          gains[i] = 0;
1004    }
1005
1006    ALLOC(fine_quant, st->mode->nbEBands, int);
1007    /* Get band energies */
1008    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
1009    
1010    ALLOC(pulses, st->mode->nbEBands, int);
1011    ALLOC(offsets, st->mode->nbEBands, int);
1012    ALLOC(stereo_mode, st->mode->nbEBands, int);
1013    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1014
1015    for (i=0;i<st->mode->nbEBands;i++)
1016       offsets[i] = 0;
1017
1018    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1019    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1020    /*bits = ec_dec_tell(&dec, 0);
1021    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1022    
1023    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1024
1025
1026    if (has_pitch) 
1027    {
1028       VARDECL(celt_ener_t, bandEp);
1029       
1030       /* Pitch MDCT */
1031       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1032       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1033       compute_band_energies(st->mode, freq, bandEp);
1034       normalise_bands(st->mode, freq, P, bandEp);
1035       /* Apply pitch gains */
1036    } else {
1037       for (i=0;i<C*N;i++)
1038          P[i] = 0;
1039    }
1040
1041    /* Decode fixed codebook and merge with pitch */
1042    unquant_bands(st->mode, X, P, has_pitch, gains, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1043
1044    if (C==2)
1045    {
1046       renormalise_bands(st->mode, X);
1047    }
1048    /* Synthesis */
1049    denormalise_bands(st->mode, X, freq, bandE);
1050
1051
1052    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1053    /* Compute inverse MDCTs */
1054    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1055
1056    for (c=0;c<C;c++)
1057    {
1058       int j;
1059       for (j=0;j<N;j++)
1060       {
1061          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1062                                 preemph,st->preemph_memD[c]);
1063          st->preemph_memD[c] = tmp;
1064          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1065       }
1066    }
1067
1068    {
1069       unsigned int val = 0;
1070       while (ec_dec_tell(&dec, 0) < len*8)
1071       {
1072          if (ec_dec_uint(&dec, 2) != val)
1073          {
1074             celt_warning("decode error");
1075             RESTORE_STACK;
1076             return CELT_CORRUPTED_DATA;
1077          }
1078          val = 1-val;
1079       }
1080    }
1081
1082    RESTORE_STACK;
1083    return 0;
1084    /*printf ("\n");*/
1085 }
1086
1087 #ifdef FIXED_POINT
1088 #ifndef DISABLE_FLOAT_API
1089 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1090 {
1091    int j, ret;
1092    const int C = CHANNELS(st->mode);
1093    const int N = st->block_size;
1094    VARDECL(celt_int16_t, out);
1095    SAVE_STACK;
1096    ALLOC(out, C*N, celt_int16_t);
1097
1098    ret=celt_decode(st, data, len, out);
1099
1100    for (j=0;j<C*N;j++)
1101      pcm[j]=out[j]*(1/32768.);
1102    RESTORE_STACK;
1103    return ret;
1104 }
1105 #endif /*DISABLE_FLOAT_API*/
1106 #else
1107 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1108 {
1109    int j, ret;
1110    VARDECL(celt_sig_t, out);
1111    const int C = CHANNELS(st->mode);
1112    const int N = st->block_size;
1113    SAVE_STACK;
1114    ALLOC(out, C*N, celt_sig_t);
1115
1116    ret=celt_decode_float(st, data, len, out);
1117
1118    for (j=0;j<C*N;j++)
1119      pcm[j] = FLOAT2INT16 (out[j]);
1120
1121    RESTORE_STACK;
1122    return ret;
1123 }
1124 #endif