Changed _new() to _create() in the API. Added some documentation
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "os_support.h"
37 #include "mdct.h"
38 #include <math.h>
39 #include "celt.h"
40 #include "pitch.h"
41 #include "kiss_fftr.h"
42 #include "bands.h"
43 #include "modes.h"
44 #include "entcode.h"
45 #include "quant_pitch.h"
46 #include "quant_bands.h"
47 #include "psy.h"
48 #include "rate.h"
49
50 #define MAX_PERIOD 1024
51
52 #ifndef M_PI
53 #define M_PI 3.14159263
54 #endif
55
56 struct CELTEncoder {
57    const CELTMode *mode;
58    int frame_size;
59    int block_size;
60    int nb_blocks;
61    int overlap;
62    int channels;
63    int Fs;
64    
65    ec_byte_buffer buf;
66    ec_enc         enc;
67
68    float preemph;
69    float *preemph_memE;
70    float *preemph_memD;
71    
72    mdct_lookup mdct_lookup;
73    kiss_fftr_cfg fft;
74    struct PsyDecay psy;
75    
76    float *window;
77    float *in_mem;
78    float *mdct_overlap;
79    float *out_mem;
80
81    float *oldBandE;
82 };
83
84
85
86 CELTEncoder *celt_encoder_create(const CELTMode *mode)
87 {
88    int i, N, B, C, N4;
89    CELTEncoder *st;
90    N = mode->mdctSize;
91    B = mode->nbMdctBlocks;
92    C = mode->nbChannels;
93    st = celt_alloc(sizeof(CELTEncoder));
94    
95    st->mode = mode;
96    st->frame_size = B*N;
97    st->block_size = N;
98    st->nb_blocks  = B;
99    st->overlap = mode->overlap;
100    st->Fs = 44100;
101
102    N4 = (N-st->overlap)/2;
103    ec_byte_writeinit(&st->buf);
104    ec_enc_init(&st->enc,&st->buf);
105
106    mdct_init(&st->mdct_lookup, 2*N);
107    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
108    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
109    
110    st->window = celt_alloc(2*N*sizeof(float));
111    st->in_mem = celt_alloc(N*C*sizeof(float));
112    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
113    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
114    for (i=0;i<2*N;i++)
115       st->window[i] = 0;
116    for (i=0;i<st->overlap;i++)
117       st->window[N4+i] = st->window[2*N-N4-i-1] 
118             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
119    for (i=0;i<2*N4;i++)
120       st->window[N-N4+i] = 1;
121    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
122
123    st->preemph = 0.8;
124    st->preemph_memE = celt_alloc(C*sizeof(float));;
125    st->preemph_memD = celt_alloc(C*sizeof(float));;
126
127    return st;
128 }
129
130 void celt_encoder_destroy(CELTEncoder *st)
131 {
132    if (st == NULL)
133    {
134       celt_warning("NULL passed to celt_encoder_destroy");
135       return;
136    }
137    ec_byte_writeclear(&st->buf);
138
139    mdct_clear(&st->mdct_lookup);
140    kiss_fft_free(st->fft);
141    psydecay_clear(&st->psy);
142
143    celt_free(st->window);
144    celt_free(st->in_mem);
145    celt_free(st->mdct_overlap);
146    celt_free(st->out_mem);
147    
148    celt_free(st->oldBandE);
149    
150    celt_free(st->preemph_memE);
151    celt_free(st->preemph_memD);
152    
153    celt_free(st);
154 }
155
156
157 static float compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
158 {
159    int i, c;
160    float E = 1e-15;
161    VARDECL(float *x);
162    VARDECL(float *tmp);
163    ALLOC(x, 2*N, float);
164    ALLOC(tmp, N, float);
165    for (c=0;c<C;c++)
166    {
167       for (i=0;i<B;i++)
168       {
169          int j;
170          for (j=0;j<2*N;j++)
171          {
172             x[j] = window[j]*in[C*i*N+C*j+c];
173             E += x[j]*x[j];
174          }
175          mdct_forward(mdct_lookup, x, tmp);
176          /* Interleaving the sub-frames */
177          for (j=0;j<N;j++)
178             out[C*B*j+C*i+c] = tmp[j];
179       }
180    }
181    return E;
182 }
183
184 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
185 {
186    int i, c, N4;
187    VARDECL(float *x);
188    VARDECL(float *tmp);
189    ALLOC(x, 2*N, float);
190    ALLOC(tmp, N, float);
191    N4 = (N-overlap)/2;
192    for (c=0;c<C;c++)
193    {
194       for (i=0;i<B;i++)
195       {
196          int j;
197          /* De-interleaving the sub-frames */
198          for (j=0;j<N;j++)
199             tmp[j] = X[C*B*j+C*i+c];
200          mdct_backward(mdct_lookup, tmp, x);
201          for (j=0;j<2*N;j++)
202             x[j] = window[j]*x[j];
203          for (j=0;j<overlap;j++)
204             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
205          for (j=0;j<2*N4;j++)
206             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
207          for (j=0;j<overlap;j++)
208             mdct_overlap[C*j+c] = x[N+N4+j];
209       }
210    }
211 }
212
213 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
214 {
215    int i, c, N, B, C, N4;
216    int has_pitch;
217    int pitch_index;
218    float curr_power, pitch_power;
219    VARDECL(float *in);
220    VARDECL(float *X);
221    VARDECL(float *P);
222    VARDECL(float *mask);
223    VARDECL(float *bandE);
224    VARDECL(float *gains);
225    N = st->block_size;
226    B = st->nb_blocks;
227    C = st->mode->nbChannels;
228    ALLOC(in, (B+1)*C*N, float);
229    ALLOC(X, B*C*N, float);         /**< Interleaved signal MDCTs */
230    ALLOC(P, B*C*N, float);         /**< Interleaved pitch MDCTs*/
231    ALLOC(mask, B*C*N, float);      /**< Masking curve */
232    ALLOC(bandE,st->mode->nbEBands*C, float);
233    ALLOC(gains,st->mode->nbPBands, float);
234    
235    N4 = (N-st->overlap)/2;
236
237    for (c=0;c<C;c++)
238    {
239       for (i=0;i<N4;i++)
240          in[C*i+c] = 0;
241       for (i=0;i<st->overlap;i++)
242          in[C*(i+N4)+c] = st->in_mem[C*i+c];
243       for (i=0;i<B*N;i++)
244       {
245          float tmp = pcm[C*i+c];
246          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
247          st->preemph_memE[c] = tmp;
248       }
249       for (i=N*(B+1)-N4;i<N*(B+1);i++)
250          in[C*i+c] = 0;
251       for (i=0;i<st->overlap;i++)
252          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
253    }
254    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
255    /* Compute MDCTs */
256    curr_power = compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
257
258 #if 0 /* Mask disabled until it can be made to do something useful */
259    compute_mdct_masking(X, mask, B*C*N, st->Fs);
260
261    /* Invert and stretch the mask to length of X 
262       For some reason, I get better results by using the sqrt instead,
263       although there's no valid reason to. Must investigate further */
264    for (i=0;i<B*C*N;i++)
265       mask[i] = 1/(.1+mask[i]);
266 #else
267    for (i=0;i<B*C*N;i++)
268       mask[i] = 1;
269 #endif
270    /* Pitch analysis */
271    for (c=0;c<C;c++)
272    {
273       for (i=0;i<N;i++)
274       {
275          in[C*i+c] *= st->window[i];
276          in[C*(B*N+i)+c] *= st->window[N+i];
277       }
278    }
279    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
280    
281    /* Compute MDCTs of the pitch part */
282    pitch_power = compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
283    
284    /*printf ("%f %f\n", curr_power, pitch_power);*/
285    /*int j;
286    for (j=0;j<B*N;j++)
287       printf ("%f ", X[j]);
288    for (j=0;j<B*N;j++)
289       printf ("%f ", P[j]);
290    printf ("\n");*/
291
292    /* Band normalisation */
293    compute_band_energies(st->mode, X, bandE);
294    normalise_bands(st->mode, X, bandE);
295    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
296    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
297
298    quant_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, &st->enc);
299
300    if (C==2)
301    {
302       stereo_mix(st->mode, X, bandE, 1);
303    }
304
305    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
306    if (curr_power + 1e5f < 10.f*pitch_power)
307    {
308       /* Normalise the pitch vector as well (discard the energies) */
309       VARDECL(float *bandEp);
310       ALLOC(bandEp, st->mode->nbEBands*st->mode->nbChannels, float);
311       compute_band_energies(st->mode, P, bandEp);
312       normalise_bands(st->mode, P, bandEp);
313
314       if (C==2)
315          stereo_mix(st->mode, P, bandE, 1);
316       /* Simulates intensity stereo */
317       /*for (i=30;i<N*B;i++)
318          X[i*C+1] = P[i*C+1] = 0;*/
319
320       /* Pitch prediction */
321       compute_pitch_gain(st->mode, X, P, gains, bandE);
322       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
323       if (has_pitch)
324          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
325    } else {
326       /* No pitch, so we just pretend we found a gain of zero */
327       for (i=0;i<st->mode->nbPBands;i++)
328          gains[i] = 0;
329       ec_enc_uint(&st->enc, 0, 128);
330       for (i=0;i<B*C*N;i++)
331          P[i] = 0;
332    }
333    
334
335    pitch_quant_bands(st->mode, X, P, gains);
336
337    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
338    /* Compute residual that we're going to encode */
339    for (i=0;i<B*C*N;i++)
340       X[i] -= P[i];
341
342    /*float sum=0;
343    for (i=0;i<B*N;i++)
344       sum += X[i]*X[i];
345    printf ("%f\n", sum);*/
346    /* Residual quantisation */
347    quant_bands(st->mode, X, P, mask, nbCompressedBytes*8, &st->enc);
348    
349    if (C==2)
350       stereo_mix(st->mode, X, bandE, -1);
351
352    renormalise_bands(st->mode, X);
353    /* Synthesis */
354    denormalise_bands(st->mode, X, bandE);
355
356
357    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
358
359    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
360    /* De-emphasis and put everything back at the right place in the synthesis history */
361    for (c=0;c<C;c++)
362    {
363       for (i=0;i<B;i++)
364       {
365          int j;
366          for (j=0;j<N;j++)
367          {
368             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
369             st->preemph_memD[c] = tmp;
370             if (tmp > 32767) tmp = 32767;
371             if (tmp < -32767) tmp = -32767;
372             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
373          }
374       }
375    }
376    
377    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
378       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
379    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
380    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
381    {
382       int val = 0;
383       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
384       {
385          ec_enc_uint(&st->enc, val, 2);
386          val = 1-val;
387       }
388    }
389    ec_enc_done(&st->enc);
390    {
391       unsigned char *data;
392       int nbBytes = ec_byte_bytes(&st->buf);
393       if (nbBytes > nbCompressedBytes)
394       {
395          celt_warning_int ("got too many bytes:", nbBytes);
396          return CELT_INTERNAL_ERROR;
397       }
398       /*printf ("%d\n", *nbBytes);*/
399       data = ec_byte_get_buffer(&st->buf);
400       for (i=0;i<nbBytes;i++)
401          compressed[i] = data[i];
402       for (;i<nbCompressedBytes;i++)
403          compressed[i] = 0;
404    }
405    /* Reset the packing for the next encoding */
406    ec_byte_reset(&st->buf);
407    ec_enc_init(&st->enc,&st->buf);
408
409    return nbCompressedBytes;
410 }
411
412
413 /****************************************************************************/
414 /*                                                                          */
415 /*                                DECODER                                   */
416 /*                                                                          */
417 /****************************************************************************/
418
419
420
421 struct CELTDecoder {
422    const CELTMode *mode;
423    int frame_size;
424    int block_size;
425    int nb_blocks;
426    int overlap;
427
428    ec_byte_buffer buf;
429    ec_enc         enc;
430
431    float preemph;
432    float *preemph_memD;
433    
434    mdct_lookup mdct_lookup;
435    
436    float *window;
437    float *mdct_overlap;
438    float *out_mem;
439
440    float *oldBandE;
441    
442    int last_pitch_index;
443 };
444
445 CELTDecoder *celt_decoder_create(const CELTMode *mode)
446 {
447    int i, N, B, C, N4;
448    CELTDecoder *st;
449    N = mode->mdctSize;
450    B = mode->nbMdctBlocks;
451    C = mode->nbChannels;
452    st = celt_alloc(sizeof(CELTDecoder));
453    
454    st->mode = mode;
455    st->frame_size = B*N;
456    st->block_size = N;
457    st->nb_blocks  = B;
458    st->overlap = mode->overlap;
459
460    N4 = (N-st->overlap)/2;
461    
462    mdct_init(&st->mdct_lookup, 2*N);
463    
464    st->window = celt_alloc(2*N*sizeof(float));
465    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
466    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
467
468    for (i=0;i<2*N;i++)
469       st->window[i] = 0;
470    for (i=0;i<st->overlap;i++)
471       st->window[N4+i] = st->window[2*N-N4-i-1] 
472             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
473    for (i=0;i<2*N4;i++)
474       st->window[N-N4+i] = 1;
475    
476    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
477
478    st->preemph = 0.8;
479    st->preemph_memD = celt_alloc(C*sizeof(float));;
480
481    st->last_pitch_index = 0;
482    return st;
483 }
484
485 void celt_decoder_destroy(CELTDecoder *st)
486 {
487    if (st == NULL)
488    {
489       celt_warning("NULL passed to celt_encoder_destroy");
490       return;
491    }
492
493    mdct_clear(&st->mdct_lookup);
494
495    celt_free(st->window);
496    celt_free(st->mdct_overlap);
497    celt_free(st->out_mem);
498    
499    celt_free(st->oldBandE);
500    
501    celt_free(st->preemph_memD);
502
503    celt_free(st);
504 }
505
506 static void celt_decode_lost(CELTDecoder *st, short *pcm)
507 {
508    int i, c, N, B, C;
509    int pitch_index;
510    VARDECL(float *X);
511    N = st->block_size;
512    B = st->nb_blocks;
513    C = st->mode->nbChannels;
514    ALLOC(X,C*B*N, float);         /**< Interleaved signal MDCTs */
515    
516    pitch_index = st->last_pitch_index;
517    
518    /* Use the pitch MDCT as the "guessed" signal */
519    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
520
521    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
522    /* Compute inverse MDCTs */
523    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
524
525    for (c=0;c<C;c++)
526    {
527       for (i=0;i<B;i++)
528       {
529          int j;
530          for (j=0;j<N;j++)
531          {
532             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
533             st->preemph_memD[c] = tmp;
534             if (tmp > 32767) tmp = 32767;
535             if (tmp < -32767) tmp = -32767;
536             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
537          }
538       }
539    }
540 }
541
542 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
543 {
544    int i, c, N, B, C;
545    int has_pitch;
546    int pitch_index;
547    ec_dec dec;
548    ec_byte_buffer buf;
549    VARDECL(float *X);
550    VARDECL(float *P);
551    VARDECL(float *bandE);
552    VARDECL(float *gains);
553    N = st->block_size;
554    B = st->nb_blocks;
555    C = st->mode->nbChannels;
556    
557    ALLOC(X, C*B*N, float);         /**< Interleaved signal MDCTs */
558    ALLOC(P, C*B*N, float);         /**< Interleaved pitch MDCTs*/
559    ALLOC(bandE, st->mode->nbEBands*C, float);
560    ALLOC(gains, st->mode->nbPBands, float);
561    
562    if (data == NULL)
563    {
564       celt_decode_lost(st, pcm);
565       return 0;
566    }
567    
568    ec_byte_readinit(&buf,data,len);
569    ec_dec_init(&dec,&buf);
570    
571    /* Get band energies */
572    unquant_energy(st->mode, bandE, st->oldBandE, len*8/3, &dec);
573    
574    /* Get the pitch gains */
575    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
576    
577    /* Get the pitch index */
578    if (has_pitch)
579    {
580       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
581       st->last_pitch_index = pitch_index;
582    } else {
583       /* FIXME: We could be more intelligent here and just not compute the MDCT */
584       pitch_index = 0;
585    }
586    
587    /* Pitch MDCT */
588    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
589
590    {
591       VARDECL(float *bandEp);
592       ALLOC(bandEp, st->mode->nbEBands*C, float);
593       compute_band_energies(st->mode, P, bandEp);
594       normalise_bands(st->mode, P, bandEp);
595    }
596
597    if (C==2)
598       stereo_mix(st->mode, P, bandE, 1);
599
600    /* Apply pitch gains */
601    pitch_quant_bands(st->mode, X, P, gains);
602
603    /* Decode fixed codebook and merge with pitch */
604    unquant_bands(st->mode, X, P, len*8, &dec);
605
606    if (C==2)
607       stereo_mix(st->mode, X, bandE, -1);
608
609    renormalise_bands(st->mode, X);
610    
611    /* Synthesis */
612    denormalise_bands(st->mode, X, bandE);
613
614
615    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
616    /* Compute inverse MDCTs */
617    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
618
619    for (c=0;c<C;c++)
620    {
621       for (i=0;i<B;i++)
622       {
623          int j;
624          for (j=0;j<N;j++)
625          {
626             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
627             st->preemph_memD[c] = tmp;
628             if (tmp > 32767) tmp = 32767;
629             if (tmp < -32767) tmp = -32767;
630             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
631          }
632       }
633    }
634
635    {
636       int val = 0;
637       while (ec_dec_tell(&dec, 0) < len*8)
638       {
639          if (ec_dec_uint(&dec, 2) != val)
640          {
641             celt_warning("decode error");
642             return CELT_CORRUPTED_DATA;
643          }
644          val = 1-val;
645       }
646    }
647
648    return 0;
649    /*printf ("\n");*/
650 }
651