bit of cleaning up, default sampling rate
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
148 {
149    int i, c;
150    for (c=0;c<C;c++)
151    {
152       for (i=0;i<B;i++)
153       {
154          int j;
155          float x[2*N];
156          float tmp[N];
157          for (j=0;j<2*N;j++)
158             x[j] = window[j]*in[C*i*N+C*j+c];
159          mdct_forward(mdct_lookup, x, tmp);
160          /* Interleaving the sub-frames */
161          for (j=0;j<N;j++)
162             out[C*B*j+C*i+c] = tmp[j];
163       }
164    }
165 }
166
167 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
168 {
169    int i, c, N4;
170    N4 = (N-overlap)/2;
171    for (c=0;c<C;c++)
172    {
173       for (i=0;i<B;i++)
174       {
175          int j;
176          float x[2*N];
177          float tmp[N];
178          /* De-interleaving the sub-frames */
179          for (j=0;j<N;j++)
180             tmp[j] = X[C*B*j+C*i+c];
181          mdct_backward(mdct_lookup, tmp, x);
182          for (j=0;j<2*N;j++)
183             x[j] = window[j]*x[j];
184          for (j=0;j<overlap;j++)
185             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
186          for (j=0;j<2*N4;j++)
187             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
188          for (j=0;j<overlap;j++)
189             mdct_overlap[C*j+c] = x[N+N4+j];
190       }
191    }
192 }
193
194 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
195 {
196    int i, c, N, B, C, N4;
197    N = st->block_size;
198    B = st->nb_blocks;
199    C = st->mode->nbChannels;
200    float in[(B+1)*C*N];
201
202    float X[B*C*N];         /**< Interleaved signal MDCTs */
203    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
204    float mask[B*C*N];      /**< Masking curve */
205    float bandE[st->mode->nbEBands*C];
206    float gains[st->mode->nbPBands];
207    int pitch_index;
208
209    N4 = (N-st->overlap)/2;
210
211    for (c=0;c<C;c++)
212    {
213       for (i=0;i<N4;i++)
214          in[C*i+c] = 0;
215       for (i=0;i<st->overlap;i++)
216          in[C*(i+N4)+c] = st->in_mem[C*i+c];
217       for (i=0;i<B*N;i++)
218       {
219          float tmp = pcm[C*i+c];
220          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
221          st->preemph_memE[c] = tmp;
222       }
223       for (i=N*(B+1)-N4;i<N*(B+1);i++)
224          in[C*i+c] = 0;
225       for (i=0;i<st->overlap;i++)
226          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
227    }
228    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
229    /* Compute MDCTs */
230    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
231
232    compute_mdct_masking(X, mask, B*C*N, st->Fs);
233
234    /* Invert and stretch the mask to length of X 
235       For some reason, I get better results by using the sqrt instead,
236       although there's no valid reason to. Must investigate further */
237    for (i=0;i<B*C*N;i++)
238       mask[i] = 1/(.1+mask[i]);
239
240    /* Pitch analysis */
241    for (c=0;c<C;c++)
242    {
243       for (i=0;i<N;i++)
244       {
245          in[C*i+c] *= st->window[i];
246          in[C*(B*N+i)+c] *= st->window[N+i];
247       }
248    }
249    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
250    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
251    
252    /* Compute MDCTs of the pitch part */
253    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
254    
255    /*int j;
256    for (j=0;j<B*N;j++)
257       printf ("%f ", X[j]);
258    for (j=0;j<B*N;j++)
259       printf ("%f ", P[j]);
260    printf ("\n");*/
261
262    /* Band normalisation */
263    compute_band_energies(st->mode, X, bandE);
264    normalise_bands(st->mode, X, bandE);
265    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
266    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
267
268    /* Normalise the pitch vector as well (discard the energies) */
269    {
270       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
271       compute_band_energies(st->mode, P, bandEp);
272       normalise_bands(st->mode, P, bandEp);
273    }
274
275    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
276
277    if (C==2)
278    {
279       stereo_mix(st->mode, X, bandE, 1);
280       stereo_mix(st->mode, P, bandE, 1);
281    }
282    /* Simulates intensity stereo */
283    //for (i=30;i<N*B;i++)
284    //   X[i*C+1] = P[i*C+1] = 0;
285    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
286
287    /* Pitch prediction */
288    compute_pitch_gain(st->mode, X, P, gains, bandE);
289    quant_pitch(gains, st->mode->nbPBands, &st->enc);
290    pitch_quant_bands(st->mode, X, P, gains);
291
292    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
293    /* Compute residual that we're going to encode */
294    for (i=0;i<B*C*N;i++)
295       X[i] -= P[i];
296
297    /*float sum=0;
298    for (i=0;i<B*N;i++)
299       sum += X[i]*X[i];
300    printf ("%f\n", sum);*/
301    /* Residual quantisation */
302    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
303    
304    if (C==2)
305       stereo_mix(st->mode, X, bandE, -1);
306
307    renormalise_bands(st->mode, X);
308    /* Synthesis */
309    denormalise_bands(st->mode, X, bandE);
310
311
312    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
313
314    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
315    /* De-emphasis and put everything back at the right place in the synthesis history */
316    for (c=0;c<C;c++)
317    {
318       for (i=0;i<B;i++)
319       {
320          int j;
321          for (j=0;j<N;j++)
322          {
323             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
324             st->preemph_memD[c] = tmp;
325             if (tmp > 32767) tmp = 32767;
326             if (tmp < -32767) tmp = -32767;
327             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
328          }
329       }
330    }
331    
332    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
333    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
334    {
335       int val = 0;
336       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
337       {
338          ec_enc_uint(&st->enc, val, 2);
339          val = 1-val;
340       }
341    }
342    ec_enc_done(&st->enc);
343    {
344       unsigned char *data;
345       int nbBytes = ec_byte_bytes(&st->buf);
346       if (nbBytes != nbCompressedBytes)
347       {
348          if (nbBytes > nbCompressedBytes)
349             celt_warning("got too many bytes");
350          else
351             celt_warning("not enough bytes");
352          return CELT_INTERNAL_ERROR;
353       }
354       //printf ("%d\n", *nbBytes);
355       data = ec_byte_get_buffer(&st->buf);
356       for (i=0;i<nbBytes;i++)
357          compressed[i] = data[i];
358    }
359    /* Reset the packing for the next encoding */
360    ec_byte_reset(&st->buf);
361    ec_enc_init(&st->enc,&st->buf);
362
363    return nbCompressedBytes;
364 }
365
366
367 /****************************************************************************/
368 /*                                                                          */
369 /*                                DECODER                                   */
370 /*                                                                          */
371 /****************************************************************************/
372
373
374
375 struct CELTDecoder {
376    const CELTMode *mode;
377    int frame_size;
378    int block_size;
379    int nb_blocks;
380    int overlap;
381
382    ec_byte_buffer buf;
383    ec_enc         enc;
384
385    float preemph;
386    float *preemph_memD;
387    
388    mdct_lookup mdct_lookup;
389    
390    float *window;
391    float *mdct_overlap;
392    float *out_mem;
393
394    float *oldBandE;
395    
396    int last_pitch_index;
397    
398    struct alloc_data alloc;
399 };
400
401 CELTDecoder *celt_decoder_new(const CELTMode *mode)
402 {
403    int i, N, B, C, N4;
404    N = mode->mdctSize;
405    B = mode->nbMdctBlocks;
406    C = mode->nbChannels;
407    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
408    
409    st->mode = mode;
410    st->frame_size = B*N;
411    st->block_size = N;
412    st->nb_blocks  = B;
413    st->overlap = mode->overlap;
414
415    N4 = (N-st->overlap)/2;
416    
417    mdct_init(&st->mdct_lookup, 2*N);
418    
419    st->window = celt_alloc(2*N*sizeof(float));
420    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
421    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
422
423    for (i=0;i<2*N;i++)
424       st->window[i] = 0;
425    for (i=0;i<st->overlap;i++)
426       st->window[N4+i] = st->window[2*N-N4-i-1] 
427             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
428    for (i=0;i<2*N4;i++)
429       st->window[N-N4+i] = 1;
430    
431    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
432
433    st->preemph = 0.8;
434    st->preemph_memD = celt_alloc(C*sizeof(float));;
435
436    st->last_pitch_index = 0;
437    alloc_init(&st->alloc, st->mode);
438
439    return st;
440 }
441
442 void celt_decoder_destroy(CELTDecoder *st)
443 {
444    if (st == NULL)
445    {
446       celt_warning("NULL passed to celt_encoder_destroy");
447       return;
448    }
449
450    mdct_clear(&st->mdct_lookup);
451
452    celt_free(st->window);
453    celt_free(st->mdct_overlap);
454    celt_free(st->out_mem);
455    
456    celt_free(st->oldBandE);
457    alloc_clear(&st->alloc);
458
459    celt_free(st);
460 }
461
462 static void celt_decode_lost(CELTDecoder *st, short *pcm)
463 {
464    int i, c, N, B, C;
465    N = st->block_size;
466    B = st->nb_blocks;
467    C = st->mode->nbChannels;
468    float X[C*B*N];         /**< Interleaved signal MDCTs */
469    int pitch_index;
470    
471    pitch_index = st->last_pitch_index;
472    
473    /* Use the pitch MDCT as the "guessed" signal */
474    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
475
476    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
477    /* Compute inverse MDCTs */
478    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
479
480    for (c=0;c<C;c++)
481    {
482       for (i=0;i<B;i++)
483       {
484          int j;
485          for (j=0;j<N;j++)
486          {
487             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
488             st->preemph_memD[c] = tmp;
489             if (tmp > 32767) tmp = 32767;
490             if (tmp < -32767) tmp = -32767;
491             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
492          }
493       }
494    }
495 }
496
497 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
498 {
499    int i, c, N, B, C;
500    N = st->block_size;
501    B = st->nb_blocks;
502    C = st->mode->nbChannels;
503    
504    float X[C*B*N];         /**< Interleaved signal MDCTs */
505    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
506    float bandE[st->mode->nbEBands*C];
507    float gains[st->mode->nbPBands];
508    int pitch_index;
509    ec_dec dec;
510    ec_byte_buffer buf;
511    
512    if (data == NULL)
513    {
514       celt_decode_lost(st, pcm);
515       return 0;
516    }
517    
518    ec_byte_readinit(&buf,data,len);
519    ec_dec_init(&dec,&buf);
520    
521    /* Get the pitch index */
522    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
523    st->last_pitch_index = pitch_index;
524    
525    /* Get band energies */
526    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
527    
528    /* Pitch MDCT */
529    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
530
531    {
532       float bandEp[st->mode->nbEBands];
533       compute_band_energies(st->mode, P, bandEp);
534       normalise_bands(st->mode, P, bandEp);
535    }
536
537    if (C==2)
538       stereo_mix(st->mode, P, bandE, 1);
539
540    /* Get the pitch gains */
541    unquant_pitch(gains, st->mode->nbPBands, &dec);
542
543    /* Apply pitch gains */
544    pitch_quant_bands(st->mode, X, P, gains);
545
546    /* Decode fixed codebook and merge with pitch */
547    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
548
549    if (C==2)
550       stereo_mix(st->mode, X, bandE, -1);
551
552    renormalise_bands(st->mode, X);
553    
554    /* Synthesis */
555    denormalise_bands(st->mode, X, bandE);
556
557
558    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
559    /* Compute inverse MDCTs */
560    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
561
562    for (c=0;c<C;c++)
563    {
564       for (i=0;i<B;i++)
565       {
566          int j;
567          for (j=0;j<N;j++)
568          {
569             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
570             st->preemph_memD[c] = tmp;
571             if (tmp > 32767) tmp = 32767;
572             if (tmp < -32767) tmp = -32767;
573             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
574          }
575       }
576    }
577
578    {
579       int val = 0;
580       while (ec_dec_tell(&dec, 0) < len*8)
581       {
582          if (ec_dec_uint(&dec, 2) != val)
583          {
584             celt_warning("decode error");
585             return CELT_CORRUPTED_DATA;
586          }
587          val = 1-val;
588       }
589    }
590
591    return 0;
592    //printf ("\n");
593 }
594