Added a DCT in time direction when multiple MDCTs are used within the same
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44
45 #define MAX_PERIOD 1024
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53    int channels;
54    int Fs;
55    
56    ec_byte_buffer buf;
57    ec_enc         enc;
58
59    float preemph;
60    float *preemph_memE;
61    float *preemph_memD;
62    
63    mdct_lookup mdct_lookup;
64    void *fft;
65    
66    float *window;
67    float *in_mem;
68    float *mdct_overlap;
69    float *out_mem;
70
71    float *oldBandE;
72 };
73
74
75
76 CELTEncoder *celt_encoder_new(const CELTMode *mode)
77 {
78    int i, N, B, C;
79    N = mode->mdctSize;
80    B = mode->nbMdctBlocks;
81    C = mode->nbChannels;
82    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
83    
84    st->mode = mode;
85    st->frame_size = B*N;
86    st->block_size = N;
87    st->nb_blocks  = B;
88    st->Fs = 44100;
89    
90    ec_byte_writeinit(&st->buf);
91    ec_enc_init(&st->enc,&st->buf);
92
93    mdct_init(&st->mdct_lookup, 2*N);
94    st->fft = spx_fft_init(MAX_PERIOD*C);
95    
96    st->window = celt_alloc(2*N*sizeof(float));
97    st->in_mem = celt_alloc(N*C*sizeof(float));
98    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
99    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
100    for (i=0;i<N;i++)
101       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
102    
103    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
104
105    st->preemph = 0.8;
106    st->preemph_memE = celt_alloc(C*sizeof(float));;
107    st->preemph_memD = celt_alloc(C*sizeof(float));;
108
109    return st;
110 }
111
112 void celt_encoder_destroy(CELTEncoder *st)
113 {
114    if (st == NULL)
115    {
116       celt_warning("NULL passed to celt_encoder_destroy");
117       return;
118    }
119    ec_byte_writeclear(&st->buf);
120
121    mdct_clear(&st->mdct_lookup);
122    spx_fft_destroy(st->fft);
123
124    celt_free(st->window);
125    celt_free(st->in_mem);
126    celt_free(st->mdct_overlap);
127    celt_free(st->out_mem);
128    
129    celt_free(st->oldBandE);
130    celt_free(st);
131 }
132
133 static void haar1(float *X, int N, int stride)
134 {
135    int i, k;
136    for (k=0;k<stride;k++)
137    {
138       for (i=k;i<N*stride;i+=2*stride)
139       {
140          float a, b;
141          a = X[i];
142          b = X[i+stride];
143          X[i] = .707107f*(a+b);
144          X[i+stride] = .707107f*(a-b);
145       }
146    }
147 }
148
149 static void time_dct(float *X, int N, int B, int stride)
150 {
151    switch (B)
152    {
153       case 1:
154          break;
155       case 2:
156          haar1(X, B*N, stride);
157          break;
158       default:
159          celt_warning("time_dct not defined for B > 2");
160    };
161 }
162
163 static void time_idct(float *X, int N, int B, int stride)
164 {
165    switch (B)
166    {
167       case 1:
168          break;
169       case 2:
170          haar1(X, B*N, stride);
171          break;
172       default:
173          celt_warning("time_dct not defined for B > 2");
174    };
175 }
176
177 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
178 {
179    int i, c;
180    for (c=0;c<C;c++)
181    {
182       for (i=0;i<B;i++)
183       {
184          int j;
185          float x[2*N];
186          float tmp[N];
187          for (j=0;j<2*N;j++)
188             x[j] = window[j]*in[C*i*N+C*j+c];
189          mdct_forward(mdct_lookup, x, tmp);
190          /* Interleaving the sub-frames */
191          for (j=0;j<N;j++)
192             out[C*B*j+C*i+c] = tmp[j];
193       }
194    }
195 }
196
197 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int B, int C)
198 {
199    int i, c;
200    for (c=0;c<C;c++)
201    {
202       for (i=0;i<B;i++)
203       {
204          int j;
205          float x[2*N];
206          float tmp[N];
207          /* De-interleaving the sub-frames */
208          for (j=0;j<N;j++)
209             tmp[j] = X[C*B*j+C*i+c];
210          mdct_backward(mdct_lookup, tmp, x);
211          for (j=0;j<2*N;j++)
212             x[j] = window[j]*x[j];
213          for (j=0;j<N;j++)
214             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[j]+mdct_overlap[C*j+c];
215          for (j=0;j<N;j++)
216             mdct_overlap[C*j+c] = x[N+j];
217       }
218    }
219 }
220
221 int celt_encode(CELTEncoder *st, short *pcm)
222 {
223    int i, c, N, B, C;
224    N = st->block_size;
225    B = st->nb_blocks;
226    C = st->mode->nbChannels;
227    float in[(B+1)*C*N];
228    
229    float X[B*C*N];         /**< Interleaved signal MDCTs */
230    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
231    float mask[B*C*N];      /**< Masking curve */
232    float bandE[st->mode->nbEBands];
233    float gains[st->mode->nbPBands];
234    int pitch_index;
235    
236    for (c=0;c<C;c++)
237    {
238       for (i=0;i<N;i++)
239          in[C*i+c] = st->in_mem[C*i+c];
240       for (;i<(B+1)*N;i++)
241       {
242          float tmp = pcm[C*(i-N)+c];
243          in[C*i+c] = tmp - st->preemph*st->preemph_memE[c];
244          st->preemph_memE[c] = tmp;
245       }
246       for (i=0;i<N;i++)
247          st->in_mem[C*i+c] = in[C*(B*N+i)+c];
248    }
249    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
250    /* Compute MDCTs */
251    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
252    
253    compute_masking(X, mask, B*C*N, st->Fs);
254    /* Invert and stretch the mask to length of X */
255    for (i=B*C*N-1;i>=0;i--)
256       mask[i] = 1/(1+mask[i>>1]);
257    
258    /* Pitch analysis */
259    for (c=0;c<C;c++)
260    {
261       for (i=0;i<N;i++)
262       {
263          in[C*i+c] *= st->window[i];
264          in[C*(B*N+i)+c] *= st->window[N+i];
265       }
266    }
267    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
268    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
269    
270    /* Compute MDCTs of the pitch part */
271    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
272    
273    /*int j;
274    for (j=0;j<B*N;j++)
275       printf ("%f ", X[j]);
276    for (j=0;j<B*N;j++)
277       printf ("%f ", P[j]);
278    printf ("\n");*/
279    if (C==2)
280    {
281       haar1(X, B*N*C, 1);
282       haar1(P, B*N*C, 1);
283    }
284    
285    time_dct(X, N, B, C);
286    time_dct(P, N, B, C);
287
288    /* Band normalisation */
289    compute_band_energies(st->mode, X, bandE);
290    normalise_bands(st->mode, X, bandE);
291    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
292    
293    {
294       float bandEp[st->mode->nbEBands];
295       compute_band_energies(st->mode, P, bandEp);
296       normalise_bands(st->mode, P, bandEp);
297    }
298    
299    band_rotation(st->mode, X, -1);
300    band_rotation(st->mode, P, -1);
301    
302    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
303    
304    /* Pitch prediction */
305    compute_pitch_gain(st->mode, X, P, gains, bandE);
306    quant_pitch(gains, st->mode->nbPBands, &st->enc);
307    pitch_quant_bands(st->mode, X, P, gains);
308
309    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
310    /* Subtract the pitch prediction from the signal to encode */
311    for (i=0;i<B*C*N;i++)
312       X[i] -= P[i];
313
314    /*float sum=0;
315    for (i=0;i<B*N;i++)
316       sum += X[i]*X[i];
317    printf ("%f\n", sum);*/
318    /* Residual quantisation */
319    quant_bands(st->mode, X, P, mask, &st->enc);
320    
321    if (0) {//This is just for debugging
322       ec_enc_done(&st->enc);
323       ec_dec dec;
324       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
325       ec_dec_init(&dec,&st->buf);
326
327       unquant_bands(st->mode, X, P, &dec);
328       //printf ("\n");
329    }
330    
331    band_rotation(st->mode, X, 1);
332
333    /* Synthesis */
334    denormalise_bands(st->mode, X, bandE);
335
336    time_idct(X, N, B, C);
337    if (C==2)
338       haar1(X, B*N*C, 1);
339    
340    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
341    /* Compute inverse MDCTs */
342    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
343
344    for (c=0;c<C;c++)
345    {
346       for (i=0;i<B;i++)
347       {
348          int j;
349          for (j=0;j<N;j++)
350          {
351             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
352             st->preemph_memD[c] = tmp;
353             if (tmp > 32767) tmp = 32767;
354             if (tmp < -32767) tmp = -32767;
355             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
356          }
357       }
358    }
359    return 0;
360 }
361
362 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
363 {
364    char *data;
365    ec_enc_done(&st->enc);
366    *nbBytes = ec_byte_bytes(&st->buf);
367    data = ec_byte_get_buffer(&st->buf);
368    //printf ("%d\n", *nbBytes);
369    
370    /* Reset the packing for the next encoding */
371    ec_byte_reset(&st->buf);
372    ec_enc_init(&st->enc,&st->buf);
373
374    return data;
375 }
376
377
378 /****************************************************************************/
379 /*                                                                          */
380 /*                                DECODER                                   */
381 /*                                                                          */
382 /****************************************************************************/
383
384
385
386 struct CELTDecoder {
387    const CELTMode *mode;
388    int frame_size;
389    int block_size;
390    int nb_blocks;
391    
392    ec_byte_buffer buf;
393    ec_enc         enc;
394
395    float preemph;
396    float *preemph_memD;
397    
398    mdct_lookup mdct_lookup;
399    
400    float *window;
401    float *mdct_overlap;
402    float *out_mem;
403
404    float *oldBandE;
405    
406    int last_pitch_index;
407 };
408
409 CELTDecoder *celt_decoder_new(const CELTMode *mode)
410 {
411    int i, N, B, C;
412    N = mode->mdctSize;
413    B = mode->nbMdctBlocks;
414    C = mode->nbChannels;
415    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
416    
417    st->mode = mode;
418    st->frame_size = B*N;
419    st->block_size = N;
420    st->nb_blocks  = B;
421    
422    mdct_init(&st->mdct_lookup, 2*N);
423    
424    st->window = celt_alloc(2*N*sizeof(float));
425    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
426    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
427    for (i=0;i<N;i++)
428       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
429    
430    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
431
432    st->preemph = 0.8;
433    st->preemph_memD = celt_alloc(C*sizeof(float));;
434
435    st->last_pitch_index = 0;
436    return st;
437 }
438
439 void celt_decoder_destroy(CELTDecoder *st)
440 {
441    if (st == NULL)
442    {
443       celt_warning("NULL passed to celt_encoder_destroy");
444       return;
445    }
446
447    mdct_clear(&st->mdct_lookup);
448
449    celt_free(st->window);
450    celt_free(st->mdct_overlap);
451    celt_free(st->out_mem);
452    
453    celt_free(st->oldBandE);
454    celt_free(st);
455 }
456
457 static void celt_decode_lost(CELTDecoder *st, short *pcm)
458 {
459    int i, c, N, B, C;
460    N = st->block_size;
461    B = st->nb_blocks;
462    C = st->mode->nbChannels;
463    float X[C*B*N];         /**< Interleaved signal MDCTs */
464    int pitch_index;
465    
466    pitch_index = st->last_pitch_index;
467    
468    /* Use the pitch MDCT as the "guessed" signal */
469    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
470
471    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
472    /* Compute inverse MDCTs */
473    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
474
475    for (c=0;c<C;c++)
476    {
477       for (i=0;i<B;i++)
478       {
479          int j;
480          for (j=0;j<N;j++)
481          {
482             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
483             st->preemph_memD[c] = tmp;
484             if (tmp > 32767) tmp = 32767;
485             if (tmp < -32767) tmp = -32767;
486             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
487          }
488       }
489    }
490 }
491
492 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
493 {
494    int i, c, N, B, C;
495    N = st->block_size;
496    B = st->nb_blocks;
497    C = st->mode->nbChannels;
498    
499    float X[C*B*N];         /**< Interleaved signal MDCTs */
500    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
501    float bandE[st->mode->nbEBands];
502    float gains[st->mode->nbPBands];
503    int pitch_index;
504    ec_dec dec;
505    ec_byte_buffer buf;
506    
507    if (data == NULL)
508    {
509       celt_decode_lost(st, pcm);
510       return 0;
511    }
512    
513    ec_byte_readinit(&buf,data,len);
514    ec_dec_init(&dec,&buf);
515    
516    /* Get the pitch index */
517    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
518    st->last_pitch_index = pitch_index;
519    
520    /* Get band energies */
521    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
522    
523    /* Pitch MDCT */
524    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
525
526    if (C==2)
527       haar1(P, B*N*C, 1);
528    time_dct(P, N, B, C);
529
530    {
531       float bandEp[st->mode->nbEBands];
532       compute_band_energies(st->mode, P, bandEp);
533       normalise_bands(st->mode, P, bandEp);
534    }
535    band_rotation(st->mode, P, -1);
536
537    /* Get the pitch gains */
538    unquant_pitch(gains, st->mode->nbPBands, &dec);
539
540    /* Apply pitch gains */
541    pitch_quant_bands(st->mode, X, P, gains);
542
543    /* Decode fixed codebook and merge with pitch */
544    unquant_bands(st->mode, X, P, &dec);
545
546    band_rotation(st->mode, X, 1);
547    
548    /* Synthesis */
549    denormalise_bands(st->mode, X, bandE);
550
551    time_idct(X, N, B, C);
552    if (C==2)
553       haar1(X, B*N*C, 1);
554
555    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
556    /* Compute inverse MDCTs */
557    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
558
559    for (c=0;c<C;c++)
560    {
561       for (i=0;i<B;i++)
562       {
563          int j;
564          for (j=0;j<N;j++)
565          {
566             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
567             st->preemph_memD[c] = tmp;
568             if (tmp > 32767) tmp = 32767;
569             if (tmp < -32767) tmp = -32767;
570             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
571          }
572       }
573    }
574    return 0;
575    //printf ("\n");
576 }
577