Merged the rate allocation atruct directly into the mode struct.
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75 };
76
77
78
79 CELTEncoder *celt_encoder_new(const CELTMode *mode)
80 {
81    int i, N, B, C, N4;
82    N = mode->mdctSize;
83    B = mode->nbMdctBlocks;
84    C = mode->nbChannels;
85    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
86    
87    st->mode = mode;
88    st->frame_size = B*N;
89    st->block_size = N;
90    st->nb_blocks  = B;
91    st->overlap = mode->overlap;
92    st->Fs = 44100;
93
94    N4 = (N-st->overlap)/2;
95    ec_byte_writeinit(&st->buf);
96    ec_enc_init(&st->enc,&st->buf);
97
98    mdct_init(&st->mdct_lookup, 2*N);
99    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
100    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    return st;
120 }
121
122 void celt_encoder_destroy(CELTEncoder *st)
123 {
124    if (st == NULL)
125    {
126       celt_warning("NULL passed to celt_encoder_destroy");
127       return;
128    }
129    ec_byte_writeclear(&st->buf);
130
131    mdct_clear(&st->mdct_lookup);
132    kiss_fft_free(st->fft);
133    psydecay_clear(&st->psy);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    
142    celt_free(st->preemph_memE);
143    celt_free(st->preemph_memD);
144    
145    celt_free(st);
146 }
147
148
149 static float compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
150 {
151    int i, c;
152    float E = 1e-15;
153    for (c=0;c<C;c++)
154    {
155       for (i=0;i<B;i++)
156       {
157          int j;
158          float x[2*N];
159          float tmp[N];
160          for (j=0;j<2*N;j++)
161          {
162             x[j] = window[j]*in[C*i*N+C*j+c];
163             E += x[j]*x[j];
164          }
165          mdct_forward(mdct_lookup, x, tmp);
166          /* Interleaving the sub-frames */
167          for (j=0;j<N;j++)
168             out[C*B*j+C*i+c] = tmp[j];
169       }
170    }
171    return E;
172 }
173
174 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
175 {
176    int i, c, N4;
177    N4 = (N-overlap)/2;
178    for (c=0;c<C;c++)
179    {
180       for (i=0;i<B;i++)
181       {
182          int j;
183          float x[2*N];
184          float tmp[N];
185          /* De-interleaving the sub-frames */
186          for (j=0;j<N;j++)
187             tmp[j] = X[C*B*j+C*i+c];
188          mdct_backward(mdct_lookup, tmp, x);
189          for (j=0;j<2*N;j++)
190             x[j] = window[j]*x[j];
191          for (j=0;j<overlap;j++)
192             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
193          for (j=0;j<2*N4;j++)
194             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
195          for (j=0;j<overlap;j++)
196             mdct_overlap[C*j+c] = x[N+N4+j];
197       }
198    }
199 }
200
201 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
202 {
203    int i, c, N, B, C, N4;
204    int has_pitch;
205    N = st->block_size;
206    B = st->nb_blocks;
207    C = st->mode->nbChannels;
208    float in[(B+1)*C*N];
209
210    float X[B*C*N];         /**< Interleaved signal MDCTs */
211    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
212    float mask[B*C*N];      /**< Masking curve */
213    float bandE[st->mode->nbEBands*C];
214    float gains[st->mode->nbPBands];
215    int pitch_index;
216    float curr_power, pitch_power;
217    
218    N4 = (N-st->overlap)/2;
219
220    for (c=0;c<C;c++)
221    {
222       for (i=0;i<N4;i++)
223          in[C*i+c] = 0;
224       for (i=0;i<st->overlap;i++)
225          in[C*(i+N4)+c] = st->in_mem[C*i+c];
226       for (i=0;i<B*N;i++)
227       {
228          float tmp = pcm[C*i+c];
229          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
230          st->preemph_memE[c] = tmp;
231       }
232       for (i=N*(B+1)-N4;i<N*(B+1);i++)
233          in[C*i+c] = 0;
234       for (i=0;i<st->overlap;i++)
235          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
236    }
237    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
238    /* Compute MDCTs */
239    curr_power = compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
240
241 #if 0 /* Mask disabled until it can be made to do something useful */
242    compute_mdct_masking(X, mask, B*C*N, st->Fs);
243
244    /* Invert and stretch the mask to length of X 
245       For some reason, I get better results by using the sqrt instead,
246       although there's no valid reason to. Must investigate further */
247    for (i=0;i<B*C*N;i++)
248       mask[i] = 1/(.1+mask[i]);
249 #else
250    for (i=0;i<B*C*N;i++)
251       mask[i] = 1;
252 #endif
253    /* Pitch analysis */
254    for (c=0;c<C;c++)
255    {
256       for (i=0;i<N;i++)
257       {
258          in[C*i+c] *= st->window[i];
259          in[C*(B*N+i)+c] *= st->window[N+i];
260       }
261    }
262    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
263    
264    /* Compute MDCTs of the pitch part */
265    pitch_power = compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
266    
267    //printf ("%f %f\n", curr_power, pitch_power);
268    /*int j;
269    for (j=0;j<B*N;j++)
270       printf ("%f ", X[j]);
271    for (j=0;j<B*N;j++)
272       printf ("%f ", P[j]);
273    printf ("\n");*/
274
275    /* Band normalisation */
276    compute_band_energies(st->mode, X, bandE);
277    normalise_bands(st->mode, X, bandE);
278    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
279    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
280
281    quant_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, &st->enc);
282
283    if (C==2)
284    {
285       stereo_mix(st->mode, X, bandE, 1);
286    }
287
288    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
289    if (curr_power + 1e5f < 10.f*pitch_power)
290    {
291       /* Normalise the pitch vector as well (discard the energies) */
292       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
293       compute_band_energies(st->mode, P, bandEp);
294       normalise_bands(st->mode, P, bandEp);
295
296       if (C==2)
297          stereo_mix(st->mode, P, bandE, 1);
298       /* Simulates intensity stereo */
299       //for (i=30;i<N*B;i++)
300       //   X[i*C+1] = P[i*C+1] = 0;
301
302       /* Pitch prediction */
303       compute_pitch_gain(st->mode, X, P, gains, bandE);
304       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
305       if (has_pitch)
306          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
307    } else {
308       /* No pitch, so we just pretend we found a gain of zero */
309       for (i=0;i<st->mode->nbPBands;i++)
310          gains[i] = 0;
311       ec_enc_uint(&st->enc, 0, 128);
312       for (i=0;i<B*C*N;i++)
313          P[i] = 0;
314    }
315    
316
317    pitch_quant_bands(st->mode, X, P, gains);
318
319    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
320    /* Compute residual that we're going to encode */
321    for (i=0;i<B*C*N;i++)
322       X[i] -= P[i];
323
324    /*float sum=0;
325    for (i=0;i<B*N;i++)
326       sum += X[i]*X[i];
327    printf ("%f\n", sum);*/
328    /* Residual quantisation */
329    quant_bands(st->mode, X, P, mask, nbCompressedBytes*8, &st->enc);
330    
331    if (C==2)
332       stereo_mix(st->mode, X, bandE, -1);
333
334    renormalise_bands(st->mode, X);
335    /* Synthesis */
336    denormalise_bands(st->mode, X, bandE);
337
338
339    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
340
341    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
342    /* De-emphasis and put everything back at the right place in the synthesis history */
343    for (c=0;c<C;c++)
344    {
345       for (i=0;i<B;i++)
346       {
347          int j;
348          for (j=0;j<N;j++)
349          {
350             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
351             st->preemph_memD[c] = tmp;
352             if (tmp > 32767) tmp = 32767;
353             if (tmp < -32767) tmp = -32767;
354             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
355          }
356       }
357    }
358    
359    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
360       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
361    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
362    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
363    {
364       int val = 0;
365       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
366       {
367          ec_enc_uint(&st->enc, val, 2);
368          val = 1-val;
369       }
370    }
371    ec_enc_done(&st->enc);
372    {
373       unsigned char *data;
374       int nbBytes = ec_byte_bytes(&st->buf);
375       if (nbBytes > nbCompressedBytes)
376       {
377          celt_warning_int ("got too many bytes:", nbBytes);
378          return CELT_INTERNAL_ERROR;
379       }
380       //printf ("%d\n", *nbBytes);
381       data = ec_byte_get_buffer(&st->buf);
382       for (i=0;i<nbBytes;i++)
383          compressed[i] = data[i];
384       for (;i<nbCompressedBytes;i++)
385          compressed[i] = 0;
386    }
387    /* Reset the packing for the next encoding */
388    ec_byte_reset(&st->buf);
389    ec_enc_init(&st->enc,&st->buf);
390
391    return nbCompressedBytes;
392 }
393
394
395 /****************************************************************************/
396 /*                                                                          */
397 /*                                DECODER                                   */
398 /*                                                                          */
399 /****************************************************************************/
400
401
402
403 struct CELTDecoder {
404    const CELTMode *mode;
405    int frame_size;
406    int block_size;
407    int nb_blocks;
408    int overlap;
409
410    ec_byte_buffer buf;
411    ec_enc         enc;
412
413    float preemph;
414    float *preemph_memD;
415    
416    mdct_lookup mdct_lookup;
417    
418    float *window;
419    float *mdct_overlap;
420    float *out_mem;
421
422    float *oldBandE;
423    
424    int last_pitch_index;
425 };
426
427 CELTDecoder *celt_decoder_new(const CELTMode *mode)
428 {
429    int i, N, B, C, N4;
430    N = mode->mdctSize;
431    B = mode->nbMdctBlocks;
432    C = mode->nbChannels;
433    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
434    
435    st->mode = mode;
436    st->frame_size = B*N;
437    st->block_size = N;
438    st->nb_blocks  = B;
439    st->overlap = mode->overlap;
440
441    N4 = (N-st->overlap)/2;
442    
443    mdct_init(&st->mdct_lookup, 2*N);
444    
445    st->window = celt_alloc(2*N*sizeof(float));
446    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
447    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
448
449    for (i=0;i<2*N;i++)
450       st->window[i] = 0;
451    for (i=0;i<st->overlap;i++)
452       st->window[N4+i] = st->window[2*N-N4-i-1] 
453             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
454    for (i=0;i<2*N4;i++)
455       st->window[N-N4+i] = 1;
456    
457    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
458
459    st->preemph = 0.8;
460    st->preemph_memD = celt_alloc(C*sizeof(float));;
461
462    st->last_pitch_index = 0;
463    return st;
464 }
465
466 void celt_decoder_destroy(CELTDecoder *st)
467 {
468    if (st == NULL)
469    {
470       celt_warning("NULL passed to celt_encoder_destroy");
471       return;
472    }
473
474    mdct_clear(&st->mdct_lookup);
475
476    celt_free(st->window);
477    celt_free(st->mdct_overlap);
478    celt_free(st->out_mem);
479    
480    celt_free(st->oldBandE);
481    
482    celt_free(st->preemph_memD);
483
484    celt_free(st);
485 }
486
487 static void celt_decode_lost(CELTDecoder *st, short *pcm)
488 {
489    int i, c, N, B, C;
490    N = st->block_size;
491    B = st->nb_blocks;
492    C = st->mode->nbChannels;
493    float X[C*B*N];         /**< Interleaved signal MDCTs */
494    int pitch_index;
495    
496    pitch_index = st->last_pitch_index;
497    
498    /* Use the pitch MDCT as the "guessed" signal */
499    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
500
501    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
502    /* Compute inverse MDCTs */
503    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
504
505    for (c=0;c<C;c++)
506    {
507       for (i=0;i<B;i++)
508       {
509          int j;
510          for (j=0;j<N;j++)
511          {
512             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
513             st->preemph_memD[c] = tmp;
514             if (tmp > 32767) tmp = 32767;
515             if (tmp < -32767) tmp = -32767;
516             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
517          }
518       }
519    }
520 }
521
522 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
523 {
524    int i, c, N, B, C;
525    int has_pitch;
526    N = st->block_size;
527    B = st->nb_blocks;
528    C = st->mode->nbChannels;
529    
530    float X[C*B*N];         /**< Interleaved signal MDCTs */
531    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
532    float bandE[st->mode->nbEBands*C];
533    float gains[st->mode->nbPBands];
534    int pitch_index;
535    ec_dec dec;
536    ec_byte_buffer buf;
537    
538    if (data == NULL)
539    {
540       celt_decode_lost(st, pcm);
541       return 0;
542    }
543    
544    ec_byte_readinit(&buf,data,len);
545    ec_dec_init(&dec,&buf);
546    
547    /* Get band energies */
548    unquant_energy(st->mode, bandE, st->oldBandE, len*8/3, &dec);
549    
550    /* Get the pitch gains */
551    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
552    
553    /* Get the pitch index */
554    if (has_pitch)
555    {
556       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
557       st->last_pitch_index = pitch_index;
558    } else {
559       /* FIXME: We could be more intelligent here and just not compute the MDCT */
560       pitch_index = 0;
561    }
562    
563    /* Pitch MDCT */
564    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
565
566    {
567       float bandEp[st->mode->nbEBands*C];
568       compute_band_energies(st->mode, P, bandEp);
569       normalise_bands(st->mode, P, bandEp);
570    }
571
572    if (C==2)
573       stereo_mix(st->mode, P, bandE, 1);
574
575    /* Apply pitch gains */
576    pitch_quant_bands(st->mode, X, P, gains);
577
578    /* Decode fixed codebook and merge with pitch */
579    unquant_bands(st->mode, X, P, len*8, &dec);
580
581    if (C==2)
582       stereo_mix(st->mode, X, bandE, -1);
583
584    renormalise_bands(st->mode, X);
585    
586    /* Synthesis */
587    denormalise_bands(st->mode, X, bandE);
588
589
590    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
591    /* Compute inverse MDCTs */
592    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
593
594    for (c=0;c<C;c++)
595    {
596       for (i=0;i<B;i++)
597       {
598          int j;
599          for (j=0;j<N;j++)
600          {
601             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
602             st->preemph_memD[c] = tmp;
603             if (tmp > 32767) tmp = 32767;
604             if (tmp < -32767) tmp = -32767;
605             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
606          }
607       }
608    }
609
610    {
611       int val = 0;
612       while (ec_dec_tell(&dec, 0) < len*8)
613       {
614          if (ec_dec_uint(&dec, 2) != val)
615          {
616             celt_warning("decode error");
617             return CELT_CORRUPTED_DATA;
618          }
619          val = 1-val;
620       }
621    }
622
623    return 0;
624    //printf ("\n");
625 }
626