Encoder now writes data directly in the user buffer
[opus.git] / libcelt / celt.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #define CELT_C
37
38 #include "os_support.h"
39 #include "mdct.h"
40 #include <math.h>
41 #include "celt.h"
42 #include "pitch.h"
43 #include "kiss_fftr.h"
44 #include "bands.h"
45 #include "modes.h"
46 #include "entcode.h"
47 #include "quant_pitch.h"
48 #include "quant_bands.h"
49 #include "psy.h"
50 #include "rate.h"
51 #include "stack_alloc.h"
52 #include "mathops.h"
53 #include "float_cast.h"
54 #include <stdarg.h>
55
56 static const celt_word16_t preemph = QCONST16(0.8f,15);
57
58 #ifdef FIXED_POINT
59 static const celt_word16_t transientWindow[16] = {
60      279,  1106,  2454,  4276,  6510,  9081, 11900, 14872,
61    17896, 20868, 23687, 26258, 28492, 30314, 31662, 32489};
62 #else
63 static const float transientWindow[16] = {
64    0.0085135, 0.0337639, 0.0748914, 0.1304955, 0.1986827, 0.2771308, 0.3631685, 0.4538658,
65    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
66 #endif
67
68    
69 /** Encoder state 
70  @brief Encoder state
71  */
72 struct CELTEncoder {
73    const CELTMode *mode;     /**< Mode used by the encoder */
74    int frame_size;
75    int block_size;
76    int overlap;
77    int channels;
78    
79    int pitch_enabled;
80
81    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
82    celt_sig_t    * restrict preemph_memD;
83
84    celt_sig_t *in_mem;
85    celt_sig_t *out_mem;
86
87    celt_word16_t *oldBandE;
88 #ifdef EXP_PSY
89    celt_word16_t *psy_mem;
90    struct PsyDecay psy;
91 #endif
92 };
93
94 CELTEncoder *celt_encoder_create(const CELTMode *mode)
95 {
96    int N, C;
97    CELTEncoder *st;
98
99    if (check_mode(mode) != CELT_OK)
100       return NULL;
101
102    N = mode->mdctSize;
103    C = mode->nbChannels;
104    st = celt_alloc(sizeof(CELTEncoder));
105    
106    st->mode = mode;
107    st->frame_size = N;
108    st->block_size = N;
109    st->overlap = mode->overlap;
110
111    st->pitch_enabled = 1;
112
113    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
114    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
115
116    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
117
118    st->preemph_memE = (celt_word16_t*)celt_alloc(C*sizeof(celt_word16_t));
119    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
120
121 #ifdef EXP_PSY
122    st->psy_mem = celt_alloc(MAX_PERIOD*sizeof(celt_word16_t));
123    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
124 #endif
125
126    return st;
127 }
128
129 void celt_encoder_destroy(CELTEncoder *st)
130 {
131    if (st == NULL)
132    {
133       celt_warning("NULL passed to celt_encoder_destroy");
134       return;
135    }
136    if (check_mode(st->mode) != CELT_OK)
137       return;
138
139    celt_free(st->in_mem);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    
144    celt_free(st->preemph_memE);
145    celt_free(st->preemph_memD);
146    
147 #ifdef EXP_PSY
148    celt_free (st->psy_mem);
149    psydecay_clear(&st->psy);
150 #endif
151    
152    celt_free(st);
153 }
154
155 static inline celt_int16_t FLOAT2INT16(float x)
156 {
157    x = x*32768.;
158    x = MAX32(x, -32768);
159    x = MIN32(x, 32767);
160    return (celt_int16_t)float2int(x);
161 }
162
163 static inline celt_word16_t SIG2WORD16(celt_sig_t x)
164 {
165 #ifdef FIXED_POINT
166    x = PSHR32(x, SIG_SHIFT);
167    x = MAX32(x, -32768);
168    x = MIN32(x, 32767);
169    return EXTRACT16(x);
170 #else
171    return (celt_word16_t)x;
172 #endif
173 }
174
175 static int transient_analysis(celt_word32_t *in, int len, int C, int *transient_time, int *transient_shift)
176 {
177    int c, i, n;
178    celt_word32_t ratio;
179    /* FIXME: Remove the floats here */
180    VARDECL(celt_word32_t, begin);
181    SAVE_STACK;
182    ALLOC(begin, len, celt_word32_t);
183    for (i=0;i<len;i++)
184       begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
185    for (c=1;c<C;c++)
186    {
187       for (i=0;i<len;i++)
188          begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
189    }
190    for (i=1;i<len;i++)
191       begin[i] = MAX32(begin[i-1],begin[i]);
192    n = -1;
193    for (i=8;i<len-8;i++)
194    {
195       if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
196          n=i;
197    }
198    if (n<32)
199    {
200       n = -1;
201       ratio = 0;
202    } else {
203       ratio = DIV32(begin[len-1],1+begin[n-16]);
204    }
205    /*printf ("%d %f\n", n, ratio*ratio);*/
206    if (ratio < 0)
207       ratio = 0;
208    if (ratio > 1000)
209       ratio = 1000;
210    ratio *= ratio;
211    if (ratio < 50)
212       *transient_shift = 0;
213    else if (ratio < 256)
214       *transient_shift = 1;
215    else if (ratio < 4096)
216       *transient_shift = 2;
217    else
218       *transient_shift = 3;
219    *transient_time = n;
220    
221    RESTORE_STACK;
222    return ratio > 20;
223 }
224
225 /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */
226 static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * restrict in, celt_sig_t * restrict out)
227 {
228    const int C = CHANNELS(mode);
229    if (C==1 && !shortBlocks)
230    {
231       const mdct_lookup *lookup = MDCT(mode);
232       const int overlap = OVERLAP(mode);
233       mdct_forward(lookup, in, out, mode->window, overlap);
234    } else if (!shortBlocks) {
235       const mdct_lookup *lookup = MDCT(mode);
236       const int overlap = OVERLAP(mode);
237       const int N = FRAMESIZE(mode);
238       int c;
239       VARDECL(celt_word32_t, x);
240       VARDECL(celt_word32_t, tmp);
241       SAVE_STACK;
242       ALLOC(x, N+overlap, celt_word32_t);
243       ALLOC(tmp, N, celt_word32_t);
244       for (c=0;c<C;c++)
245       {
246          int j;
247          for (j=0;j<N+overlap;j++)
248             x[j] = in[C*j+c];
249          mdct_forward(lookup, x, tmp, mode->window, overlap);
250          /* Interleaving the sub-frames */
251          for (j=0;j<N;j++)
252             out[C*j+c] = tmp[j];
253       }
254       RESTORE_STACK;
255    } else {
256       const mdct_lookup *lookup = &mode->shortMdct;
257       const int overlap = mode->overlap;
258       const int N = mode->shortMdctSize;
259       int b, c;
260       VARDECL(celt_word32_t, x);
261       VARDECL(celt_word32_t, tmp);
262       SAVE_STACK;
263       ALLOC(x, N+overlap, celt_word32_t);
264       ALLOC(tmp, N, celt_word32_t);
265       for (c=0;c<C;c++)
266       {
267          int B = mode->nbShortMdcts;
268          for (b=0;b<B;b++)
269          {
270             int j;
271             for (j=0;j<N+overlap;j++)
272                x[j] = in[C*(b*N+j)+c];
273             mdct_forward(lookup, x, tmp, mode->window, overlap);
274             /* Interleaving the sub-frames */
275             for (j=0;j<N;j++)
276                out[C*(j*B+b)+c] = tmp[j];
277          }
278       }
279       RESTORE_STACK;
280    }
281 }
282
283 /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */
284 static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t *X, int transient_time, int transient_shift, celt_sig_t * restrict out_mem)
285 {
286    int c, N4;
287    const int C = CHANNELS(mode);
288    const int N = FRAMESIZE(mode);
289    const int overlap = OVERLAP(mode);
290    N4 = (N-overlap)>>1;
291    for (c=0;c<C;c++)
292    {
293       int j;
294       if (transient_shift==0 && C==1 && !shortBlocks) {
295          const mdct_lookup *lookup = MDCT(mode);
296          mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap);
297       } else if (!shortBlocks) {
298          const mdct_lookup *lookup = MDCT(mode);
299          VARDECL(celt_word32_t, x);
300          VARDECL(celt_word32_t, tmp);
301          SAVE_STACK;
302          ALLOC(x, 2*N, celt_word32_t);
303          ALLOC(tmp, N, celt_word32_t);
304          /* De-interleaving the sub-frames */
305          for (j=0;j<N;j++)
306             tmp[j] = X[C*j+c];
307          /* Prevents problems from the imdct doing the overlap-add */
308          CELT_MEMSET(x+N4, 0, N);
309          mdct_backward(lookup, tmp, x, mode->window, overlap);
310          celt_assert(transient_shift == 0);
311          /* The first and last part would need to be set to zero if we actually
312             wanted to use them. */
313          for (j=0;j<overlap;j++)
314             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
315          for (j=0;j<overlap;j++)
316             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
317          for (j=0;j<2*N4;j++)
318             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
319          RESTORE_STACK;
320       } else {
321          int b;
322          const int N2 = mode->shortMdctSize;
323          const int B = mode->nbShortMdcts;
324          const mdct_lookup *lookup = &mode->shortMdct;
325          VARDECL(celt_word32_t, x);
326          VARDECL(celt_word32_t, tmp);
327          SAVE_STACK;
328          ALLOC(x, 2*N, celt_word32_t);
329          ALLOC(tmp, N, celt_word32_t);
330          /* Prevents problems from the imdct doing the overlap-add */
331          CELT_MEMSET(x+N4, 0, N2);
332          for (b=0;b<B;b++)
333          {
334             /* De-interleaving the sub-frames */
335             for (j=0;j<N2;j++)
336                tmp[j] = X[C*(j*B+b)+c];
337             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
338          }
339          if (transient_shift > 0)
340          {
341 #ifdef FIXED_POINT
342             for (j=0;j<16;j++)
343                x[N4+transient_time+j-16] = MULT16_32_Q15(SHR16(Q15_ONE-transientWindow[j],transient_shift)+transientWindow[j], SHL32(x[N4+transient_time+j-16],transient_shift));
344             for (j=transient_time;j<N+overlap;j++)
345                x[N4+j] = SHL32(x[N4+j], transient_shift);
346 #else
347             for (j=0;j<16;j++)
348                x[N4+transient_time+j-16] *= 1+transientWindow[j]*((1<<transient_shift)-1);
349             for (j=transient_time;j<N+overlap;j++)
350                x[N4+j] *= 1<<transient_shift;
351 #endif
352          }
353          /* The first and last part would need to be set to zero if we actually
354          wanted to use them. */
355          for (j=0;j<overlap;j++)
356             out_mem[C*(MAX_PERIOD-N)+C*j+c] += x[j+N4];
357          for (j=0;j<overlap;j++)
358             out_mem[C*(MAX_PERIOD)+C*(overlap-j-1)+c] = x[2*N-j-N4-1];
359          for (j=0;j<2*N4;j++)
360             out_mem[C*(MAX_PERIOD-N)+C*(j+overlap)+c] = x[j+N4+overlap];
361          RESTORE_STACK;
362       }
363    }
364 }
365
366 #ifdef FIXED_POINT
367 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
368 {
369 #else
370 int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_sig_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
371 {
372 #endif
373    int i, c, N, N4;
374    int has_pitch;
375    int pitch_index;
376    int bits;
377    ec_byte_buffer buf;
378    ec_enc         enc;
379    celt_word32_t curr_power, pitch_power=0;
380    VARDECL(celt_sig_t, in);
381    VARDECL(celt_sig_t, freq);
382    VARDECL(celt_norm_t, X);
383    VARDECL(celt_norm_t, P);
384    VARDECL(celt_ener_t, bandE);
385    VARDECL(celt_pgain_t, gains);
386    VARDECL(int, stereo_mode);
387    VARDECL(int, fine_quant);
388    VARDECL(celt_word16_t, error);
389    VARDECL(int, pulses);
390    VARDECL(int, offsets);
391 #ifdef EXP_PSY
392    VARDECL(celt_word32_t, mask);
393 #endif
394    int shortBlocks=0;
395    int transient_time;
396    int transient_shift;
397    const int C = CHANNELS(st->mode);
398    SAVE_STACK;
399
400    if (check_mode(st->mode) != CELT_OK)
401       return CELT_INVALID_MODE;
402
403    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
404    CELT_MEMSET(compressed, 0, nbCompressedBytes);
405    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
406    ec_enc_init(&enc,&buf);
407
408    N = st->block_size;
409    N4 = (N-st->overlap)>>1;
410    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
411
412    CELT_COPY(in, st->in_mem, C*st->overlap);
413    for (c=0;c<C;c++)
414    {
415       const celt_word16_t * restrict pcmp = pcm+c;
416       celt_sig_t * restrict inp = in+C*st->overlap+c;
417       for (i=0;i<N;i++)
418       {
419          /* Apply pre-emphasis */
420          celt_sig_t tmp = SCALEIN(SHL32(EXTEND32(*pcmp), SIG_SHIFT));
421          *inp = SUB32(tmp, SHR32(MULT16_16(preemph,st->preemph_memE[c]),3));
422          st->preemph_memE[c] = SCALEIN(*pcmp);
423          inp += C;
424          pcmp += C;
425       }
426    }
427    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
428    
429    if (st->mode->nbShortMdcts > 1)
430    {
431       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
432       {
433 #ifndef FIXED_POINT
434          float gain_1;
435 #endif
436          ec_enc_bits(&enc, 0, 1); //Pitch off
437          ec_enc_bits(&enc, 1, 1); //Transient on
438          ec_enc_bits(&enc, transient_shift, 2);
439          if (transient_shift)
440             ec_enc_uint(&enc, transient_time, N+st->overlap);
441          if (transient_shift)
442          {
443 #ifdef FIXED_POINT
444             for (c=0;c<C;c++)
445                for (i=0;i<16;i++)
446                   in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
447             for (c=0;c<C;c++)
448                for (i=transient_time;i<N+st->overlap;i++)
449                   in[C*i+c] = SHR32(in[C*i+c], transient_shift);
450 #else
451             for (c=0;c<C;c++)
452                for (i=0;i<16;i++)
453                   in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
454             gain_1 = 1./(1<<transient_shift);
455             for (c=0;c<C;c++)
456                for (i=transient_time;i<N+st->overlap;i++)
457                   in[C*i+c] *= gain_1;
458 #endif
459          }
460          shortBlocks = 1;
461       } else {
462          transient_time = -1;
463          transient_shift = 0;
464          shortBlocks = 0;
465       }
466    } else {
467       transient_time = -1;
468       transient_shift = 0;
469       shortBlocks = 0;
470    }
471    /* Pitch analysis: we do it early to save on the peak stack space */
472    if (st->pitch_enabled && !shortBlocks)
473    {
474 #ifdef EXP_PSY
475       VARDECL(celt_word16_t, X);
476       ALLOC(X, MAX_PERIOD/2, celt_word16_t);
477       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
478       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD);
479 #else
480       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
481 #endif
482    }
483    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
484    
485    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
486    /* Compute MDCTs */
487    compute_mdcts(st->mode, shortBlocks, in, freq);
488
489 #ifdef EXP_PSY
490    /*CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
491    for (i=0;i<N;i++)
492       st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
493    for (c=1;c<C;c++)
494       for (i=0;i<N;i++)
495          st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
496    */
497    ALLOC(mask, N, celt_sig_t);
498    compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
499
500    /* Invert and stretch the mask to length of X 
501       For some reason, I get better results by using the sqrt instead,
502       although there's no valid reason to. Must investigate further */
503    for (i=0;i<C*N;i++)
504       mask[i] = 1/(.1+mask[i]);
505 #endif
506    
507    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
508    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
509    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
510    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
511    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
512
513    /*printf ("%f %f\n", curr_power, pitch_power);*/
514    /*int j;
515    for (j=0;j<B*N;j++)
516       printf ("%f ", X[j]);
517    for (j=0;j<B*N;j++)
518       printf ("%f ", P[j]);
519    printf ("\n");*/
520
521    /* Band normalisation */
522    compute_band_energies(st->mode, freq, bandE);
523    normalise_bands(st->mode, freq, X, bandE);
524    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
525    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
526
527    /* Compute MDCTs of the pitch part */
528    if (st->pitch_enabled && !shortBlocks)
529    {
530       /* Normalise the pitch vector as well (discard the energies) */
531       VARDECL(celt_ener_t, bandEp);
532       
533       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
534       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, celt_ener_t);
535       compute_band_energies(st->mode, freq, bandEp);
536       normalise_bands(st->mode, freq, P, bandEp);
537       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
538    }
539    curr_power = bandE[0]+bandE[1]+bandE[2];
540    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
541    if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
542    {
543       /* Simulates intensity stereo */
544       /*for (i=30;i<N*B;i++)
545          X[i*C+1] = P[i*C+1] = 0;*/
546
547       /* Pitch prediction */
548       compute_pitch_gain(st->mode, X, P, gains);
549       has_pitch = quant_pitch(gains, st->mode->nbPBands, &enc);
550       if (has_pitch)
551          ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
552       else if (st->mode->nbShortMdcts > 1)
553          ec_enc_bits(&enc, 0, 1); //Transient off
554    } else {
555       if (!shortBlocks)
556       {
557          ec_enc_bits(&enc, 0, 1); //Pitch off
558          if (st->mode->nbShortMdcts > 1)
559            ec_enc_bits(&enc, 0, 1); //Transient off
560       }
561       /* No pitch, so we just pretend we found a gain of zero */
562       for (i=0;i<st->mode->nbPBands;i++)
563          gains[i] = 0;
564       for (i=0;i<C*N;i++)
565          P[i] = 0;
566    }
567
568 #ifdef STDIN_TUNING2
569    static int fine_quant[30];
570    static int pulses[30];
571    static int init=0;
572    if (!init)
573    {
574       for (i=0;i<st->mode->nbEBands;i++)
575          scanf("%d ", &fine_quant[i]);
576       for (i=0;i<st->mode->nbEBands;i++)
577          scanf("%d ", &pulses[i]);
578       init = 1;
579    }
580 #else
581    ALLOC(fine_quant, st->mode->nbEBands, int);
582    ALLOC(pulses, st->mode->nbEBands, int);
583 #endif
584    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
585    quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
586    
587    ALLOC(offsets, st->mode->nbEBands, int);
588    ALLOC(stereo_mode, st->mode->nbEBands, int);
589    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
590
591    for (i=0;i<st->mode->nbEBands;i++)
592       offsets[i] = 0;
593    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
594 #ifndef STDIN_TUNING
595    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
596 #endif
597    /*for (i=0;i<st->mode->nbEBands;i++)
598       printf("%d ", fine_quant[i]);
599    for (i=0;i<st->mode->nbEBands;i++)
600       printf("%d ", pulses[i]);
601    printf ("\n");*/
602    /*bits = ec_enc_tell(&st->enc, 0);
603    compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
604    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
605
606    pitch_quant_bands(st->mode, P, gains);
607
608    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
609
610    /* Residual quantisation */
611    quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &enc);
612    
613    if (st->pitch_enabled || optional_synthesis!=NULL)
614    {
615       if (C==2)
616          renormalise_bands(st->mode, X);
617       /* Synthesis */
618       denormalise_bands(st->mode, X, freq, bandE);
619       
620       
621       CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
622       
623       compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
624       /* De-emphasis and put everything back at the right place in the synthesis history */
625       if (optional_synthesis != NULL) {
626          for (c=0;c<C;c++)
627          {
628             int j;
629             for (j=0;j<N;j++)
630             {
631                celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
632                                    preemph,st->preemph_memD[c]);
633                st->preemph_memD[c] = tmp;
634                optional_synthesis[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
635             }
636          }
637       }
638    }
639    /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
640    /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
641       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
642    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
643    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
644    {
645       int val = 0;
646       while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
647       {
648          ec_enc_uint(&enc, val, 2);
649          val = 1-val;
650       }
651    }
652    ec_enc_done(&enc);
653    {
654       /*unsigned char *data;*/
655       int nbBytes = ec_byte_bytes(&buf);
656       if (nbBytes > nbCompressedBytes)
657       {
658          celt_warning_int ("got too many bytes:", nbBytes);
659          RESTORE_STACK;
660          return CELT_INTERNAL_ERROR;
661       }
662       /*printf ("%d\n", *nbBytes);*/
663       /*data = ec_byte_get_buffer(&buf);
664       for (i=0;i<nbBytes;i++)
665          compressed[i] = data[i];
666       for (i=nbBytes;i<nbCompressedBytes;i++)
667          compressed[i] = 0;*/
668    }
669
670    RESTORE_STACK;
671    return nbCompressedBytes;
672 }
673
674 #ifdef FIXED_POINT
675 #ifndef DISABLE_FLOAT_API
676 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
677 {
678    int j, ret;
679    const int C = CHANNELS(st->mode);
680    const int N = st->block_size;
681    VARDECL(celt_int16_t, in);
682    SAVE_STACK;
683    ALLOC(in, C*N, celt_int16_t);
684
685    for (j=0;j<C*N;j++)
686      in[j] = FLOAT2INT16(pcm[j]);
687
688    if (optional_synthesis != NULL) {
689      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
690    /*Converts backwards for inplace operation*/
691       for (j=0;j=C*N;j++)
692          optional_synthesis[j]=in[j]*(1/32768.);
693    } else {
694      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
695    }
696    RESTORE_STACK;
697    return ret;
698
699 }
700 #endif /*DISABLE_FLOAT_API*/
701 #else
702 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
703 {
704    int j, ret;
705    VARDECL(celt_sig_t, in);
706    const int C = CHANNELS(st->mode);
707    const int N = st->block_size;
708    SAVE_STACK;
709    ALLOC(in, C*N, celt_sig_t);
710    for (j=0;j<C*N;j++) {
711      in[j] = SCALEOUT(pcm[j]);
712    }
713
714    if (optional_synthesis != NULL) {
715       ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes);
716       for (j=0;j<C*N;j++)
717          optional_synthesis[j] = FLOAT2INT16(in[j]);
718    } else {
719       ret = celt_encode_float(st,in,NULL,compressed,nbCompressedBytes);
720    }
721    RESTORE_STACK;
722    return ret;
723 }
724 #endif
725
726 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
727 {
728    va_list ap;
729    va_start(ap, request);
730    switch (request)
731    {
732       case CELT_SET_COMPLEXITY_REQUEST:
733       {
734          int value = va_arg(ap, int);
735          if (value<0 || value>10)
736             goto bad_arg;
737          if (value<=2)
738             st->pitch_enabled = 0;
739          else
740             st->pitch_enabled = 1;
741       }
742       break;
743       default:
744          goto bad_request;
745    }
746    va_end(ap);
747    return CELT_OK;
748 bad_arg:
749    va_end(ap);
750    return CELT_BAD_ARG;
751 bad_request:
752    va_end(ap);
753    return CELT_UNIMPLEMENTED;
754 }
755
756 /****************************************************************************/
757 /*                                                                          */
758 /*                                DECODER                                   */
759 /*                                                                          */
760 /****************************************************************************/
761
762
763 /** Decoder state 
764  @brief Decoder state
765  */
766 struct CELTDecoder {
767    const CELTMode *mode;
768    int frame_size;
769    int block_size;
770    int overlap;
771
772    ec_byte_buffer buf;
773    ec_enc         enc;
774
775    celt_sig_t * restrict preemph_memD;
776
777    celt_sig_t *out_mem;
778
779    celt_word16_t *oldBandE;
780    
781    int last_pitch_index;
782 };
783
784 CELTDecoder *celt_decoder_create(const CELTMode *mode)
785 {
786    int N, C;
787    CELTDecoder *st;
788
789    if (check_mode(mode) != CELT_OK)
790       return NULL;
791
792    N = mode->mdctSize;
793    C = CHANNELS(mode);
794    st = celt_alloc(sizeof(CELTDecoder));
795    
796    st->mode = mode;
797    st->frame_size = N;
798    st->block_size = N;
799    st->overlap = mode->overlap;
800
801    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
802    
803    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
804
805    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
806
807    st->last_pitch_index = 0;
808    return st;
809 }
810
811 void celt_decoder_destroy(CELTDecoder *st)
812 {
813    if (st == NULL)
814    {
815       celt_warning("NULL passed to celt_encoder_destroy");
816       return;
817    }
818    if (check_mode(st->mode) != CELT_OK)
819       return;
820
821
822    celt_free(st->out_mem);
823    
824    celt_free(st->oldBandE);
825    
826    celt_free(st->preemph_memD);
827
828    celt_free(st);
829 }
830
831 /** Handles lost packets by just copying past data with the same offset as the last
832     pitch period */
833 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
834 {
835    int c, N;
836    int pitch_index;
837    int i, len;
838    VARDECL(celt_sig_t, freq);
839    const int C = CHANNELS(st->mode);
840    int offset;
841    SAVE_STACK;
842    N = st->block_size;
843    ALLOC(freq,C*N, celt_sig_t);         /**< Interleaved signal MDCTs */
844    
845    len = N+st->mode->overlap;
846 #if 0
847    pitch_index = st->last_pitch_index;
848    
849    /* Use the pitch MDCT as the "guessed" signal */
850    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
851
852 #else
853    find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
854    pitch_index = MAX_PERIOD-len-pitch_index;
855    offset = MAX_PERIOD-pitch_index;
856    while (offset+len >= MAX_PERIOD)
857       offset -= pitch_index;
858    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
859    for (i=0;i<N;i++)
860       freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
861 #endif
862    
863    
864    
865    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N));
866    /* Compute inverse MDCTs */
867    compute_inv_mdcts(st->mode, 0, freq, -1, 1, st->out_mem);
868
869    for (c=0;c<C;c++)
870    {
871       int j;
872       for (j=0;j<N;j++)
873       {
874          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
875                                 preemph,st->preemph_memD[c]);
876          st->preemph_memD[c] = tmp;
877          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
878       }
879    }
880    RESTORE_STACK;
881 }
882
883 #ifdef FIXED_POINT
884 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
885 {
886 #else
887 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
888 {
889 #endif
890    int i, c, N, N4;
891    int has_pitch, has_fold;
892    int pitch_index;
893    int bits;
894    ec_dec dec;
895    ec_byte_buffer buf;
896    VARDECL(celt_sig_t, freq);
897    VARDECL(celt_norm_t, X);
898    VARDECL(celt_norm_t, P);
899    VARDECL(celt_ener_t, bandE);
900    VARDECL(celt_pgain_t, gains);
901    VARDECL(int, stereo_mode);
902    VARDECL(int, fine_quant);
903    VARDECL(int, pulses);
904    VARDECL(int, offsets);
905
906    int shortBlocks;
907    int transient_time;
908    int transient_shift;
909    const int C = CHANNELS(st->mode);
910    SAVE_STACK;
911
912    if (check_mode(st->mode) != CELT_OK)
913       return CELT_INVALID_MODE;
914
915    N = st->block_size;
916    N4 = (N-st->overlap)>>1;
917
918    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
919    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
920    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
921    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
922    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
923    
924    if (check_mode(st->mode) != CELT_OK)
925    {
926       RESTORE_STACK;
927       return CELT_INVALID_MODE;
928    }
929    if (data == NULL)
930    {
931       celt_decode_lost(st, pcm);
932       RESTORE_STACK;
933       return 0;
934    }
935    
936    ec_byte_readinit(&buf,data,len);
937    ec_dec_init(&dec,&buf);
938    
939    has_pitch = ec_dec_bits(&dec, 1);
940    if (has_pitch)
941    {
942       has_fold = ec_dec_bits(&dec, 1);
943       shortBlocks = 0;
944    } else if (st->mode->nbShortMdcts > 1){
945       shortBlocks = ec_dec_bits(&dec, 1);
946       has_fold = 1;
947    } else {
948       shortBlocks = 0;
949       has_fold = 1;
950    }
951    if (shortBlocks)
952    {
953       transient_shift = ec_dec_bits(&dec, 2);
954       if (transient_shift)
955          transient_time = ec_dec_uint(&dec, N+st->mode->overlap);
956       else
957          transient_time = 0;
958    } else {
959       transient_time = -1;
960       transient_shift = 0;
961    }
962    /* Get the pitch gains */
963    
964    /* Get the pitch index */
965    if (has_pitch)
966    {
967       has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
968       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
969       st->last_pitch_index = pitch_index;
970    } else {
971       /* FIXME: We could be more intelligent here and just not compute the MDCT */
972       pitch_index = 0;
973       for (i=0;i<st->mode->nbPBands;i++)
974          gains[i] = 0;
975    }
976
977    ALLOC(fine_quant, st->mode->nbEBands, int);
978    /* Get band energies */
979    unquant_coarse_energy(st->mode, bandE, st->oldBandE, len*8/3, st->mode->prob, &dec);
980    
981    ALLOC(pulses, st->mode->nbEBands, int);
982    ALLOC(offsets, st->mode->nbEBands, int);
983    ALLOC(stereo_mode, st->mode->nbEBands, int);
984    stereo_decision(st->mode, X, stereo_mode, st->mode->nbEBands);
985
986    for (i=0;i<st->mode->nbEBands;i++)
987       offsets[i] = 0;
988
989    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
990    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
991    /*bits = ec_dec_tell(&dec, 0);
992    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
993    
994    unquant_fine_energy(st->mode, bandE, st->oldBandE, fine_quant, &dec);
995
996
997    if (has_pitch) 
998    {
999       VARDECL(celt_ener_t, bandEp);
1000       
1001       /* Pitch MDCT */
1002       compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, freq);
1003       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
1004       compute_band_energies(st->mode, freq, bandEp);
1005       normalise_bands(st->mode, freq, P, bandEp);
1006    } else {
1007       for (i=0;i<C*N;i++)
1008          P[i] = 0;
1009    }
1010
1011    /* Apply pitch gains */
1012    pitch_quant_bands(st->mode, P, gains);
1013
1014    /* Decode fixed codebook and merge with pitch */
1015    unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
1016
1017    if (C==2)
1018    {
1019       renormalise_bands(st->mode, X);
1020    }
1021    /* Synthesis */
1022    denormalise_bands(st->mode, X, freq, bandE);
1023
1024
1025    CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
1026    /* Compute inverse MDCTs */
1027    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
1028
1029    for (c=0;c<C;c++)
1030    {
1031       int j;
1032       for (j=0;j<N;j++)
1033       {
1034          celt_sig_t tmp = MAC16_32_Q15(st->out_mem[C*(MAX_PERIOD-N)+C*j+c],
1035                                 preemph,st->preemph_memD[c]);
1036          st->preemph_memD[c] = tmp;
1037          pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
1038       }
1039    }
1040
1041    {
1042       unsigned int val = 0;
1043       while (ec_dec_tell(&dec, 0) < len*8)
1044       {
1045          if (ec_dec_uint(&dec, 2) != val)
1046          {
1047             celt_warning("decode error");
1048             RESTORE_STACK;
1049             return CELT_CORRUPTED_DATA;
1050          }
1051          val = 1-val;
1052       }
1053    }
1054
1055    RESTORE_STACK;
1056    return 0;
1057    /*printf ("\n");*/
1058 }
1059
1060 #ifdef FIXED_POINT
1061 #ifndef DISABLE_FLOAT_API
1062 int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
1063 {
1064    int j, ret;
1065    const int C = CHANNELS(st->mode);
1066    const int N = st->block_size;
1067    VARDECL(celt_int16_t, out);
1068    SAVE_STACK;
1069    ALLOC(out, C*N, celt_int16_t);
1070
1071    ret=celt_decode(st, data, len, out);
1072
1073    for (j=0;j<C*N;j++)
1074      pcm[j]=out[j]*(1/32768.);
1075    RESTORE_STACK;
1076    return ret;
1077 }
1078 #endif /*DISABLE_FLOAT_API*/
1079 #else
1080 int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
1081 {
1082    int j, ret;
1083    VARDECL(celt_sig_t, out);
1084    const int C = CHANNELS(st->mode);
1085    const int N = st->block_size;
1086    SAVE_STACK;
1087    ALLOC(out, C*N, celt_sig_t);
1088
1089    ret=celt_decode_float(st, data, len, out);
1090
1091    for (j=0;j<C*N;j++)
1092      pcm[j] = FLOAT2INT16 (out[j]);
1093
1094    RESTORE_STACK;
1095    return ret;
1096 }
1097 #endif