Stereo decoding working again (fixed a few issues in the encoder at the same
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44
45 #define MAX_PERIOD 1024
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53    int overlap;
54    int channels;
55    int Fs;
56    
57    ec_byte_buffer buf;
58    ec_enc         enc;
59
60    float preemph;
61    float *preemph_memE;
62    float *preemph_memD;
63    
64    mdct_lookup mdct_lookup;
65    void *fft;
66    
67    float *window;
68    float *in_mem;
69    float *mdct_overlap;
70    float *out_mem;
71
72    float *oldBandE;
73 };
74
75
76
77 CELTEncoder *celt_encoder_new(const CELTMode *mode)
78 {
79    int i, N, B, C, N4;
80    N = mode->mdctSize;
81    B = mode->nbMdctBlocks;
82    C = mode->nbChannels;
83    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
84    
85    st->mode = mode;
86    st->frame_size = B*N;
87    st->block_size = N;
88    st->nb_blocks  = B;
89    st->overlap = mode->overlap;
90    st->Fs = 44100;
91
92    N4 = (N-st->overlap)/2;
93    ec_byte_writeinit(&st->buf);
94    ec_enc_init(&st->enc,&st->buf);
95
96    mdct_init(&st->mdct_lookup, 2*N);
97    st->fft = spx_fft_init(MAX_PERIOD*C);
98    
99    st->window = celt_alloc(2*N*sizeof(float));
100    st->in_mem = celt_alloc(N*C*sizeof(float));
101    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
102    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
103    for (i=0;i<2*N;i++)
104       st->window[i] = 0;
105    for (i=0;i<st->overlap;i++)
106       st->window[N4+i] = st->window[2*N-N4-i-1] 
107             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
108    for (i=0;i<2*N4;i++)
109       st->window[N-N4+i] = 1;
110    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
111
112    st->preemph = 0.8;
113    st->preemph_memE = celt_alloc(C*sizeof(float));;
114    st->preemph_memD = celt_alloc(C*sizeof(float));;
115
116    return st;
117 }
118
119 void celt_encoder_destroy(CELTEncoder *st)
120 {
121    if (st == NULL)
122    {
123       celt_warning("NULL passed to celt_encoder_destroy");
124       return;
125    }
126    ec_byte_writeclear(&st->buf);
127
128    mdct_clear(&st->mdct_lookup);
129    spx_fft_destroy(st->fft);
130
131    celt_free(st->window);
132    celt_free(st->in_mem);
133    celt_free(st->mdct_overlap);
134    celt_free(st->out_mem);
135    
136    celt_free(st->oldBandE);
137    celt_free(st);
138 }
139
140 static void haar1(float *X, int N, int stride)
141 {
142    int i, k;
143    for (k=0;k<stride;k++)
144    {
145       for (i=k;i<N*stride;i+=2*stride)
146       {
147          float a, b;
148          a = X[i];
149          b = X[i+stride];
150          X[i] = .707107f*(a+b);
151          X[i+stride] = .707107f*(a-b);
152       }
153    }
154 }
155
156 static void time_dct(float *X, int N, int B, int stride)
157 {
158    switch (B)
159    {
160       case 1:
161          break;
162       case 2:
163          haar1(X, B*N, stride);
164          break;
165       default:
166          celt_warning("time_dct not defined for B > 2");
167    };
168 }
169
170 static void time_idct(float *X, int N, int B, int stride)
171 {
172    switch (B)
173    {
174       case 1:
175          break;
176       case 2:
177          haar1(X, B*N, stride);
178          break;
179       default:
180          celt_warning("time_dct not defined for B > 2");
181    };
182 }
183
184 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
185 {
186    int i, c;
187    for (c=0;c<C;c++)
188    {
189       for (i=0;i<B;i++)
190       {
191          int j;
192          float x[2*N];
193          float tmp[N];
194          for (j=0;j<2*N;j++)
195             x[j] = window[j]*in[C*i*N+C*j+c];
196          mdct_forward(mdct_lookup, x, tmp);
197          /* Interleaving the sub-frames */
198          for (j=0;j<N;j++)
199             out[C*B*j+C*i+c] = tmp[j];
200       }
201    }
202 }
203
204 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
205 {
206    int i, c, N4;
207    N4 = (N-overlap)/2;
208    for (c=0;c<C;c++)
209    {
210       for (i=0;i<B;i++)
211       {
212          int j;
213          float x[2*N];
214          float tmp[N];
215          /* De-interleaving the sub-frames */
216          for (j=0;j<N;j++)
217             tmp[j] = X[C*B*j+C*i+c];
218          mdct_backward(mdct_lookup, tmp, x);
219          for (j=0;j<2*N;j++)
220             x[j] = window[j]*x[j];
221          for (j=0;j<overlap;j++)
222             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
223          for (j=0;j<2*N4;j++)
224             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
225          for (j=0;j<overlap;j++)
226             mdct_overlap[C*j+c] = x[N+N4+j];
227       }
228    }
229 }
230
231 int celt_encode(CELTEncoder *st, short *pcm)
232 {
233    int i, c, N, B, C, N4;
234    N = st->block_size;
235    B = st->nb_blocks;
236    C = st->mode->nbChannels;
237    float in[(B+1)*C*N];
238
239    float X[B*C*N];         /**< Interleaved signal MDCTs */
240    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
241    float mask[B*C*N];      /**< Masking curve */
242    float bandE[st->mode->nbEBands*C];
243    float gains[st->mode->nbPBands];
244    int pitch_index;
245
246    N4 = (N-st->overlap)/2;
247
248    for (c=0;c<C;c++)
249    {
250       for (i=0;i<N4;i++)
251          in[C*i+c] = 0;
252       for (i=0;i<st->overlap;i++)
253          in[C*(i+N4)+c] = st->in_mem[C*i+c];
254       for (i=0;i<B*N;i++)
255       {
256          float tmp = pcm[C*i+c];
257          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
258          st->preemph_memE[c] = tmp;
259       }
260       for (i=N*(B+1)-N4;i<N*(B+1);i++)
261          in[C*i+c] = 0;
262       for (i=0;i<st->overlap;i++)
263          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
264    }
265    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
266    /* Compute MDCTs */
267    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
268
269    compute_mdct_masking(X, mask, B*C*N, st->Fs);
270
271    /* Invert and stretch the mask to length of X 
272       For some reason, I get better results by using the sqrt instead,
273       although there's no valid reason to. Must investigate further */
274    for (i=0;i<B*C*N;i++)
275       mask[i] = 1/(.1+mask[i]);
276
277    /* Pitch analysis */
278    for (c=0;c<C;c++)
279    {
280       for (i=0;i<N;i++)
281       {
282          in[C*i+c] *= st->window[i];
283          in[C*(B*N+i)+c] *= st->window[N+i];
284       }
285    }
286    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
287    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
288    
289    /* Compute MDCTs of the pitch part */
290    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
291    
292    /*int j;
293    for (j=0;j<B*N;j++)
294       printf ("%f ", X[j]);
295    for (j=0;j<B*N;j++)
296       printf ("%f ", P[j]);
297    printf ("\n");*/
298
299    /* Band normalisation */
300    compute_band_energies(st->mode, X, bandE);
301    normalise_bands(st->mode, X, bandE);
302    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
303    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
304
305    /* Normalise the pitch vector as well (discard the energies) */
306    {
307       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
308       compute_band_energies(st->mode, P, bandEp);
309       normalise_bands(st->mode, P, bandEp);
310    }
311
312    if (C==2)
313    {
314       haar1(X, B*N*C, 1);
315       haar1(P, B*N*C, 1);
316    }
317    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
318    time_dct(X, N, B, C);
319    time_dct(P, N, B, C);
320
321    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
322
323    /* Pitch prediction */
324    compute_pitch_gain(st->mode, X, P, gains, bandE);
325    quant_pitch(gains, st->mode->nbPBands, &st->enc);
326    pitch_quant_bands(st->mode, X, P, gains);
327
328    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
329    /* Compute residual that we're going to encode */
330    for (i=0;i<B*C*N;i++)
331       X[i] -= P[i];
332
333    /*float sum=0;
334    for (i=0;i<B*N;i++)
335       sum += X[i]*X[i];
336    printf ("%f\n", sum);*/
337    /* Residual quantisation */
338    quant_bands(st->mode, X, P, mask, &st->enc);
339    
340    time_idct(X, N, B, C);
341    if (C==2)
342       haar1(X, B*N*C, 1);
343
344    renormalise_bands(st->mode, X);
345    /* Synthesis */
346    denormalise_bands(st->mode, X, bandE);
347
348
349    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
350
351    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
352    /* De-emphasis and put everything back at the right place in the synthesis history */
353    for (c=0;c<C;c++)
354    {
355       for (i=0;i<B;i++)
356       {
357          int j;
358          for (j=0;j<N;j++)
359          {
360             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
361             st->preemph_memD[c] = tmp;
362             if (tmp > 32767) tmp = 32767;
363             if (tmp < -32767) tmp = -32767;
364             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
365          }
366       }
367    }
368    return 0;
369 }
370
371 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
372 {
373    char *data;
374    ec_enc_done(&st->enc);
375    *nbBytes = ec_byte_bytes(&st->buf);
376    data = ec_byte_get_buffer(&st->buf);
377    //printf ("%d\n", *nbBytes);
378    
379    /* Reset the packing for the next encoding */
380    ec_byte_reset(&st->buf);
381    ec_enc_init(&st->enc,&st->buf);
382
383    return data;
384 }
385
386
387 /****************************************************************************/
388 /*                                                                          */
389 /*                                DECODER                                   */
390 /*                                                                          */
391 /****************************************************************************/
392
393
394
395 struct CELTDecoder {
396    const CELTMode *mode;
397    int frame_size;
398    int block_size;
399    int nb_blocks;
400    int overlap;
401
402    ec_byte_buffer buf;
403    ec_enc         enc;
404
405    float preemph;
406    float *preemph_memD;
407    
408    mdct_lookup mdct_lookup;
409    
410    float *window;
411    float *mdct_overlap;
412    float *out_mem;
413
414    float *oldBandE;
415    
416    int last_pitch_index;
417 };
418
419 CELTDecoder *celt_decoder_new(const CELTMode *mode)
420 {
421    int i, N, B, C, N4;
422    N = mode->mdctSize;
423    B = mode->nbMdctBlocks;
424    C = mode->nbChannels;
425    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
426    
427    st->mode = mode;
428    st->frame_size = B*N;
429    st->block_size = N;
430    st->nb_blocks  = B;
431    st->overlap = mode->overlap;
432
433    N4 = (N-st->overlap)/2;
434    
435    mdct_init(&st->mdct_lookup, 2*N);
436    
437    st->window = celt_alloc(2*N*sizeof(float));
438    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
439    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
440
441    for (i=0;i<2*N;i++)
442       st->window[i] = 0;
443    for (i=0;i<st->overlap;i++)
444       st->window[N4+i] = st->window[2*N-N4-i-1] 
445             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
446    for (i=0;i<2*N4;i++)
447       st->window[N-N4+i] = 1;
448    
449    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
450
451    st->preemph = 0.8;
452    st->preemph_memD = celt_alloc(C*sizeof(float));;
453
454    st->last_pitch_index = 0;
455    return st;
456 }
457
458 void celt_decoder_destroy(CELTDecoder *st)
459 {
460    if (st == NULL)
461    {
462       celt_warning("NULL passed to celt_encoder_destroy");
463       return;
464    }
465
466    mdct_clear(&st->mdct_lookup);
467
468    celt_free(st->window);
469    celt_free(st->mdct_overlap);
470    celt_free(st->out_mem);
471    
472    celt_free(st->oldBandE);
473    celt_free(st);
474 }
475
476 static void celt_decode_lost(CELTDecoder *st, short *pcm)
477 {
478    int i, c, N, B, C;
479    N = st->block_size;
480    B = st->nb_blocks;
481    C = st->mode->nbChannels;
482    float X[C*B*N];         /**< Interleaved signal MDCTs */
483    int pitch_index;
484    
485    pitch_index = st->last_pitch_index;
486    
487    /* Use the pitch MDCT as the "guessed" signal */
488    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
489
490    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
491    /* Compute inverse MDCTs */
492    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
493
494    for (c=0;c<C;c++)
495    {
496       for (i=0;i<B;i++)
497       {
498          int j;
499          for (j=0;j<N;j++)
500          {
501             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
502             st->preemph_memD[c] = tmp;
503             if (tmp > 32767) tmp = 32767;
504             if (tmp < -32767) tmp = -32767;
505             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
506          }
507       }
508    }
509 }
510
511 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
512 {
513    int i, c, N, B, C;
514    N = st->block_size;
515    B = st->nb_blocks;
516    C = st->mode->nbChannels;
517    
518    float X[C*B*N];         /**< Interleaved signal MDCTs */
519    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
520    float bandE[st->mode->nbEBands*C];
521    float gains[st->mode->nbPBands];
522    int pitch_index;
523    ec_dec dec;
524    ec_byte_buffer buf;
525    
526    if (data == NULL)
527    {
528       celt_decode_lost(st, pcm);
529       return 0;
530    }
531    
532    ec_byte_readinit(&buf,data,len);
533    ec_dec_init(&dec,&buf);
534    
535    /* Get the pitch index */
536    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
537    st->last_pitch_index = pitch_index;
538    
539    /* Get band energies */
540    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
541    
542    /* Pitch MDCT */
543    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
544
545    {
546       float bandEp[st->mode->nbEBands];
547       compute_band_energies(st->mode, P, bandEp);
548       normalise_bands(st->mode, P, bandEp);
549    }
550
551    if (C==2)
552       haar1(P, B*N*C, 1);
553    time_dct(P, N, B, C);
554
555    /* Get the pitch gains */
556    unquant_pitch(gains, st->mode->nbPBands, &dec);
557
558    /* Apply pitch gains */
559    pitch_quant_bands(st->mode, X, P, gains);
560
561    /* Decode fixed codebook and merge with pitch */
562    unquant_bands(st->mode, X, P, &dec);
563
564    time_idct(X, N, B, C);
565    if (C==2)
566       haar1(X, B*N*C, 1);
567
568    renormalise_bands(st->mode, X);
569    
570    /* Synthesis */
571    denormalise_bands(st->mode, X, bandE);
572
573
574    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
575    /* Compute inverse MDCTs */
576    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
577
578    for (c=0;c<C;c++)
579    {
580       for (i=0;i<B;i++)
581       {
582          int j;
583          for (j=0;j<N;j++)
584          {
585             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
586             st->preemph_memD[c] = tmp;
587             if (tmp > 32767) tmp = 32767;
588             if (tmp < -32767) tmp = -32767;
589             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
590          }
591       }
592    }
593    return 0;
594    //printf ("\n");
595 }
596