fixed leaked ritrev table
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    kiss_fft_free(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    
142    celt_free(st->preemph_memE);
143    celt_free(st->preemph_memD);
144    
145    alloc_clear(&st->alloc);
146
147    celt_free(st);
148 }
149
150
151 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
152 {
153    int i, c;
154    for (c=0;c<C;c++)
155    {
156       for (i=0;i<B;i++)
157       {
158          int j;
159          float x[2*N];
160          float tmp[N];
161          for (j=0;j<2*N;j++)
162             x[j] = window[j]*in[C*i*N+C*j+c];
163          mdct_forward(mdct_lookup, x, tmp);
164          /* Interleaving the sub-frames */
165          for (j=0;j<N;j++)
166             out[C*B*j+C*i+c] = tmp[j];
167       }
168    }
169 }
170
171 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
172 {
173    int i, c, N4;
174    N4 = (N-overlap)/2;
175    for (c=0;c<C;c++)
176    {
177       for (i=0;i<B;i++)
178       {
179          int j;
180          float x[2*N];
181          float tmp[N];
182          /* De-interleaving the sub-frames */
183          for (j=0;j<N;j++)
184             tmp[j] = X[C*B*j+C*i+c];
185          mdct_backward(mdct_lookup, tmp, x);
186          for (j=0;j<2*N;j++)
187             x[j] = window[j]*x[j];
188          for (j=0;j<overlap;j++)
189             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
190          for (j=0;j<2*N4;j++)
191             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
192          for (j=0;j<overlap;j++)
193             mdct_overlap[C*j+c] = x[N+N4+j];
194       }
195    }
196 }
197
198 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
199 {
200    int i, c, N, B, C, N4;
201    N = st->block_size;
202    B = st->nb_blocks;
203    C = st->mode->nbChannels;
204    float in[(B+1)*C*N];
205
206    float X[B*C*N];         /**< Interleaved signal MDCTs */
207    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
208    float mask[B*C*N];      /**< Masking curve */
209    float bandE[st->mode->nbEBands*C];
210    float gains[st->mode->nbPBands];
211    int pitch_index;
212
213    N4 = (N-st->overlap)/2;
214
215    for (c=0;c<C;c++)
216    {
217       for (i=0;i<N4;i++)
218          in[C*i+c] = 0;
219       for (i=0;i<st->overlap;i++)
220          in[C*(i+N4)+c] = st->in_mem[C*i+c];
221       for (i=0;i<B*N;i++)
222       {
223          float tmp = pcm[C*i+c];
224          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
225          st->preemph_memE[c] = tmp;
226       }
227       for (i=N*(B+1)-N4;i<N*(B+1);i++)
228          in[C*i+c] = 0;
229       for (i=0;i<st->overlap;i++)
230          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
231    }
232    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
233    /* Compute MDCTs */
234    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
235
236    compute_mdct_masking(X, mask, B*C*N, st->Fs);
237
238    /* Invert and stretch the mask to length of X 
239       For some reason, I get better results by using the sqrt instead,
240       although there's no valid reason to. Must investigate further */
241    for (i=0;i<B*C*N;i++)
242       mask[i] = 1/(.1+mask[i]);
243
244    /* Pitch analysis */
245    for (c=0;c<C;c++)
246    {
247       for (i=0;i<N;i++)
248       {
249          in[C*i+c] *= st->window[i];
250          in[C*(B*N+i)+c] *= st->window[N+i];
251       }
252    }
253    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
254    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
255    
256    /* Compute MDCTs of the pitch part */
257    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
258    
259    /*int j;
260    for (j=0;j<B*N;j++)
261       printf ("%f ", X[j]);
262    for (j=0;j<B*N;j++)
263       printf ("%f ", P[j]);
264    printf ("\n");*/
265
266    /* Band normalisation */
267    compute_band_energies(st->mode, X, bandE);
268    normalise_bands(st->mode, X, bandE);
269    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
270    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
271
272    /* Normalise the pitch vector as well (discard the energies) */
273    {
274       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
275       compute_band_energies(st->mode, P, bandEp);
276       normalise_bands(st->mode, P, bandEp);
277    }
278
279    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
280
281    if (C==2)
282    {
283       stereo_mix(st->mode, X, bandE, 1);
284       stereo_mix(st->mode, P, bandE, 1);
285    }
286    /* Simulates intensity stereo */
287    //for (i=30;i<N*B;i++)
288    //   X[i*C+1] = P[i*C+1] = 0;
289    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
290
291    /* Pitch prediction */
292    compute_pitch_gain(st->mode, X, P, gains, bandE);
293    quant_pitch(gains, st->mode->nbPBands, &st->enc);
294    pitch_quant_bands(st->mode, X, P, gains);
295
296    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
297    /* Compute residual that we're going to encode */
298    for (i=0;i<B*C*N;i++)
299       X[i] -= P[i];
300
301    /*float sum=0;
302    for (i=0;i<B*N;i++)
303       sum += X[i]*X[i];
304    printf ("%f\n", sum);*/
305    /* Residual quantisation */
306    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
307    
308    if (C==2)
309       stereo_mix(st->mode, X, bandE, -1);
310
311    renormalise_bands(st->mode, X);
312    /* Synthesis */
313    denormalise_bands(st->mode, X, bandE);
314
315
316    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
317
318    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
319    /* De-emphasis and put everything back at the right place in the synthesis history */
320    for (c=0;c<C;c++)
321    {
322       for (i=0;i<B;i++)
323       {
324          int j;
325          for (j=0;j<N;j++)
326          {
327             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
328             st->preemph_memD[c] = tmp;
329             if (tmp > 32767) tmp = 32767;
330             if (tmp < -32767) tmp = -32767;
331             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
332          }
333       }
334    }
335    
336    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
337    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
338    {
339       int val = 0;
340       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
341       {
342          ec_enc_uint(&st->enc, val, 2);
343          val = 1-val;
344       }
345    }
346    ec_enc_done(&st->enc);
347    {
348       unsigned char *data;
349       int nbBytes = ec_byte_bytes(&st->buf);
350       if (nbBytes > nbCompressedBytes)
351       {
352          celt_warning_int ("got too many bytes:", nbBytes);
353          return CELT_INTERNAL_ERROR;
354       }
355       //printf ("%d\n", *nbBytes);
356       data = ec_byte_get_buffer(&st->buf);
357       for (i=0;i<nbBytes;i++)
358          compressed[i] = data[i];
359       for (;i<nbCompressedBytes;i++)
360          compressed[i] = 0;
361    }
362    /* Reset the packing for the next encoding */
363    ec_byte_reset(&st->buf);
364    ec_enc_init(&st->enc,&st->buf);
365
366    return nbCompressedBytes;
367 }
368
369
370 /****************************************************************************/
371 /*                                                                          */
372 /*                                DECODER                                   */
373 /*                                                                          */
374 /****************************************************************************/
375
376
377
378 struct CELTDecoder {
379    const CELTMode *mode;
380    int frame_size;
381    int block_size;
382    int nb_blocks;
383    int overlap;
384
385    ec_byte_buffer buf;
386    ec_enc         enc;
387
388    float preemph;
389    float *preemph_memD;
390    
391    mdct_lookup mdct_lookup;
392    
393    float *window;
394    float *mdct_overlap;
395    float *out_mem;
396
397    float *oldBandE;
398    
399    int last_pitch_index;
400    
401    struct alloc_data alloc;
402 };
403
404 CELTDecoder *celt_decoder_new(const CELTMode *mode)
405 {
406    int i, N, B, C, N4;
407    N = mode->mdctSize;
408    B = mode->nbMdctBlocks;
409    C = mode->nbChannels;
410    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
411    
412    st->mode = mode;
413    st->frame_size = B*N;
414    st->block_size = N;
415    st->nb_blocks  = B;
416    st->overlap = mode->overlap;
417
418    N4 = (N-st->overlap)/2;
419    
420    mdct_init(&st->mdct_lookup, 2*N);
421    
422    st->window = celt_alloc(2*N*sizeof(float));
423    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
424    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
425
426    for (i=0;i<2*N;i++)
427       st->window[i] = 0;
428    for (i=0;i<st->overlap;i++)
429       st->window[N4+i] = st->window[2*N-N4-i-1] 
430             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
431    for (i=0;i<2*N4;i++)
432       st->window[N-N4+i] = 1;
433    
434    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
435
436    st->preemph = 0.8;
437    st->preemph_memD = celt_alloc(C*sizeof(float));;
438
439    st->last_pitch_index = 0;
440    alloc_init(&st->alloc, st->mode);
441
442    return st;
443 }
444
445 void celt_decoder_destroy(CELTDecoder *st)
446 {
447    if (st == NULL)
448    {
449       celt_warning("NULL passed to celt_encoder_destroy");
450       return;
451    }
452
453    mdct_clear(&st->mdct_lookup);
454
455    celt_free(st->window);
456    celt_free(st->mdct_overlap);
457    celt_free(st->out_mem);
458    
459    celt_free(st->oldBandE);
460    
461    celt_free(st->preemph_memD);
462
463    alloc_clear(&st->alloc);
464
465    celt_free(st);
466 }
467
468 static void celt_decode_lost(CELTDecoder *st, short *pcm)
469 {
470    int i, c, N, B, C;
471    N = st->block_size;
472    B = st->nb_blocks;
473    C = st->mode->nbChannels;
474    float X[C*B*N];         /**< Interleaved signal MDCTs */
475    int pitch_index;
476    
477    pitch_index = st->last_pitch_index;
478    
479    /* Use the pitch MDCT as the "guessed" signal */
480    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
481
482    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
483    /* Compute inverse MDCTs */
484    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
485
486    for (c=0;c<C;c++)
487    {
488       for (i=0;i<B;i++)
489       {
490          int j;
491          for (j=0;j<N;j++)
492          {
493             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
494             st->preemph_memD[c] = tmp;
495             if (tmp > 32767) tmp = 32767;
496             if (tmp < -32767) tmp = -32767;
497             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
498          }
499       }
500    }
501 }
502
503 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
504 {
505    int i, c, N, B, C;
506    N = st->block_size;
507    B = st->nb_blocks;
508    C = st->mode->nbChannels;
509    
510    float X[C*B*N];         /**< Interleaved signal MDCTs */
511    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
512    float bandE[st->mode->nbEBands*C];
513    float gains[st->mode->nbPBands];
514    int pitch_index;
515    ec_dec dec;
516    ec_byte_buffer buf;
517    
518    if (data == NULL)
519    {
520       celt_decode_lost(st, pcm);
521       return 0;
522    }
523    
524    ec_byte_readinit(&buf,data,len);
525    ec_dec_init(&dec,&buf);
526    
527    /* Get the pitch index */
528    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
529    st->last_pitch_index = pitch_index;
530    
531    /* Get band energies */
532    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
533    
534    /* Pitch MDCT */
535    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
536
537    {
538       float bandEp[st->mode->nbEBands];
539       compute_band_energies(st->mode, P, bandEp);
540       normalise_bands(st->mode, P, bandEp);
541    }
542
543    if (C==2)
544       stereo_mix(st->mode, P, bandE, 1);
545
546    /* Get the pitch gains */
547    unquant_pitch(gains, st->mode->nbPBands, &dec);
548
549    /* Apply pitch gains */
550    pitch_quant_bands(st->mode, X, P, gains);
551
552    /* Decode fixed codebook and merge with pitch */
553    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
554
555    if (C==2)
556       stereo_mix(st->mode, X, bandE, -1);
557
558    renormalise_bands(st->mode, X);
559    
560    /* Synthesis */
561    denormalise_bands(st->mode, X, bandE);
562
563
564    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
565    /* Compute inverse MDCTs */
566    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
567
568    for (c=0;c<C;c++)
569    {
570       for (i=0;i<B;i++)
571       {
572          int j;
573          for (j=0;j<N;j++)
574          {
575             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
576             st->preemph_memD[c] = tmp;
577             if (tmp > 32767) tmp = 32767;
578             if (tmp < -32767) tmp = -32767;
579             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
580          }
581       }
582    }
583
584    {
585       int val = 0;
586       while (ec_dec_tell(&dec, 0) < len*8)
587       {
588          if (ec_dec_uint(&dec, 2) != val)
589          {
590             celt_warning("decode error");
591             return CELT_CORRUPTED_DATA;
592          }
593          val = 1-val;
594       }
595    }
596
597    return 0;
598    //printf ("\n");
599 }
600