Decays corresponding to the psychoacoustic spreading function are now
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99
100    mdct_init(&st->mdct_lookup, 2*N);
101    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
102    psydecay_init(&st->psy, MAX_PERIOD/2, st->Fs);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    kiss_fft_free(st->fft);
136    psydecay_clear(&st->psy);
137
138    celt_free(st->window);
139    celt_free(st->in_mem);
140    celt_free(st->mdct_overlap);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148    alloc_clear(&st->alloc);
149
150    celt_free(st);
151 }
152
153
154 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
155 {
156    int i, c;
157    for (c=0;c<C;c++)
158    {
159       for (i=0;i<B;i++)
160       {
161          int j;
162          float x[2*N];
163          float tmp[N];
164          for (j=0;j<2*N;j++)
165             x[j] = window[j]*in[C*i*N+C*j+c];
166          mdct_forward(mdct_lookup, x, tmp);
167          /* Interleaving the sub-frames */
168          for (j=0;j<N;j++)
169             out[C*B*j+C*i+c] = tmp[j];
170       }
171    }
172 }
173
174 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
175 {
176    int i, c, N4;
177    N4 = (N-overlap)/2;
178    for (c=0;c<C;c++)
179    {
180       for (i=0;i<B;i++)
181       {
182          int j;
183          float x[2*N];
184          float tmp[N];
185          /* De-interleaving the sub-frames */
186          for (j=0;j<N;j++)
187             tmp[j] = X[C*B*j+C*i+c];
188          mdct_backward(mdct_lookup, tmp, x);
189          for (j=0;j<2*N;j++)
190             x[j] = window[j]*x[j];
191          for (j=0;j<overlap;j++)
192             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
193          for (j=0;j<2*N4;j++)
194             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
195          for (j=0;j<overlap;j++)
196             mdct_overlap[C*j+c] = x[N+N4+j];
197       }
198    }
199 }
200
201 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
202 {
203    int i, c, N, B, C, N4;
204    N = st->block_size;
205    B = st->nb_blocks;
206    C = st->mode->nbChannels;
207    float in[(B+1)*C*N];
208
209    float X[B*C*N];         /**< Interleaved signal MDCTs */
210    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
211    float mask[B*C*N];      /**< Masking curve */
212    float bandE[st->mode->nbEBands*C];
213    float gains[st->mode->nbPBands];
214    int pitch_index;
215
216    N4 = (N-st->overlap)/2;
217
218    for (c=0;c<C;c++)
219    {
220       for (i=0;i<N4;i++)
221          in[C*i+c] = 0;
222       for (i=0;i<st->overlap;i++)
223          in[C*(i+N4)+c] = st->in_mem[C*i+c];
224       for (i=0;i<B*N;i++)
225       {
226          float tmp = pcm[C*i+c];
227          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
228          st->preemph_memE[c] = tmp;
229       }
230       for (i=N*(B+1)-N4;i<N*(B+1);i++)
231          in[C*i+c] = 0;
232       for (i=0;i<st->overlap;i++)
233          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
234    }
235    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
236    /* Compute MDCTs */
237    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
238
239 #if 0 /* Mask disabled until it can be made to do something useful */
240    compute_mdct_masking(X, mask, B*C*N, st->Fs);
241
242    /* Invert and stretch the mask to length of X 
243       For some reason, I get better results by using the sqrt instead,
244       although there's no valid reason to. Must investigate further */
245    for (i=0;i<B*C*N;i++)
246       mask[i] = 1/(.1+mask[i]);
247 #else
248    for (i=0;i<B*C*N;i++)
249       mask[i] = 1;
250 #endif
251    /* Pitch analysis */
252    for (c=0;c<C;c++)
253    {
254       for (i=0;i<N;i++)
255       {
256          in[C*i+c] *= st->window[i];
257          in[C*(B*N+i)+c] *= st->window[N+i];
258       }
259    }
260    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
261    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
262    
263    /* Compute MDCTs of the pitch part */
264    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
265    
266    /*int j;
267    for (j=0;j<B*N;j++)
268       printf ("%f ", X[j]);
269    for (j=0;j<B*N;j++)
270       printf ("%f ", P[j]);
271    printf ("\n");*/
272
273    /* Band normalisation */
274    compute_band_energies(st->mode, X, bandE);
275    normalise_bands(st->mode, X, bandE);
276    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
277    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
278
279    /* Normalise the pitch vector as well (discard the energies) */
280    {
281       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
282       compute_band_energies(st->mode, P, bandEp);
283       normalise_bands(st->mode, P, bandEp);
284    }
285
286    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
287
288    if (C==2)
289    {
290       stereo_mix(st->mode, X, bandE, 1);
291       stereo_mix(st->mode, P, bandE, 1);
292    }
293    /* Simulates intensity stereo */
294    //for (i=30;i<N*B;i++)
295    //   X[i*C+1] = P[i*C+1] = 0;
296    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
297
298    /* Pitch prediction */
299    compute_pitch_gain(st->mode, X, P, gains, bandE);
300    quant_pitch(gains, st->mode->nbPBands, &st->enc);
301    pitch_quant_bands(st->mode, X, P, gains);
302
303    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
304    /* Compute residual that we're going to encode */
305    for (i=0;i<B*C*N;i++)
306       X[i] -= P[i];
307
308    /*float sum=0;
309    for (i=0;i<B*N;i++)
310       sum += X[i]*X[i];
311    printf ("%f\n", sum);*/
312    /* Residual quantisation */
313    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
314    
315    if (C==2)
316       stereo_mix(st->mode, X, bandE, -1);
317
318    renormalise_bands(st->mode, X);
319    /* Synthesis */
320    denormalise_bands(st->mode, X, bandE);
321
322
323    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
324
325    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
326    /* De-emphasis and put everything back at the right place in the synthesis history */
327    for (c=0;c<C;c++)
328    {
329       for (i=0;i<B;i++)
330       {
331          int j;
332          for (j=0;j<N;j++)
333          {
334             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
335             st->preemph_memD[c] = tmp;
336             if (tmp > 32767) tmp = 32767;
337             if (tmp < -32767) tmp = -32767;
338             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
339          }
340       }
341    }
342    
343    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
344    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
345    {
346       int val = 0;
347       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
348       {
349          ec_enc_uint(&st->enc, val, 2);
350          val = 1-val;
351       }
352    }
353    ec_enc_done(&st->enc);
354    {
355       unsigned char *data;
356       int nbBytes = ec_byte_bytes(&st->buf);
357       if (nbBytes > nbCompressedBytes)
358       {
359          celt_warning_int ("got too many bytes:", nbBytes);
360          return CELT_INTERNAL_ERROR;
361       }
362       //printf ("%d\n", *nbBytes);
363       data = ec_byte_get_buffer(&st->buf);
364       for (i=0;i<nbBytes;i++)
365          compressed[i] = data[i];
366       for (;i<nbCompressedBytes;i++)
367          compressed[i] = 0;
368    }
369    /* Reset the packing for the next encoding */
370    ec_byte_reset(&st->buf);
371    ec_enc_init(&st->enc,&st->buf);
372
373    return nbCompressedBytes;
374 }
375
376
377 /****************************************************************************/
378 /*                                                                          */
379 /*                                DECODER                                   */
380 /*                                                                          */
381 /****************************************************************************/
382
383
384
385 struct CELTDecoder {
386    const CELTMode *mode;
387    int frame_size;
388    int block_size;
389    int nb_blocks;
390    int overlap;
391
392    ec_byte_buffer buf;
393    ec_enc         enc;
394
395    float preemph;
396    float *preemph_memD;
397    
398    mdct_lookup mdct_lookup;
399    
400    float *window;
401    float *mdct_overlap;
402    float *out_mem;
403
404    float *oldBandE;
405    
406    int last_pitch_index;
407    
408    struct alloc_data alloc;
409 };
410
411 CELTDecoder *celt_decoder_new(const CELTMode *mode)
412 {
413    int i, N, B, C, N4;
414    N = mode->mdctSize;
415    B = mode->nbMdctBlocks;
416    C = mode->nbChannels;
417    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
418    
419    st->mode = mode;
420    st->frame_size = B*N;
421    st->block_size = N;
422    st->nb_blocks  = B;
423    st->overlap = mode->overlap;
424
425    N4 = (N-st->overlap)/2;
426    
427    mdct_init(&st->mdct_lookup, 2*N);
428    
429    st->window = celt_alloc(2*N*sizeof(float));
430    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
431    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
432
433    for (i=0;i<2*N;i++)
434       st->window[i] = 0;
435    for (i=0;i<st->overlap;i++)
436       st->window[N4+i] = st->window[2*N-N4-i-1] 
437             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
438    for (i=0;i<2*N4;i++)
439       st->window[N-N4+i] = 1;
440    
441    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
442
443    st->preemph = 0.8;
444    st->preemph_memD = celt_alloc(C*sizeof(float));;
445
446    st->last_pitch_index = 0;
447    alloc_init(&st->alloc, st->mode);
448
449    return st;
450 }
451
452 void celt_decoder_destroy(CELTDecoder *st)
453 {
454    if (st == NULL)
455    {
456       celt_warning("NULL passed to celt_encoder_destroy");
457       return;
458    }
459
460    mdct_clear(&st->mdct_lookup);
461
462    celt_free(st->window);
463    celt_free(st->mdct_overlap);
464    celt_free(st->out_mem);
465    
466    celt_free(st->oldBandE);
467    
468    celt_free(st->preemph_memD);
469
470    alloc_clear(&st->alloc);
471
472    celt_free(st);
473 }
474
475 static void celt_decode_lost(CELTDecoder *st, short *pcm)
476 {
477    int i, c, N, B, C;
478    N = st->block_size;
479    B = st->nb_blocks;
480    C = st->mode->nbChannels;
481    float X[C*B*N];         /**< Interleaved signal MDCTs */
482    int pitch_index;
483    
484    pitch_index = st->last_pitch_index;
485    
486    /* Use the pitch MDCT as the "guessed" signal */
487    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
488
489    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
490    /* Compute inverse MDCTs */
491    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
492
493    for (c=0;c<C;c++)
494    {
495       for (i=0;i<B;i++)
496       {
497          int j;
498          for (j=0;j<N;j++)
499          {
500             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
501             st->preemph_memD[c] = tmp;
502             if (tmp > 32767) tmp = 32767;
503             if (tmp < -32767) tmp = -32767;
504             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
505          }
506       }
507    }
508 }
509
510 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
511 {
512    int i, c, N, B, C;
513    N = st->block_size;
514    B = st->nb_blocks;
515    C = st->mode->nbChannels;
516    
517    float X[C*B*N];         /**< Interleaved signal MDCTs */
518    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
519    float bandE[st->mode->nbEBands*C];
520    float gains[st->mode->nbPBands];
521    int pitch_index;
522    ec_dec dec;
523    ec_byte_buffer buf;
524    
525    if (data == NULL)
526    {
527       celt_decode_lost(st, pcm);
528       return 0;
529    }
530    
531    ec_byte_readinit(&buf,data,len);
532    ec_dec_init(&dec,&buf);
533    
534    /* Get the pitch index */
535    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
536    st->last_pitch_index = pitch_index;
537    
538    /* Get band energies */
539    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
540    
541    /* Pitch MDCT */
542    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
543
544    {
545       float bandEp[st->mode->nbEBands];
546       compute_band_energies(st->mode, P, bandEp);
547       normalise_bands(st->mode, P, bandEp);
548    }
549
550    if (C==2)
551       stereo_mix(st->mode, P, bandE, 1);
552
553    /* Get the pitch gains */
554    unquant_pitch(gains, st->mode->nbPBands, &dec);
555
556    /* Apply pitch gains */
557    pitch_quant_bands(st->mode, X, P, gains);
558
559    /* Decode fixed codebook and merge with pitch */
560    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
561
562    if (C==2)
563       stereo_mix(st->mode, X, bandE, -1);
564
565    renormalise_bands(st->mode, X);
566    
567    /* Synthesis */
568    denormalise_bands(st->mode, X, bandE);
569
570
571    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
572    /* Compute inverse MDCTs */
573    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
574
575    for (c=0;c<C;c++)
576    {
577       for (i=0;i<B;i++)
578       {
579          int j;
580          for (j=0;j<N;j++)
581          {
582             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
583             st->preemph_memD[c] = tmp;
584             if (tmp > 32767) tmp = 32767;
585             if (tmp < -32767) tmp = -32767;
586             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
587          }
588       }
589    }
590
591    {
592       int val = 0;
593       while (ec_dec_tell(&dec, 0) < len*8)
594       {
595          if (ec_dec_uint(&dec, 2) != val)
596          {
597             celt_warning("decode error");
598             return CELT_CORRUPTED_DATA;
599          }
600          val = 1-val;
601       }
602    }
603
604    return 0;
605    //printf ("\n");
606 }
607