Further simplified the API by passing the rate directly to the
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146 static void haar1(float *X, int N, int stride)
147 {
148    int i, k;
149    for (k=0;k<stride;k++)
150    {
151       for (i=k;i<N*stride;i+=2*stride)
152       {
153          float a, b;
154          a = X[i];
155          b = X[i+stride];
156          X[i] = .707107f*(a+b);
157          X[i+stride] = .707107f*(a-b);
158       }
159    }
160 }
161
162 static void time_dct(float *X, int N, int B, int stride)
163 {
164    switch (B)
165    {
166       case 1:
167          break;
168       case 2:
169          haar1(X, B*N, stride);
170          break;
171       default:
172          celt_warning("time_dct not defined for B > 2");
173    };
174 }
175
176 static void time_idct(float *X, int N, int B, int stride)
177 {
178    switch (B)
179    {
180       case 1:
181          break;
182       case 2:
183          haar1(X, B*N, stride);
184          break;
185       default:
186          celt_warning("time_dct not defined for B > 2");
187    };
188 }
189
190 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
191 {
192    int i, c;
193    for (c=0;c<C;c++)
194    {
195       for (i=0;i<B;i++)
196       {
197          int j;
198          float x[2*N];
199          float tmp[N];
200          for (j=0;j<2*N;j++)
201             x[j] = window[j]*in[C*i*N+C*j+c];
202          mdct_forward(mdct_lookup, x, tmp);
203          /* Interleaving the sub-frames */
204          for (j=0;j<N;j++)
205             out[C*B*j+C*i+c] = tmp[j];
206       }
207    }
208 }
209
210 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
211 {
212    int i, c, N4;
213    N4 = (N-overlap)/2;
214    for (c=0;c<C;c++)
215    {
216       for (i=0;i<B;i++)
217       {
218          int j;
219          float x[2*N];
220          float tmp[N];
221          /* De-interleaving the sub-frames */
222          for (j=0;j<N;j++)
223             tmp[j] = X[C*B*j+C*i+c];
224          mdct_backward(mdct_lookup, tmp, x);
225          for (j=0;j<2*N;j++)
226             x[j] = window[j]*x[j];
227          for (j=0;j<overlap;j++)
228             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
229          for (j=0;j<2*N4;j++)
230             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
231          for (j=0;j<overlap;j++)
232             mdct_overlap[C*j+c] = x[N+N4+j];
233       }
234    }
235 }
236
237 int celt_encode(CELTEncoder *st, short *pcm, char *compressed, int nbCompressedBytes)
238 {
239    int i, c, N, B, C, N4;
240    N = st->block_size;
241    B = st->nb_blocks;
242    C = st->mode->nbChannels;
243    float in[(B+1)*C*N];
244
245    float X[B*C*N];         /**< Interleaved signal MDCTs */
246    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
247    float mask[B*C*N];      /**< Masking curve */
248    float bandE[st->mode->nbEBands*C];
249    float gains[st->mode->nbPBands];
250    int pitch_index;
251
252    N4 = (N-st->overlap)/2;
253
254    for (c=0;c<C;c++)
255    {
256       for (i=0;i<N4;i++)
257          in[C*i+c] = 0;
258       for (i=0;i<st->overlap;i++)
259          in[C*(i+N4)+c] = st->in_mem[C*i+c];
260       for (i=0;i<B*N;i++)
261       {
262          float tmp = pcm[C*i+c];
263          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
264          st->preemph_memE[c] = tmp;
265       }
266       for (i=N*(B+1)-N4;i<N*(B+1);i++)
267          in[C*i+c] = 0;
268       for (i=0;i<st->overlap;i++)
269          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
270    }
271    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
272    /* Compute MDCTs */
273    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
274
275    compute_mdct_masking(X, mask, B*C*N, st->Fs);
276
277    /* Invert and stretch the mask to length of X 
278       For some reason, I get better results by using the sqrt instead,
279       although there's no valid reason to. Must investigate further */
280    for (i=0;i<B*C*N;i++)
281       mask[i] = 1/(.1+mask[i]);
282
283    /* Pitch analysis */
284    for (c=0;c<C;c++)
285    {
286       for (i=0;i<N;i++)
287       {
288          in[C*i+c] *= st->window[i];
289          in[C*(B*N+i)+c] *= st->window[N+i];
290       }
291    }
292    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
293    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
294    
295    /* Compute MDCTs of the pitch part */
296    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
297    
298    /*int j;
299    for (j=0;j<B*N;j++)
300       printf ("%f ", X[j]);
301    for (j=0;j<B*N;j++)
302       printf ("%f ", P[j]);
303    printf ("\n");*/
304
305    /* Band normalisation */
306    compute_band_energies(st->mode, X, bandE);
307    normalise_bands(st->mode, X, bandE);
308    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
309    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
310
311    /* Normalise the pitch vector as well (discard the energies) */
312    {
313       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
314       compute_band_energies(st->mode, P, bandEp);
315       normalise_bands(st->mode, P, bandEp);
316    }
317
318    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
319
320    if (C==2)
321    {
322       stereo_mix(st->mode, X, bandE, 1);
323       stereo_mix(st->mode, P, bandE, 1);
324       //haar1(X, B*N*C, 1);
325       //haar1(P, B*N*C, 1);
326    }
327    /* Simulates intensity stereo */
328    //for (i=30;i<N*B;i++)
329    //   X[i*C+1] = P[i*C+1] = 0;
330    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
331    time_dct(X, N, B, C);
332    time_dct(P, N, B, C);
333
334
335    /* Pitch prediction */
336    compute_pitch_gain(st->mode, X, P, gains, bandE);
337    quant_pitch(gains, st->mode->nbPBands, &st->enc);
338    pitch_quant_bands(st->mode, X, P, gains);
339
340    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
341    /* Compute residual that we're going to encode */
342    for (i=0;i<B*C*N;i++)
343       X[i] -= P[i];
344
345    /*float sum=0;
346    for (i=0;i<B*N;i++)
347       sum += X[i]*X[i];
348    printf ("%f\n", sum);*/
349    /* Residual quantisation */
350    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
351    
352    time_idct(X, N, B, C);
353    if (C==2)
354       //haar1(X, B*N*C, 1);
355       stereo_mix(st->mode, X, bandE, -1);
356
357    renormalise_bands(st->mode, X);
358    /* Synthesis */
359    denormalise_bands(st->mode, X, bandE);
360
361
362    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
363
364    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
365    /* De-emphasis and put everything back at the right place in the synthesis history */
366    for (c=0;c<C;c++)
367    {
368       for (i=0;i<B;i++)
369       {
370          int j;
371          for (j=0;j<N;j++)
372          {
373             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
374             st->preemph_memD[c] = tmp;
375             if (tmp > 32767) tmp = 32767;
376             if (tmp < -32767) tmp = -32767;
377             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
378          }
379       }
380    }
381    
382    while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
383       ec_enc_uint(&st->enc, 1, 2);
384    ec_enc_done(&st->enc);
385    {
386       unsigned char *data;
387       int nbBytes = ec_byte_bytes(&st->buf);
388       //printf ("%d\n", *nbBytes);
389       data = ec_byte_get_buffer(&st->buf);
390       for (i=0;i<nbBytes;i++)
391          compressed[i] = data[i];
392       /* Fill the last byte with the right pattern so the decoder doesn't get confused
393          if the encoder didn't return enough bytes */
394       /* FIXME: This isn't quite what the decoder expects, but it's the best we can do for now */
395       if (nbBytes < nbCompressedBytes)
396       {
397          //fprintf (stderr, "smaller: %d\n", compressed[nbBytes-1]);
398          if (compressed[nbBytes-1] == 0x00)
399          {
400             compressed[i++] = 0x80;
401             //fprintf (stderr, "put 0x00\n");
402          } else if (compressed[nbBytes-1] == 0x80)
403          {
404             int k = nbBytes-1;
405             while (compressed[k-1] == 0x80)
406             {
407                k--;
408             }
409             if (compressed[k-1] == 0x00)
410             {
411                compressed[i++] = 0x80;
412                //fprintf (stderr, "special 0x00\n");
413             }
414          }
415          for (;i<nbCompressedBytes;i++)
416             compressed[i] = 0x00;
417       } else if (nbBytes < nbCompressedBytes)
418       {
419          //fprintf (stderr, "ERROR: too many bits\n");
420       }
421    }   
422    /* Reset the packing for the next encoding */
423    ec_byte_reset(&st->buf);
424    ec_enc_init(&st->enc,&st->buf);
425
426    return nbCompressedBytes;
427 }
428
429 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
430 {
431    char *data;
432    ec_enc_done(&st->enc);
433    *nbBytes = ec_byte_bytes(&st->buf);
434    data = ec_byte_get_buffer(&st->buf);
435    //printf ("%d\n", *nbBytes);
436    
437    /* Reset the packing for the next encoding */
438    ec_byte_reset(&st->buf);
439    ec_enc_init(&st->enc,&st->buf);
440
441    return data;
442 }
443
444
445 /****************************************************************************/
446 /*                                                                          */
447 /*                                DECODER                                   */
448 /*                                                                          */
449 /****************************************************************************/
450
451
452
453 struct CELTDecoder {
454    const CELTMode *mode;
455    int frame_size;
456    int block_size;
457    int nb_blocks;
458    int overlap;
459
460    ec_byte_buffer buf;
461    ec_enc         enc;
462
463    float preemph;
464    float *preemph_memD;
465    
466    mdct_lookup mdct_lookup;
467    
468    float *window;
469    float *mdct_overlap;
470    float *out_mem;
471
472    float *oldBandE;
473    
474    int last_pitch_index;
475    
476    struct alloc_data alloc;
477 };
478
479 CELTDecoder *celt_decoder_new(const CELTMode *mode)
480 {
481    int i, N, B, C, N4;
482    N = mode->mdctSize;
483    B = mode->nbMdctBlocks;
484    C = mode->nbChannels;
485    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
486    
487    st->mode = mode;
488    st->frame_size = B*N;
489    st->block_size = N;
490    st->nb_blocks  = B;
491    st->overlap = mode->overlap;
492
493    N4 = (N-st->overlap)/2;
494    
495    mdct_init(&st->mdct_lookup, 2*N);
496    
497    st->window = celt_alloc(2*N*sizeof(float));
498    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
499    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
500
501    for (i=0;i<2*N;i++)
502       st->window[i] = 0;
503    for (i=0;i<st->overlap;i++)
504       st->window[N4+i] = st->window[2*N-N4-i-1] 
505             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
506    for (i=0;i<2*N4;i++)
507       st->window[N-N4+i] = 1;
508    
509    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
510
511    st->preemph = 0.8;
512    st->preemph_memD = celt_alloc(C*sizeof(float));;
513
514    st->last_pitch_index = 0;
515    alloc_init(&st->alloc, st->mode);
516
517    return st;
518 }
519
520 void celt_decoder_destroy(CELTDecoder *st)
521 {
522    if (st == NULL)
523    {
524       celt_warning("NULL passed to celt_encoder_destroy");
525       return;
526    }
527
528    mdct_clear(&st->mdct_lookup);
529
530    celt_free(st->window);
531    celt_free(st->mdct_overlap);
532    celt_free(st->out_mem);
533    
534    celt_free(st->oldBandE);
535    alloc_clear(&st->alloc);
536
537    celt_free(st);
538 }
539
540 static void celt_decode_lost(CELTDecoder *st, short *pcm)
541 {
542    int i, c, N, B, C;
543    N = st->block_size;
544    B = st->nb_blocks;
545    C = st->mode->nbChannels;
546    float X[C*B*N];         /**< Interleaved signal MDCTs */
547    int pitch_index;
548    
549    pitch_index = st->last_pitch_index;
550    
551    /* Use the pitch MDCT as the "guessed" signal */
552    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
553
554    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
555    /* Compute inverse MDCTs */
556    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
557
558    for (c=0;c<C;c++)
559    {
560       for (i=0;i<B;i++)
561       {
562          int j;
563          for (j=0;j<N;j++)
564          {
565             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
566             st->preemph_memD[c] = tmp;
567             if (tmp > 32767) tmp = 32767;
568             if (tmp < -32767) tmp = -32767;
569             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
570          }
571       }
572    }
573 }
574
575 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
576 {
577    int i, c, N, B, C;
578    N = st->block_size;
579    B = st->nb_blocks;
580    C = st->mode->nbChannels;
581    
582    float X[C*B*N];         /**< Interleaved signal MDCTs */
583    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
584    float bandE[st->mode->nbEBands*C];
585    float gains[st->mode->nbPBands];
586    int pitch_index;
587    ec_dec dec;
588    ec_byte_buffer buf;
589    
590    if (data == NULL)
591    {
592       celt_decode_lost(st, pcm);
593       return 0;
594    }
595    
596    ec_byte_readinit(&buf,data,len);
597    ec_dec_init(&dec,&buf);
598    
599    /* Get the pitch index */
600    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
601    st->last_pitch_index = pitch_index;
602    
603    /* Get band energies */
604    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
605    
606    /* Pitch MDCT */
607    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
608
609    {
610       float bandEp[st->mode->nbEBands];
611       compute_band_energies(st->mode, P, bandEp);
612       normalise_bands(st->mode, P, bandEp);
613    }
614
615    if (C==2)
616       //haar1(P, B*N*C, 1);
617       stereo_mix(st->mode, P, bandE, 1);
618    time_dct(P, N, B, C);
619
620    /* Get the pitch gains */
621    unquant_pitch(gains, st->mode->nbPBands, &dec);
622
623    /* Apply pitch gains */
624    pitch_quant_bands(st->mode, X, P, gains);
625
626    /* Decode fixed codebook and merge with pitch */
627    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
628
629    time_idct(X, N, B, C);
630    if (C==2)
631       //haar1(X, B*N*C, 1);
632       stereo_mix(st->mode, X, bandE, -1);
633
634    renormalise_bands(st->mode, X);
635    
636    /* Synthesis */
637    denormalise_bands(st->mode, X, bandE);
638
639
640    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
641    /* Compute inverse MDCTs */
642    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
643
644    for (c=0;c<C;c++)
645    {
646       for (i=0;i<B;i++)
647       {
648          int j;
649          for (j=0;j<N;j++)
650          {
651             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
652             st->preemph_memD[c] = tmp;
653             if (tmp > 32767) tmp = 32767;
654             if (tmp < -32767) tmp = -32767;
655             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
656          }
657       }
658    }
659    
660    return 0;
661    //printf ("\n");
662 }
663