9ebd8df2dfc1c0c42957c3fd8c7bb59ed547ef6e
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80    int pitch_available;
81
82    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
83    celt_sig_t    * restrict preemph_memD;
84
85    celt_sig_t *in_mem;
86    celt_sig_t *out_mem;
87
88    celt_word16_t *oldBandE;
89 #ifdef EXP_PSY
90    celt_word16_t *psy_mem;
91    struct PsyDecay psy;
92 #endif
93 };
94
95 CELTEncoder *celt_encoder_create(const CELTMode *mode)
96 {
97    int N, C;
98    CELTEncoder *st;
99
100    if (check_mode(mode) != CELT_OK)
101       return NULL;
102
103    N = mode->mdctSize;
104    C = mode->nbChannels;
105    st = celt_alloc(sizeof(CELTEncoder));
106    
107    st->mode = mode;
108    st->frame_size = N;
109    st->block_size = N;
110    st->overlap = mode->overlap;
111
112    st->pitch_enabled = 1;
113    st->pitch_available = 1;
114
115    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
116    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
117
118    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
119
120    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
121    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
122
123 #ifdef EXP_PSY
124    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
125    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
126 #endif
127
128    return st;
129 }
130
131 void celt_encoder_destroy(CELTEncoder *st)
132 {
133    if (st == NULL)
134    {
135       celt_warning("NULL passed to celt_encoder_destroy");
136       return;
137    }
138    if (check_mode(st->mode) != CELT_OK)
139       return;
140
141    celt_free(st->in_mem);
142    celt_free(st->out_mem);
143    
144    celt_free(st->oldBandE);
145    
146    celt_free(st->preemph_memE);
147    celt_free(st->preemph_memD);
148    
149 #ifdef EXP_PSY
150    celt_free (st->psy_mem);
151    psydecay_clear(&st->psy);
152 #endif
153    
154    celt_free(st);
155 }
156
157 static inline celt_int16_t FLOAT2INT16(float x)
158 {
159    x = x*32768.;
160    x = MAX32(x, -32768);
161    x = MIN32(x, 32767);
162    return (celt_int16_t)float2int(x);
163 }
164
165 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
166 {
167 #ifdef FIXED_POINT
168    x = PSHR32(x, SIG_SHIFT);
169    x = MAX32(x, -32768);
170    x = MIN32(x, 32767);
171    return EXTRACT16(x);
172 #else
173    return (celt_word16_t)x;
174 #endif
175 }
176
177 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
178 {
179    int c, i, n;
180    celt_word32_t ratio;
181    /* FIXME: Remove the floats here */
182    VARDECL(celt_word32_t, begin);
183    SAVE_STACK;
184    ALLOC(begin, len, celt_word32_t);
185    for (i=0;i<len;i++)
186       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
187    for (c=1;c<C;c++)
188    {
189       for (i=0;i<len;i++)
190          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
191    }
192    for (i=1;i<len;i++)
193       begin[i] = MAX32(begin[i-1],begin[i]);
194    n = -1;
195    for (i=8;i<len-8;i++)
196    {
197       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
198          n=i;
199    }
200    if (n<32)
201    {
202       n = -1;
203       ratio = 0;
204    } else {
205       ratio = DIV32(begin[len-1],1+begin[n-16]);
206    }
207    /*printf ("%d %f\n", n, ratio*ratio);*/
208    if (ratio < 0)
209       ratio = 0;
210    if (ratio > 1000)
211       ratio = 1000;
212    ratio *= ratio;
213    if (ratio < 50)
214       *transient_shift = 0;
215    else if (ratio < 256)
216       *transient_shift = 1;
217    else if (ratio < 4096)
218       *transient_shift = 2;
219    else
220       *transient_shift = 3;
221    *transient_time = n;
222    
223    RESTORE_STACK;
224    return ratio > 20;
225 }
226
227 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
228 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
229 {
230    const int C = CHANNELS(mode);
231    if (C==1 && !shortBlocks)
232    {
233       const mdct_lookup *lookup = MDCT(mode);
234       const int overlap = OVERLAP(mode);
235       mdct_forward(lookup, in, out, mode->window, overlap);
236    } else if (!shortBlocks) {
237       const mdct_lookup *lookup = MDCT(mode);
238       const int overlap = OVERLAP(mode);
239       const int N = FRAMESIZE(mode);
240       int c;
241       VARDECL(celt_word32_t, x);
242       VARDECL(celt_word32_t, tmp);
243       SAVE_STACK;
244       ALLOC(x, N+overlap, celt_word32_t);
245       ALLOC(tmp, N, celt_word32_t);
246       for (c=0;c<C;c++)
247       {
248          int j;
249          for (j=0;j<N+overlap;j++)
250             x[j] = in[C*j+c];
251          mdct_forward(lookup, x, tmp, mode->window, overlap);
252          /* Interleaving the sub-frames */
253          for (j=0;j<N;j++)
254             out[C*j+c] = tmp[j];
255       }
256       RESTORE_STACK;
257    } else {
258       const mdct_lookup *lookup = &mode->shortMdct;
259       const int overlap = mode->overlap;
260       const int N = mode->shortMdctSize;
261       int b, c;
262       VARDECL(celt_word32_t, x);
263       VARDECL(celt_word32_t, tmp);
264       SAVE_STACK;
265       ALLOC(x, N+overlap, celt_word32_t);
266       ALLOC(tmp, N, celt_word32_t);
267       for (c=0;c<C;c++)
268       {
269          int B = mode->nbShortMdcts;
270          for (b=0;b<B;b++)
271          {
272             int j;
273             for (j=0;j<N+overlap;j++)
274                x[j] = in[C*(b*N+j)+c];
275             mdct_forward(lookup, x, tmp, mode->window, overlap);
276             /* Interleaving the sub-frames */
277             for (j=0;j<N;j++)
278                out[C*(j*B+b)+c] = tmp[j];
279          }
280       }
281       RESTORE_STACK;
282    }
283 }
284
285 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
286 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
287 {
288    int c, N4;
289    const int C = CHANNELS(mode);
290    const int N = FRAMESIZE(mode);
291    const int overlap = OVERLAP(mode);
292    N4 = (N-overlap)>>1;
293    for (c=0;c<C;c++)
294    {
295       int j;
296       if (transient_shift==0 && C==1 && !shortBlocks) {
297          const mdct_lookup *lookup = MDCT(mode);
298          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
299       } else if (!shortBlocks) {
300          const mdct_lookup *lookup = MDCT(mode);
301          VARDECL(celt_word32_t, x);
302          VARDECL(celt_word32_t, tmp);
303          SAVE_STACK;
304          ALLOC(x, 2*N, celt_word32_t);
305          ALLOC(tmp, N, celt_word32_t);
306          /* De-interleaving the sub-frames */
307          for (j=0;j<N;j++)
308             tmp[j] = X[C*j+c];
309          /* Prevents problems from the imdct doing the overlap-add */
310          CELT_MEMSET(x+N4, 0, N);
311          mdct_backward(lookup, tmp, x, mode->window, overlap);
312          celt_assert(transient_shift == 0);
313          /* The first and last part would need to be set to zero if we actually
314             wanted to use them. */
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
317          for (j=0;j<overlap;j++)
318             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
319          for (j=0;j<2*N4;j++)
320             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
321          RESTORE_STACK;
322       } else {
323          int b;
324          const int N2 = mode->shortMdctSize;
325          const int B = mode->nbShortMdcts;
326          const mdct_lookup *lookup = &mode->shortMdct;
327          VARDECL(celt_word32_t, x);
328          VARDECL(celt_word32_t, tmp);
329          SAVE_STACK;
330          ALLOC(x, 2*N, celt_word32_t);
331          ALLOC(tmp, N, celt_word32_t);
332          /* Prevents problems from the imdct doing the overlap-add */
333          CELT_MEMSET(x+N4, 0, N2);
334          for (b=0;b<B;b++)
335          {
336             /* De-interleaving the sub-frames */
337             for (j=0;j<N2;j++)
338                tmp[j] = X[C*(j*B+b)+c];
339             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
340          }
341          if (transient_shift > 0)
342          {
343 #ifdef FIXED_POINT
344             for (j=0;j<16;j++)
345                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
346             for (j=transient_time;j<N+overlap;j++)
347                x[N4+j] = SHL32(x[N4+j], transient_shift);
348 #else
349             for (j=0;j<16;j++)
350                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
351             for (j=transient_time;j<N+overlap;j++)
352                x[N4+j] *= 1<<transient_shift;
353 #endif
354          }
355          /* The first and last part would need to be set to zero if we actually
356          wanted to use them. */
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
359          for (j=0;j<overlap;j++)
360             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
361          for (j=0;j<2*N4;j++)
362             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
363          RESTORE_STACK;
364       }
365    }
366 }
367
368 #ifdef FIXED_POINT
369 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
370 {
371 #else
372 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
373 {
374 #endif
375    int i, c, N, N4;
376    int has_pitch;
377    int id;
378    int pitch_index;
379    int bits;
380    int has_fold=1;
381    ec_byte_buffer buf;
382    ec_enc         enc;
383    VARDECL(celt_sig_t, in);
384    VARDECL(celt_sig_t, freq);
385    VARDECL(celt_norm_t, X);
386    VARDECL(celt_norm_t, P);
387    VARDECL(celt_ener_t, bandE);
388    VARDECL(celt_pgain_t, gains);
389    VARDECL(int, stereo_mode);
390    VARDECL(int, fine_quant);
391    VARDECL(celt_word16_t, error);
392    VARDECL(int, pulses);
393    VARDECL(int, offsets);
394 #ifdef EXP_PSY
395    VARDECL(celt_word32_t, mask);
396    VARDECL(celt_word32_t, tonality);
397    VARDECL(celt_word32_t, bandM);
398    VARDECL(celt_ener_t, bandN);
399 #endif
400    int shortBlocks=0;
401    int transient_time;
402    int transient_shift;
403    const int C = CHANNELS(st->mode);
404    SAVE_STACK;
405
406    if (check_mode(st->mode) != CELT_OK)
407       return CELT_INVALID_MODE;
408
409    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
410    CELT_MEMSET(compressed, 0, nbCompressedBytes);
411    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
412    ec_enc_init(&enc,&buf);
413
414    N = st->block_size;
415    N4 = (N-st->overlap)>>1;
416    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
417
418    CELT_COPY(in, st->in_mem, C*st->overlap);
419    for (c=0;c<C;c++)
420    {
421       const celt_word16_t * restrict pcmp = pcm+c;
422       celt_sig_t * restrict inp = in+C*st->overlap+c;
423       for (i=0;i<N;i++)
424       {
425          /* Apply pre-emphasis */
426          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
427          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
428          st->preemph_memE[c] = SCALEIN(*pcmp);
429          inp += C;
430          pcmp += C;
431       }
432    }
433    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
434    
435    /* Transient handling */
436    if (st->mode->nbShortMdcts > 1)
437    {
438       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
439       {
440 #ifndef FIXED_POINT
441          float gain_1;
442 #endif
443          ec_enc_bits(&enc, 0, 1); //Pitch off
444          ec_enc_bits(&enc, 1, 1); //Transient on
445          ec_enc_bits(&enc, transient_shift, 2);
446          if (transient_shift)
447             ec_enc_uint(&enc, transient_time, N+st->overlap);
448          /* Apply the inverse shaping window */
449          if (transient_shift)
450          {
451 #ifdef FIXED_POINT
452             for (c=0;c<C;c++)
453                for (i=0;i<16;i++)
454                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
455             for (c=0;c<C;c++)
456                for (i=transient_time;i<N+st->overlap;i++)
457                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
458 #else
459             for (c=0;c<C;c++)
460                for (i=0;i<16;i++)
461                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
462             gain_1 = 1./(1<<transient_shift);
463             for (c=0;c<C;c++)
464                for (i=transient_time;i<N+st->overlap;i++)
465                   in[C*i+c] *= gain_1;
466 #endif
467          }
468          shortBlocks = 1;
469       } else {
470          transient_time = -1;
471          transient_shift = 0;
472          shortBlocks = 0;
473       }
474    } else {
475       transient_time = -1;
476       transient_shift = 0;
477       shortBlocks = 0;
478    }
479
480    /* Pitch analysis: we do it early to save on the peak stack space */
481    /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
482    has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks);
483 #ifdef EXP_PSY
484    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
485    {
486       VARDECL(celt_word16_t, X);
487       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
488       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
489       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
490    }
491 #else
492    if (has_pitch)
493    {
494       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
495    }
496 #endif
497    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
498    
499    /* Compute MDCTs */
500    compute_mdcts(st->mode, shortBlocks, in, freq);
501
502 #ifdef EXP_PSY
503    ALLOC(mask, N, celt_sig_t);
504    compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
505    /*for (i=0;i<256;i++)
506       printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
507    printf ("\n");*/
508 #endif
509
510    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
511    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
512    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
513    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
514    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
515
516
517    /* Band normalisation */
518    compute_band_energies(st->mode, freq, bandE);
519    normalise_bands(st->mode, freq, X, bandE);
520
521 #ifdef EXP_PSY
522    ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
523    ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
524    compute_noise_energies(st->mode, freq, tonality, bandN);
525
526    /*for (i=0;i<st->mode->nbEBands;i++)
527       printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
528    printf ("\n");*/
529    has_fold = 0;
530    for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
531       if (bandN[i] < .4*bandE[i])
532          has_fold++;
533    /*printf ("%d\n", has_fold);*/
534    if (has_fold>=2)
535       has_fold = 0;
536    else
537       has_fold = 1;
538    for (i=0;i<N;i++)
539       mask[i] = sqrt(mask[i]);
540    compute_band_energies(st->mode, mask, bandM);
541    /*for (i=0;i<st->mode->nbEBands;i++)
542       printf ("%f %f ", bandE[i], bandM[i]);
543    printf ("\n");*/
544 #endif
545
546    /* Compute MDCTs of the pitch part */
547    if (has_pitch)
548    {
549       celt_word32_t curr_power, pitch_power=0;
550       /* Normalise the pitch vector as well (discard the energies) */
551       VARDECL(celt_ener_t, bandEp);
552       
553       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
554       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
555       compute_band_energies(st->mode, freq, bandEp);
556       normalise_bands(st->mode, freq, P, bandEp);
557       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
558       /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
559       curr_power = bandE[0]+bandE[1]+bandE[2];
560       id=-1;
561       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
562       {
563          /* Pitch prediction */
564          compute_pitch_gain(st->mode, X, P, gains);
565          id = quant_pitch(gains, st->mode->nbPBands);
566       } 
567       if (id == -1)
568          has_pitch = 0;
569    }
570    
571    if (has_pitch) 
572    {  
573       unquant_pitch(id, gains, st->mode->nbPBands);
574       ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
575       ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
576       ec_enc_bits(&enc, id, 7);
577       ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
578       pitch_quant_bands(st->mode, P, gains);
579    } else {
580       if (!shortBlocks)
581       {
582          ec_enc_bits(&enc, 0, 1); /* Pitch off */
583          if (st->mode->nbShortMdcts > 1)
584            ec_enc_bits(&enc, 0, 1); /* Transient off */
585       }
586       has_fold = 1;
587       /* No pitch, so we just pretend we found a gain of zero */
588       for (i=0;i<st->mode->nbPBands;i++)
589          gains[i] = 0;
590       for (i=0;i<C*N;i++)
591          P[i] = 0;
592    }
593
594 #ifdef STDIN_TUNING2
595    static int fine_quant[30];
596    static int pulses[30];
597    static int init=0;
598    if (!init)
599    {
600       for (i=0;i<st->mode->nbEBands;i++)
601          scanf("%d ", &fine_quant[i]);
602       for (i=0;i<st->mode->nbEBands;i++)
603          scanf("%d ", &pulses[i]);
604       init = 1;
605    }
606 #else
607    ALLOC(fine_quant, st->mode->nbEBands, int);
608    ALLOC(pulses, st->mode->nbEBands, int);
609 #endif
610
611    /* Bit allocation */
612    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
613    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
614    
615    ALLOC(offsets, st->mode->nbEBands, int);
616    ALLOC(stereo_mode, st->mode->nbEBands, int);
617    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
618
619    for (i=0;i<st->mode->nbEBands;i++)
620       offsets[i] = 0;
621    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
622 #ifndef STDIN_TUNING
623    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
624 #endif
625
626    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
627
628    /* Residual quantisation */
629    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
630
631    /* Re-synthesis of the coded audio if required */
632    if (st->pitch_available>0 || optional_synthesis!=NULL)
633    {
634       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
635         st->pitch_available+=st->frame_size;
636
637       if (C==2)
638          renormalise_bands(st->mode, X);
639       /* Synthesis */
640       denormalise_bands(st->mode, X, freq, bandE);
641       
642       
643       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
644       
645       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
646       /* De-emphasis and put everything back at the right place in the synthesis history */
647       if (optional_synthesis != NULL) {
648          for (c=0;c<C;c++)
649          {
650             int j;
651             for (j=0;j<N;j++)
652             {
653                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
654                                    preemph,st->preemph_memD[c]);
655                st->preemph_memD[c] = tmp;
656                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
657             }
658          }
659       }
660    }
661    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
662    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
663       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
664    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
665    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
666    {
667       int val = 0;
668       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
669       {
670          ec_enc_uint(&enc, val, 2);
671          val = 1-val;
672       }
673    }
674    ec_enc_done(&enc);
675    {
676       /*unsigned char *data;*/
677       int nbBytes = ec_byte_bytes(&buf);
678       if (nbBytes > nbCompressedBytes)
679       {
680          celt_warning_int ("got too many bytes:", nbBytes);
681          RESTORE_STACK;
682          return CELT_INTERNAL_ERROR;
683       }
684    }
685
686    RESTORE_STACK;
687    return nbCompressedBytes;
688 }
689
690 #ifdef FIXED_POINT
691 #ifndef DISABLE_FLOAT_API
692 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
693 {
694    int j, ret;
695    const int C = CHANNELS(st->mode);
696    const int N = st->block_size;
697    VARDECL(celt_int16_t, in);
698    SAVE_STACK;
699    ALLOC(in, C*N, celt_int16_t);
700
701    for (j=0;j<C*N;j++)
702      in[j] = FLOAT2INT16(pcm[j]);
703
704    if (optional_synthesis != NULL) {
705      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
706       for (j=0;j<C*N;j++)
707          optional_synthesis[j]=in[j]*(1/32768.);
708    } else {
709      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
710    }
711    RESTORE_STACK;
712    return ret;
713
714 }
715 #endif /*DISABLE_FLOAT_API*/
716 #else
717 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
718 {
719    int j, ret;
720    VARDECL(celt_sig_t, in);
721    const int C = CHANNELS(st->mode);
722    const int N = st->block_size;
723    SAVE_STACK;
724    ALLOC(in, C*N, celt_sig_t);
725    for (j=0;j<C*N;j++) {
726      in[j] = SCALEOUT(pcm[j]);
727    }
728
729    if (optional_synthesis != NULL) {
730       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
731       for (j=0;j<C*N;j++)
732          optional_synthesis[j] = FLOAT2INT16(in[j]);
733    } else {
734       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
735    }
736    RESTORE_STACK;
737    return ret;
738 }
739 #endif
740
741 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
742 {
743    va_list ap;
744    va_start(ap, request);
745    switch (request)
746    {
747       case CELT_SET_COMPLEXITY_REQUEST:
748       {
749          int value = va_arg(ap, int);
750          if (value<0 || value>10)
751             goto bad_arg;
752          if (value<=2) {
753             st->pitch_enabled = 0; 
754             st->pitch_available = 0;
755          } else {
756               st->pitch_enabled = 1;
757               if (st->pitch_available<1)
758                 st->pitch_available = 1;
759          }   
760       }
761       break;
762       case CELT_SET_LTP_REQUEST:
763       {
764          int value = va_arg(ap, int);
765          if (value<0 || value>1 || (value==1 && st->pitch_available==0))
766             goto bad_arg;
767          if (value==0)
768             st->pitch_enabled = 0;
769          else
770             st->pitch_enabled = 1;
771       }
772       break;
773       default:
774          goto bad_request;
775    }
776    va_end(ap);
777    return CELT_OK;
778 bad_arg:
779    va_end(ap);
780    return CELT_BAD_ARG;
781 bad_request:
782    va_end(ap);
783    return CELT_UNIMPLEMENTED;
784 }
785
786 /****************************************************************************/
787 /*                                                                          */
788 /*                                DECODER                                   */
789 /*                                                                          */
790 /****************************************************************************/
791
792
793 /** Decoder state 
794  @brief Decoder state
795  */
796 struct CELTDecoder {
797    const CELTMode *mode;
798    int frame_size;
799    int block_size;
800    int overlap;
801
802    ec_byte_buffer buf;
803    ec_enc         enc;
804
805    celt_sig_t * restrict preemph_memD;
806
807    celt_sig_t *out_mem;
808
809    celt_word16_t *oldBandE;
810    
811    int last_pitch_index;
812 };
813
814 CELTDecoder *celt_decoder_create(const CELTMode *mode)
815 {
816    int N, C;
817    CELTDecoder *st;
818
819    if (check_mode(mode) != CELT_OK)
820       return NULL;
821
822    N = mode->mdctSize;
823    C = CHANNELS(mode);
824    st = celt_alloc(sizeof(CELTDecoder));
825    
826    st->mode = mode;
827    st->frame_size = N;
828    st->block_size = N;
829    st->overlap = mode->overlap;
830
831    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
832    
833    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
834
835    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
836
837    st->last_pitch_index = 0;
838    return st;
839 }
840
841 void celt_decoder_destroy(CELTDecoder *st)
842 {
843    if (st == NULL)
844    {
845       celt_warning("NULL passed to celt_encoder_destroy");
846       return;
847    }
848    if (check_mode(st->mode) != CELT_OK)
849       return;
850
851
852    celt_free(st->out_mem);
853    
854    celt_free(st->oldBandE);
855    
856    celt_free(st->preemph_memD);
857
858    celt_free(st);
859 }
860
861 /** Handles lost packets by just copying past data with the same offset as the last
862     pitch period */
863 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
864 {
865    int c, N;
866    int pitch_index;
867    int i, len;
868    VARDECL(celt_sig_t, freq);
869    const int C = CHANNELS(st->mode);
870    int offset;
871    SAVE_STACK;
872    N = st->block_size;
873    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
874    
875    len = N+st->mode->overlap;
876 #if 0
877    pitch_index = st->last_pitch_index;
878    
879    /* Use the pitch MDCT as the "guessed" signal */
880    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
881
882 #else
883    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
884    pitch_index = MAX_PERIOD-len-pitch_index;
885    offset = MAX_PERIOD-pitch_index;
886    while (offset+len >= MAX_PERIOD)
887       offset -= pitch_index;
888    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
889    for (i=0;i<N;i++)
890       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
891 #endif
892    
893    
894    
895    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
896    /* Compute inverse MDCTs */
897    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
898
899    for (c=0;c<C;c++)
900    {
901       int j;
902       for (j=0;j<N;j++)
903       {
904          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
905                                 preemph,st->preemph_memD[c]);
906          st->preemph_memD[c] = tmp;
907          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
908       }
909    }
910    RESTORE_STACK;
911 }
912
913 #ifdef FIXED_POINT
914 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
915 {
916 #else
917 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
918 {
919 #endif
920    int i, c, N, N4;
921    int has_pitch, has_fold;
922    int pitch_index;
923    int bits;
924    ec_dec dec;
925    ec_byte_buffer buf;
926    VARDECL(celt_sig_t, freq);
927    VARDECL(celt_norm_t, X);
928    VARDECL(celt_norm_t, P);
929    VARDECL(celt_ener_t, bandE);
930    VARDECL(celt_pgain_t, gains);
931    VARDECL(int, stereo_mode);
932    VARDECL(int, fine_quant);
933    VARDECL(int, pulses);
934    VARDECL(int, offsets);
935
936    int shortBlocks;
937    int transient_time;
938    int transient_shift;
939    const int C = CHANNELS(st->mode);
940    SAVE_STACK;
941
942    if (check_mode(st->mode) != CELT_OK)
943       return CELT_INVALID_MODE;
944
945    N = st->block_size;
946    N4 = (N-st->overlap)>>1;
947
948    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
949    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
950    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
951    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
952    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
953    
954    if (check_mode(st->mode) != CELT_OK)
955    {
956       RESTORE_STACK;
957       return CELT_INVALID_MODE;
958    }
959    if (data == NULL)
960    {
961       celt_decode_lost(st, pcm);
962       RESTORE_STACK;
963       return 0;
964    }
965    
966    ec_byte_readinit(&buf,data,len);
967    ec_dec_init(&dec,&buf);
968    
969    has_pitch = ec_dec_bits(&dec, 1);
970    if (has_pitch)
971    {
972       has_fold = ec_dec_bits(&dec, 1);
973       shortBlocks = 0;
974    } else if (st->mode->nbShortMdcts > 1){
975       shortBlocks = ec_dec_bits(&dec, 1);
976       has_fold = 1;
977    } else {
978       shortBlocks = 0;
979       has_fold = 1;
980    }
981    if (shortBlocks)
982    {
983       transient_shift = ec_dec_bits(&dec, 2);
984       if (transient_shift)
985          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
986       else
987          transient_time = 0;
988    } else {
989       transient_time = -1;
990       transient_shift = 0;
991    }
992    
993    if (has_pitch)
994    {
995       int id;
996       /* Get the pitch gains and index */
997       id = ec_dec_bits(&dec, 7);
998       unquant_pitch(id, gains, st->mode->nbPBands);
999       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
1000       st->last_pitch_index = pitch_index;
1001    } else {
1002       pitch_index = 0;
1003       for (i=0;i<st->mode->nbPBands;i++)
1004          gains[i] = 0;
1005    }
1006
1007    ALLOC(fine_quant, st->mode->nbEBands, int);
1008    /* Get band energies */
1009    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
1010    
1011    ALLOC(pulses, st->mode->nbEBands, int);
1012    ALLOC(offsets, st->mode->nbEBands, int);
1013    ALLOC(stereo_mode, st->mode->nbEBands, int);
1014    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
1015
1016    for (i=0;i<st->mode->nbEBands;i++)
1017       offsets[i] = 0;
1018
1019    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
1020    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
1021    /*bits = ec_dec_tell(&dec, 0);
1022    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
1023    
1024    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
1025
1026
1027    if (has_pitch) 
1028    {
1029       VARDECL(celt_ener_t, bandEp);
1030       
1031       /* Pitch MDCT */
1032       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1033       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1034       compute_band_energies(st->mode, freq, bandEp);
1035       normalise_bands(st->mode, freq, P, bandEp);
1036       /* Apply pitch gains */
1037       pitch_quant_bands(st->mode, P, gains);
1038    } else {
1039       for (i=0;i<C*N;i++)
1040          P[i] = 0;
1041    }
1042
1043    /* Decode fixed codebook and merge with pitch */
1044    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
1045
1046    if (C==2)
1047    {
1048       renormalise_bands(st->mode, X);
1049    }
1050    /* Synthesis */
1051    denormalise_bands(st->mode, X, freq, bandE);
1052
1053
1054    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1055    /* Compute inverse MDCTs */
1056    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1057
1058    for (c=0;c<C;c++)
1059    {
1060       int j;
1061       for (j=0;j<N;j++)
1062       {
1063          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1064                                 preemph,st->preemph_memD[c]);
1065          st->preemph_memD[c] = tmp;
1066          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1067       }
1068    }
1069
1070    {
1071       unsigned int val = 0;
1072       while (ec_dec_tell(&dec, 0) < len*8)
1073       {
1074          if (ec_dec_uint(&dec, 2) != val)
1075          {
1076             celt_warning("decode error");
1077             RESTORE_STACK;
1078             return CELT_CORRUPTED_DATA;
1079          }
1080          val = 1-val;
1081       }
1082    }
1083
1084    RESTORE_STACK;
1085    return 0;
1086    /*printf ("\n");*/
1087 }
1088
1089 #ifdef FIXED_POINT
1090 #ifndef DISABLE_FLOAT_API
1091 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1092 {
1093    int j, ret;
1094    const int C = CHANNELS(st->mode);
1095    const int N = st->block_size;
1096    VARDECL(celt_int16_t, out);
1097    SAVE_STACK;
1098    ALLOC(out, C*N, celt_int16_t);
1099
1100    ret=celt_decode(st, data, len, out);
1101
1102    for (j=0;j<C*N;j++)
1103      pcm[j]=out[j]*(1/32768.);
1104    RESTORE_STACK;
1105    return ret;
1106 }
1107 #endif /*DISABLE_FLOAT_API*/
1108 #else
1109 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1110 {
1111    int j, ret;
1112    VARDECL(celt_sig_t, out);
1113    const int C = CHANNELS(st->mode);
1114    const int N = st->block_size;
1115    SAVE_STACK;
1116    ALLOC(out, C*N, celt_sig_t);
1117
1118    ret=celt_decode_float(st, data, len, out);
1119
1120    for (j=0;j<C*N;j++)
1121      pcm[j] = FLOAT2INT16 (out[j]);
1122
1123    RESTORE_STACK;
1124    return ret;
1125 }
1126 #endif