e887428a6213788dac0e1e67a49ddfdab7014cf1
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44
45 #define MAX_PERIOD 1024
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53    int overlap;
54    int channels;
55    int Fs;
56    
57    ec_byte_buffer buf;
58    ec_enc         enc;
59
60    float preemph;
61    float *preemph_memE;
62    float *preemph_memD;
63    
64    mdct_lookup mdct_lookup;
65    void *fft;
66    
67    float *window;
68    float *in_mem;
69    float *mdct_overlap;
70    float *out_mem;
71
72    float *oldBandE;
73 };
74
75
76
77 CELTEncoder *celt_encoder_new(const CELTMode *mode)
78 {
79    int i, N, B, C, N4;
80    N = mode->mdctSize;
81    B = mode->nbMdctBlocks;
82    C = mode->nbChannels;
83    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
84    
85    st->mode = mode;
86    st->frame_size = B*N;
87    st->block_size = N;
88    st->nb_blocks  = B;
89    st->overlap = mode->overlap;
90    st->Fs = 44100;
91
92    N4 = (N-st->overlap)/2;
93    ec_byte_writeinit(&st->buf);
94    ec_enc_init(&st->enc,&st->buf);
95
96    mdct_init(&st->mdct_lookup, 2*N);
97    st->fft = spx_fft_init(MAX_PERIOD*C);
98    
99    st->window = celt_alloc(2*N*sizeof(float));
100    st->in_mem = celt_alloc(N*C*sizeof(float));
101    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
102    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
103    for (i=0;i<2*N;i++)
104       st->window[i] = 0;
105    for (i=0;i<st->overlap;i++)
106       st->window[N4+i] = st->window[2*N-N4-i-1] 
107             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
108    for (i=0;i<2*N4;i++)
109       st->window[N-N4+i] = 1;
110    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
111
112    st->preemph = 0.8;
113    st->preemph_memE = celt_alloc(C*sizeof(float));;
114    st->preemph_memD = celt_alloc(C*sizeof(float));;
115
116    return st;
117 }
118
119 void celt_encoder_destroy(CELTEncoder *st)
120 {
121    if (st == NULL)
122    {
123       celt_warning("NULL passed to celt_encoder_destroy");
124       return;
125    }
126    ec_byte_writeclear(&st->buf);
127
128    mdct_clear(&st->mdct_lookup);
129    spx_fft_destroy(st->fft);
130
131    celt_free(st->window);
132    celt_free(st->in_mem);
133    celt_free(st->mdct_overlap);
134    celt_free(st->out_mem);
135    
136    celt_free(st->oldBandE);
137    celt_free(st);
138 }
139
140 static void haar1(float *X, int N, int stride)
141 {
142    int i, k;
143    for (k=0;k<stride;k++)
144    {
145       for (i=k;i<N*stride;i+=2*stride)
146       {
147          float a, b;
148          a = X[i];
149          b = X[i+stride];
150          X[i] = .707107f*(a+b);
151          X[i+stride] = .707107f*(a-b);
152       }
153    }
154 }
155
156 static void time_dct(float *X, int N, int B, int stride)
157 {
158    switch (B)
159    {
160       case 1:
161          break;
162       case 2:
163          haar1(X, B*N, stride);
164          break;
165       default:
166          celt_warning("time_dct not defined for B > 2");
167    };
168 }
169
170 static void time_idct(float *X, int N, int B, int stride)
171 {
172    switch (B)
173    {
174       case 1:
175          break;
176       case 2:
177          haar1(X, B*N, stride);
178          break;
179       default:
180          celt_warning("time_dct not defined for B > 2");
181    };
182 }
183
184 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
185 {
186    int i, c;
187    for (c=0;c<C;c++)
188    {
189       for (i=0;i<B;i++)
190       {
191          int j;
192          float x[2*N];
193          float tmp[N];
194          for (j=0;j<2*N;j++)
195             x[j] = window[j]*in[C*i*N+C*j+c];
196          mdct_forward(mdct_lookup, x, tmp);
197          /* Interleaving the sub-frames */
198          for (j=0;j<N;j++)
199             out[C*B*j+C*i+c] = tmp[j];
200       }
201    }
202 }
203
204 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
205 {
206    int i, c, N4;
207    N4 = (N-overlap)/2;
208    for (c=0;c<C;c++)
209    {
210       for (i=0;i<B;i++)
211       {
212          int j;
213          float x[2*N];
214          float tmp[N];
215          /* De-interleaving the sub-frames */
216          for (j=0;j<N;j++)
217             tmp[j] = X[C*B*j+C*i+c];
218          mdct_backward(mdct_lookup, tmp, x);
219          for (j=0;j<2*N;j++)
220             x[j] = window[j]*x[j];
221          for (j=0;j<overlap;j++)
222             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
223          for (j=0;j<2*N4;j++)
224             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
225          for (j=0;j<overlap;j++)
226             mdct_overlap[C*j+c] = x[N+N4+j];
227       }
228    }
229 }
230
231 int celt_encode(CELTEncoder *st, short *pcm)
232 {
233    int i, c, N, B, C, N4;
234    N = st->block_size;
235    B = st->nb_blocks;
236    C = st->mode->nbChannels;
237    float in[(B+1)*C*N];
238
239    float X[B*C*N];         /**< Interleaved signal MDCTs */
240    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
241    float mask[B*C*N];      /**< Masking curve */
242    float bandE[st->mode->nbEBands*C];
243    float gains[st->mode->nbPBands];
244    int pitch_index;
245
246    N4 = (N-st->overlap)/2;
247
248    for (c=0;c<C;c++)
249    {
250       for (i=0;i<N4;i++)
251          in[C*i+c] = 0;
252       for (i=0;i<st->overlap;i++)
253          in[C*(i+N4)+c] = st->in_mem[C*i+c];
254       for (i=0;i<B*N;i++)
255       {
256          float tmp = pcm[C*i+c];
257          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
258          st->preemph_memE[c] = tmp;
259       }
260       for (i=N*(B+1)-N4;i<N*(B+1);i++)
261          in[C*i+c] = 0;
262       for (i=0;i<st->overlap;i++)
263          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
264    }
265    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
266    /* Compute MDCTs */
267    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
268
269    compute_mdct_masking(X, mask, B*C*N, st->Fs);
270
271    /* Invert and stretch the mask to length of X 
272       For some reason, I get better results by using the sqrt instead,
273       although there's no valid reason to. Must investigate further */
274    for (i=0;i<B*C*N;i++)
275       mask[i] = 1/(.1+mask[i]);
276
277    /* Pitch analysis */
278    for (c=0;c<C;c++)
279    {
280       for (i=0;i<N;i++)
281       {
282          in[C*i+c] *= st->window[i];
283          in[C*(B*N+i)+c] *= st->window[N+i];
284       }
285    }
286    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
287    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
288    
289    /* Compute MDCTs of the pitch part */
290    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
291    
292    /*int j;
293    for (j=0;j<B*N;j++)
294       printf ("%f ", X[j]);
295    for (j=0;j<B*N;j++)
296       printf ("%f ", P[j]);
297    printf ("\n");*/
298
299    /* Band normalisation */
300    compute_band_energies(st->mode, X, bandE);
301    normalise_bands(st->mode, X, bandE);
302    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
303    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
304
305    /* Normalise the pitch vector as well (discard the energies) */
306    {
307       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
308       compute_band_energies(st->mode, P, bandEp);
309       normalise_bands(st->mode, P, bandEp);
310    }
311
312    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
313
314    if (C==2)
315    {
316       stereo_mix(st->mode, X, bandE, 1);
317       stereo_mix(st->mode, P, bandE, 1);
318       //haar1(X, B*N*C, 1);
319       //haar1(P, B*N*C, 1);
320    }
321    /* Simulates intensity stereo */
322    //for (i=30;i<N*B;i++)
323    //   X[i*C+1] = P[i*C+1] = 0;
324    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
325    time_dct(X, N, B, C);
326    time_dct(P, N, B, C);
327
328
329    /* Pitch prediction */
330    compute_pitch_gain(st->mode, X, P, gains, bandE);
331    quant_pitch(gains, st->mode->nbPBands, &st->enc);
332    pitch_quant_bands(st->mode, X, P, gains);
333
334    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
335    /* Compute residual that we're going to encode */
336    for (i=0;i<B*C*N;i++)
337       X[i] -= P[i];
338
339    /*float sum=0;
340    for (i=0;i<B*N;i++)
341       sum += X[i]*X[i];
342    printf ("%f\n", sum);*/
343    /* Residual quantisation */
344    quant_bands(st->mode, X, P, mask, &st->enc);
345    
346    time_idct(X, N, B, C);
347    if (C==2)
348       //haar1(X, B*N*C, 1);
349       stereo_mix(st->mode, X, bandE, -1);
350
351    renormalise_bands(st->mode, X);
352    /* Synthesis */
353    denormalise_bands(st->mode, X, bandE);
354
355
356    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
357
358    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
359    /* De-emphasis and put everything back at the right place in the synthesis history */
360    for (c=0;c<C;c++)
361    {
362       for (i=0;i<B;i++)
363       {
364          int j;
365          for (j=0;j<N;j++)
366          {
367             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
368             st->preemph_memD[c] = tmp;
369             if (tmp > 32767) tmp = 32767;
370             if (tmp < -32767) tmp = -32767;
371             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
372          }
373       }
374    }
375    return 0;
376 }
377
378 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
379 {
380    char *data;
381    ec_enc_done(&st->enc);
382    *nbBytes = ec_byte_bytes(&st->buf);
383    data = ec_byte_get_buffer(&st->buf);
384    //printf ("%d\n", *nbBytes);
385    
386    /* Reset the packing for the next encoding */
387    ec_byte_reset(&st->buf);
388    ec_enc_init(&st->enc,&st->buf);
389
390    return data;
391 }
392
393
394 /****************************************************************************/
395 /*                                                                          */
396 /*                                DECODER                                   */
397 /*                                                                          */
398 /****************************************************************************/
399
400
401
402 struct CELTDecoder {
403    const CELTMode *mode;
404    int frame_size;
405    int block_size;
406    int nb_blocks;
407    int overlap;
408
409    ec_byte_buffer buf;
410    ec_enc         enc;
411
412    float preemph;
413    float *preemph_memD;
414    
415    mdct_lookup mdct_lookup;
416    
417    float *window;
418    float *mdct_overlap;
419    float *out_mem;
420
421    float *oldBandE;
422    
423    int last_pitch_index;
424 };
425
426 CELTDecoder *celt_decoder_new(const CELTMode *mode)
427 {
428    int i, N, B, C, N4;
429    N = mode->mdctSize;
430    B = mode->nbMdctBlocks;
431    C = mode->nbChannels;
432    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
433    
434    st->mode = mode;
435    st->frame_size = B*N;
436    st->block_size = N;
437    st->nb_blocks  = B;
438    st->overlap = mode->overlap;
439
440    N4 = (N-st->overlap)/2;
441    
442    mdct_init(&st->mdct_lookup, 2*N);
443    
444    st->window = celt_alloc(2*N*sizeof(float));
445    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
446    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
447
448    for (i=0;i<2*N;i++)
449       st->window[i] = 0;
450    for (i=0;i<st->overlap;i++)
451       st->window[N4+i] = st->window[2*N-N4-i-1] 
452             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
453    for (i=0;i<2*N4;i++)
454       st->window[N-N4+i] = 1;
455    
456    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
457
458    st->preemph = 0.8;
459    st->preemph_memD = celt_alloc(C*sizeof(float));;
460
461    st->last_pitch_index = 0;
462    return st;
463 }
464
465 void celt_decoder_destroy(CELTDecoder *st)
466 {
467    if (st == NULL)
468    {
469       celt_warning("NULL passed to celt_encoder_destroy");
470       return;
471    }
472
473    mdct_clear(&st->mdct_lookup);
474
475    celt_free(st->window);
476    celt_free(st->mdct_overlap);
477    celt_free(st->out_mem);
478    
479    celt_free(st->oldBandE);
480    celt_free(st);
481 }
482
483 static void celt_decode_lost(CELTDecoder *st, short *pcm)
484 {
485    int i, c, N, B, C;
486    N = st->block_size;
487    B = st->nb_blocks;
488    C = st->mode->nbChannels;
489    float X[C*B*N];         /**< Interleaved signal MDCTs */
490    int pitch_index;
491    
492    pitch_index = st->last_pitch_index;
493    
494    /* Use the pitch MDCT as the "guessed" signal */
495    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
496
497    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
498    /* Compute inverse MDCTs */
499    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
500
501    for (c=0;c<C;c++)
502    {
503       for (i=0;i<B;i++)
504       {
505          int j;
506          for (j=0;j<N;j++)
507          {
508             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
509             st->preemph_memD[c] = tmp;
510             if (tmp > 32767) tmp = 32767;
511             if (tmp < -32767) tmp = -32767;
512             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
513          }
514       }
515    }
516 }
517
518 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
519 {
520    int i, c, N, B, C;
521    N = st->block_size;
522    B = st->nb_blocks;
523    C = st->mode->nbChannels;
524    
525    float X[C*B*N];         /**< Interleaved signal MDCTs */
526    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
527    float bandE[st->mode->nbEBands*C];
528    float gains[st->mode->nbPBands];
529    int pitch_index;
530    ec_dec dec;
531    ec_byte_buffer buf;
532    
533    if (data == NULL)
534    {
535       celt_decode_lost(st, pcm);
536       return 0;
537    }
538    
539    ec_byte_readinit(&buf,data,len);
540    ec_dec_init(&dec,&buf);
541    
542    /* Get the pitch index */
543    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
544    st->last_pitch_index = pitch_index;
545    
546    /* Get band energies */
547    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
548    
549    /* Pitch MDCT */
550    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
551
552    {
553       float bandEp[st->mode->nbEBands];
554       compute_band_energies(st->mode, P, bandEp);
555       normalise_bands(st->mode, P, bandEp);
556    }
557
558    if (C==2)
559       //haar1(P, B*N*C, 1);
560       stereo_mix(st->mode, P, bandE, 1);
561    time_dct(P, N, B, C);
562
563    /* Get the pitch gains */
564    unquant_pitch(gains, st->mode->nbPBands, &dec);
565
566    /* Apply pitch gains */
567    pitch_quant_bands(st->mode, X, P, gains);
568
569    /* Decode fixed codebook and merge with pitch */
570    unquant_bands(st->mode, X, P, &dec);
571
572    time_idct(X, N, B, C);
573    if (C==2)
574       //haar1(X, B*N*C, 1);
575       stereo_mix(st->mode, X, bandE, -1);
576
577    renormalise_bands(st->mode, X);
578    
579    /* Synthesis */
580    denormalise_bands(st->mode, X, bandE);
581
582
583    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
584    /* Compute inverse MDCTs */
585    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
586
587    for (c=0;c<C;c++)
588    {
589       for (i=0;i<B;i++)
590       {
591          int j;
592          for (j=0;j<N;j++)
593          {
594             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
595             st->preemph_memD[c] = tmp;
596             if (tmp > 32767) tmp = 32767;
597             if (tmp < -32767) tmp = -32767;
598             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
599          }
600       }
601    }
602    return 0;
603    //printf ("\n");
604 }
605