Fixed stereo regression introduced in 05686a5d6e366d3a067c39f1b8567def7baa450d
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99
100    mdct_init(&st->mdct_lookup, 2*N);
101    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
102    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    kiss_fft_free(st->fft);
136    psydecay_clear(&st->psy);
137
138    celt_free(st->window);
139    celt_free(st->in_mem);
140    celt_free(st->mdct_overlap);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148    alloc_clear(&st->alloc);
149
150    celt_free(st);
151 }
152
153
154 static float compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
155 {
156    int i, c;
157    float E = 1e-15;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<B;i++)
161       {
162          int j;
163          float x[2*N];
164          float tmp[N];
165          for (j=0;j<2*N;j++)
166          {
167             x[j] = window[j]*in[C*i*N+C*j+c];
168             E += x[j]*x[j];
169          }
170          mdct_forward(mdct_lookup, x, tmp);
171          /* Interleaving the sub-frames */
172          for (j=0;j<N;j++)
173             out[C*B*j+C*i+c] = tmp[j];
174       }
175    }
176    return E;
177 }
178
179 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
180 {
181    int i, c, N4;
182    N4 = (N-overlap)/2;
183    for (c=0;c<C;c++)
184    {
185       for (i=0;i<B;i++)
186       {
187          int j;
188          float x[2*N];
189          float tmp[N];
190          /* De-interleaving the sub-frames */
191          for (j=0;j<N;j++)
192             tmp[j] = X[C*B*j+C*i+c];
193          mdct_backward(mdct_lookup, tmp, x);
194          for (j=0;j<2*N;j++)
195             x[j] = window[j]*x[j];
196          for (j=0;j<overlap;j++)
197             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
198          for (j=0;j<2*N4;j++)
199             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
200          for (j=0;j<overlap;j++)
201             mdct_overlap[C*j+c] = x[N+N4+j];
202       }
203    }
204 }
205
206 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
207 {
208    int i, c, N, B, C, N4;
209    int has_pitch;
210    N = st->block_size;
211    B = st->nb_blocks;
212    C = st->mode->nbChannels;
213    float in[(B+1)*C*N];
214
215    float X[B*C*N];         /**< Interleaved signal MDCTs */
216    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
217    float mask[B*C*N];      /**< Masking curve */
218    float bandE[st->mode->nbEBands*C];
219    float gains[st->mode->nbPBands];
220    int pitch_index;
221    float curr_power, pitch_power;
222    
223    N4 = (N-st->overlap)/2;
224
225    for (c=0;c<C;c++)
226    {
227       for (i=0;i<N4;i++)
228          in[C*i+c] = 0;
229       for (i=0;i<st->overlap;i++)
230          in[C*(i+N4)+c] = st->in_mem[C*i+c];
231       for (i=0;i<B*N;i++)
232       {
233          float tmp = pcm[C*i+c];
234          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
235          st->preemph_memE[c] = tmp;
236       }
237       for (i=N*(B+1)-N4;i<N*(B+1);i++)
238          in[C*i+c] = 0;
239       for (i=0;i<st->overlap;i++)
240          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
241    }
242    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
243    /* Compute MDCTs */
244    curr_power = compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
245
246 #if 0 /* Mask disabled until it can be made to do something useful */
247    compute_mdct_masking(X, mask, B*C*N, st->Fs);
248
249    /* Invert and stretch the mask to length of X 
250       For some reason, I get better results by using the sqrt instead,
251       although there's no valid reason to. Must investigate further */
252    for (i=0;i<B*C*N;i++)
253       mask[i] = 1/(.1+mask[i]);
254 #else
255    for (i=0;i<B*C*N;i++)
256       mask[i] = 1;
257 #endif
258    /* Pitch analysis */
259    for (c=0;c<C;c++)
260    {
261       for (i=0;i<N;i++)
262       {
263          in[C*i+c] *= st->window[i];
264          in[C*(B*N+i)+c] *= st->window[N+i];
265       }
266    }
267    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
268    
269    /* Compute MDCTs of the pitch part */
270    pitch_power = compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
271    
272    //printf ("%f %f\n", curr_power, pitch_power);
273    /*int j;
274    for (j=0;j<B*N;j++)
275       printf ("%f ", X[j]);
276    for (j=0;j<B*N;j++)
277       printf ("%f ", P[j]);
278    printf ("\n");*/
279
280    /* Band normalisation */
281    compute_band_energies(st->mode, X, bandE);
282    normalise_bands(st->mode, X, bandE);
283    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
284    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
285
286    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
287
288    if (C==2)
289    {
290       stereo_mix(st->mode, X, bandE, 1);
291    }
292
293    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
294    if (curr_power + 1e5f < 10.f*pitch_power)
295    {
296       /* Normalise the pitch vector as well (discard the energies) */
297       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
298       compute_band_energies(st->mode, P, bandEp);
299       normalise_bands(st->mode, P, bandEp);
300
301       if (C==2)
302          stereo_mix(st->mode, P, bandE, 1);
303       /* Simulates intensity stereo */
304       //for (i=30;i<N*B;i++)
305       //   X[i*C+1] = P[i*C+1] = 0;
306
307       /* Pitch prediction */
308       compute_pitch_gain(st->mode, X, P, gains, bandE);
309       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
310       if (has_pitch)
311          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
312    } else {
313       /* No pitch, so we just pretend we found a gain of zero */
314       for (i=0;i<st->mode->nbPBands;i++)
315          gains[i] = 0;
316       ec_enc_uint(&st->enc, 0, 128);
317       for (i=0;i<B*C*N;i++)
318          P[i] = 0;
319    }
320    
321
322    pitch_quant_bands(st->mode, X, P, gains);
323
324    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
325    /* Compute residual that we're going to encode */
326    for (i=0;i<B*C*N;i++)
327       X[i] -= P[i];
328
329    /*float sum=0;
330    for (i=0;i<B*N;i++)
331       sum += X[i]*X[i];
332    printf ("%f\n", sum);*/
333    /* Residual quantisation */
334    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
335    
336    if (C==2)
337       stereo_mix(st->mode, X, bandE, -1);
338
339    renormalise_bands(st->mode, X);
340    /* Synthesis */
341    denormalise_bands(st->mode, X, bandE);
342
343
344    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
345
346    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
347    /* De-emphasis and put everything back at the right place in the synthesis history */
348    for (c=0;c<C;c++)
349    {
350       for (i=0;i<B;i++)
351       {
352          int j;
353          for (j=0;j<N;j++)
354          {
355             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
356             st->preemph_memD[c] = tmp;
357             if (tmp > 32767) tmp = 32767;
358             if (tmp < -32767) tmp = -32767;
359             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
360          }
361       }
362    }
363    
364    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 16)
365       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
366    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
367    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
368    {
369       int val = 0;
370       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
371       {
372          ec_enc_uint(&st->enc, val, 2);
373          val = 1-val;
374       }
375    }
376    ec_enc_done(&st->enc);
377    {
378       unsigned char *data;
379       int nbBytes = ec_byte_bytes(&st->buf);
380       if (nbBytes > nbCompressedBytes)
381       {
382          celt_warning_int ("got too many bytes:", nbBytes);
383          return CELT_INTERNAL_ERROR;
384       }
385       //printf ("%d\n", *nbBytes);
386       data = ec_byte_get_buffer(&st->buf);
387       for (i=0;i<nbBytes;i++)
388          compressed[i] = data[i];
389       for (;i<nbCompressedBytes;i++)
390          compressed[i] = 0;
391    }
392    /* Reset the packing for the next encoding */
393    ec_byte_reset(&st->buf);
394    ec_enc_init(&st->enc,&st->buf);
395
396    return nbCompressedBytes;
397 }
398
399
400 /****************************************************************************/
401 /*                                                                          */
402 /*                                DECODER                                   */
403 /*                                                                          */
404 /****************************************************************************/
405
406
407
408 struct CELTDecoder {
409    const CELTMode *mode;
410    int frame_size;
411    int block_size;
412    int nb_blocks;
413    int overlap;
414
415    ec_byte_buffer buf;
416    ec_enc         enc;
417
418    float preemph;
419    float *preemph_memD;
420    
421    mdct_lookup mdct_lookup;
422    
423    float *window;
424    float *mdct_overlap;
425    float *out_mem;
426
427    float *oldBandE;
428    
429    int last_pitch_index;
430    
431    struct alloc_data alloc;
432 };
433
434 CELTDecoder *celt_decoder_new(const CELTMode *mode)
435 {
436    int i, N, B, C, N4;
437    N = mode->mdctSize;
438    B = mode->nbMdctBlocks;
439    C = mode->nbChannels;
440    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
441    
442    st->mode = mode;
443    st->frame_size = B*N;
444    st->block_size = N;
445    st->nb_blocks  = B;
446    st->overlap = mode->overlap;
447
448    N4 = (N-st->overlap)/2;
449    
450    mdct_init(&st->mdct_lookup, 2*N);
451    
452    st->window = celt_alloc(2*N*sizeof(float));
453    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
454    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
455
456    for (i=0;i<2*N;i++)
457       st->window[i] = 0;
458    for (i=0;i<st->overlap;i++)
459       st->window[N4+i] = st->window[2*N-N4-i-1] 
460             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
461    for (i=0;i<2*N4;i++)
462       st->window[N-N4+i] = 1;
463    
464    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
465
466    st->preemph = 0.8;
467    st->preemph_memD = celt_alloc(C*sizeof(float));;
468
469    st->last_pitch_index = 0;
470    alloc_init(&st->alloc, st->mode);
471
472    return st;
473 }
474
475 void celt_decoder_destroy(CELTDecoder *st)
476 {
477    if (st == NULL)
478    {
479       celt_warning("NULL passed to celt_encoder_destroy");
480       return;
481    }
482
483    mdct_clear(&st->mdct_lookup);
484
485    celt_free(st->window);
486    celt_free(st->mdct_overlap);
487    celt_free(st->out_mem);
488    
489    celt_free(st->oldBandE);
490    
491    celt_free(st->preemph_memD);
492
493    alloc_clear(&st->alloc);
494
495    celt_free(st);
496 }
497
498 static void celt_decode_lost(CELTDecoder *st, short *pcm)
499 {
500    int i, c, N, B, C;
501    N = st->block_size;
502    B = st->nb_blocks;
503    C = st->mode->nbChannels;
504    float X[C*B*N];         /**< Interleaved signal MDCTs */
505    int pitch_index;
506    
507    pitch_index = st->last_pitch_index;
508    
509    /* Use the pitch MDCT as the "guessed" signal */
510    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
511
512    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
513    /* Compute inverse MDCTs */
514    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
515
516    for (c=0;c<C;c++)
517    {
518       for (i=0;i<B;i++)
519       {
520          int j;
521          for (j=0;j<N;j++)
522          {
523             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
524             st->preemph_memD[c] = tmp;
525             if (tmp > 32767) tmp = 32767;
526             if (tmp < -32767) tmp = -32767;
527             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
528          }
529       }
530    }
531 }
532
533 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
534 {
535    int i, c, N, B, C;
536    int has_pitch;
537    N = st->block_size;
538    B = st->nb_blocks;
539    C = st->mode->nbChannels;
540    
541    float X[C*B*N];         /**< Interleaved signal MDCTs */
542    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
543    float bandE[st->mode->nbEBands*C];
544    float gains[st->mode->nbPBands];
545    int pitch_index;
546    ec_dec dec;
547    ec_byte_buffer buf;
548    
549    if (data == NULL)
550    {
551       celt_decode_lost(st, pcm);
552       return 0;
553    }
554    
555    ec_byte_readinit(&buf,data,len);
556    ec_dec_init(&dec,&buf);
557    
558    /* Get band energies */
559    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
560    
561    /* Get the pitch gains */
562    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
563    
564    /* Get the pitch index */
565    if (has_pitch)
566    {
567       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
568       st->last_pitch_index = pitch_index;
569    } else {
570       /* FIXME: We could be more intelligent here and just not compute the MDCT */
571       pitch_index = 0;
572    }
573    
574    /* Pitch MDCT */
575    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
576
577    {
578       float bandEp[st->mode->nbEBands*C];
579       compute_band_energies(st->mode, P, bandEp);
580       normalise_bands(st->mode, P, bandEp);
581    }
582
583    if (C==2)
584       stereo_mix(st->mode, P, bandE, 1);
585
586    /* Apply pitch gains */
587    pitch_quant_bands(st->mode, X, P, gains);
588
589    /* Decode fixed codebook and merge with pitch */
590    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
591
592    if (C==2)
593       stereo_mix(st->mode, X, bandE, -1);
594
595    renormalise_bands(st->mode, X);
596    
597    /* Synthesis */
598    denormalise_bands(st->mode, X, bandE);
599
600
601    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
602    /* Compute inverse MDCTs */
603    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
604
605    for (c=0;c<C;c++)
606    {
607       for (i=0;i<B;i++)
608       {
609          int j;
610          for (j=0;j<N;j++)
611          {
612             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
613             st->preemph_memD[c] = tmp;
614             if (tmp > 32767) tmp = 32767;
615             if (tmp < -32767) tmp = -32767;
616             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
617          }
618       }
619    }
620
621    {
622       int val = 0;
623       while (ec_dec_tell(&dec, 0) < len*8)
624       {
625          if (ec_dec_uint(&dec, 2) != val)
626          {
627             celt_warning("decode error");
628             return CELT_CORRUPTED_DATA;
629          }
630          val = 1-val;
631       }
632    }
633
634    return 0;
635    //printf ("\n");
636 }
637