Some stereo work (breaks the decoder for now)
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44
45 #define MAX_PERIOD 1024
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53    int overlap;
54    int channels;
55    int Fs;
56    
57    ec_byte_buffer buf;
58    ec_enc         enc;
59
60    float preemph;
61    float *preemph_memE;
62    float *preemph_memD;
63    
64    mdct_lookup mdct_lookup;
65    void *fft;
66    
67    float *window;
68    float *in_mem;
69    float *mdct_overlap;
70    float *out_mem;
71
72    float *oldBandE;
73 };
74
75
76
77 CELTEncoder *celt_encoder_new(const CELTMode *mode)
78 {
79    int i, N, B, C, N4;
80    N = mode->mdctSize;
81    B = mode->nbMdctBlocks;
82    C = mode->nbChannels;
83    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
84    
85    st->mode = mode;
86    st->frame_size = B*N;
87    st->block_size = N;
88    st->nb_blocks  = B;
89    st->overlap = mode->overlap;
90    st->Fs = 44100;
91
92    N4 = (N-st->overlap)/2;
93    ec_byte_writeinit(&st->buf);
94    ec_enc_init(&st->enc,&st->buf);
95
96    mdct_init(&st->mdct_lookup, 2*N);
97    st->fft = spx_fft_init(MAX_PERIOD*C);
98    
99    st->window = celt_alloc(2*N*sizeof(float));
100    st->in_mem = celt_alloc(N*C*sizeof(float));
101    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
102    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
103    for (i=0;i<2*N;i++)
104       st->window[i] = 0;
105    for (i=0;i<st->overlap;i++)
106       st->window[N4+i] = st->window[2*N-N4-i-1] 
107             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
108    for (i=0;i<2*N4;i++)
109       st->window[N-N4+i] = 1;
110    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
111
112    st->preemph = 0.8;
113    st->preemph_memE = celt_alloc(C*sizeof(float));;
114    st->preemph_memD = celt_alloc(C*sizeof(float));;
115
116    return st;
117 }
118
119 void celt_encoder_destroy(CELTEncoder *st)
120 {
121    if (st == NULL)
122    {
123       celt_warning("NULL passed to celt_encoder_destroy");
124       return;
125    }
126    ec_byte_writeclear(&st->buf);
127
128    mdct_clear(&st->mdct_lookup);
129    spx_fft_destroy(st->fft);
130
131    celt_free(st->window);
132    celt_free(st->in_mem);
133    celt_free(st->mdct_overlap);
134    celt_free(st->out_mem);
135    
136    celt_free(st->oldBandE);
137    celt_free(st);
138 }
139
140 static void haar1(float *X, int N, int stride)
141 {
142    int i, k;
143    for (k=0;k<stride;k++)
144    {
145       for (i=k;i<N*stride;i+=2*stride)
146       {
147          float a, b;
148          a = X[i];
149          b = X[i+stride];
150          X[i] = .707107f*(a+b);
151          X[i+stride] = .707107f*(a-b);
152       }
153    }
154 }
155
156 static void time_dct(float *X, int N, int B, int stride)
157 {
158    switch (B)
159    {
160       case 1:
161          break;
162       case 2:
163          haar1(X, B*N, stride);
164          break;
165       default:
166          celt_warning("time_dct not defined for B > 2");
167    };
168 }
169
170 static void time_idct(float *X, int N, int B, int stride)
171 {
172    switch (B)
173    {
174       case 1:
175          break;
176       case 2:
177          haar1(X, B*N, stride);
178          break;
179       default:
180          celt_warning("time_dct not defined for B > 2");
181    };
182 }
183
184 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
185 {
186    int i, c;
187    for (c=0;c<C;c++)
188    {
189       for (i=0;i<B;i++)
190       {
191          int j;
192          float x[2*N];
193          float tmp[N];
194          for (j=0;j<2*N;j++)
195             x[j] = window[j]*in[C*i*N+C*j+c];
196          mdct_forward(mdct_lookup, x, tmp);
197          /* Interleaving the sub-frames */
198          for (j=0;j<N;j++)
199             out[C*B*j+C*i+c] = tmp[j];
200       }
201    }
202 }
203
204 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
205 {
206    int i, c, N4;
207    N4 = (N-overlap)/2;
208    for (c=0;c<C;c++)
209    {
210       for (i=0;i<B;i++)
211       {
212          int j;
213          float x[2*N];
214          float tmp[N];
215          /* De-interleaving the sub-frames */
216          for (j=0;j<N;j++)
217             tmp[j] = X[C*B*j+C*i+c];
218          mdct_backward(mdct_lookup, tmp, x);
219          for (j=0;j<2*N;j++)
220             x[j] = window[j]*x[j];
221          for (j=0;j<overlap;j++)
222             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
223          for (j=0;j<2*N4;j++)
224             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
225          for (j=0;j<overlap;j++)
226             mdct_overlap[C*j+c] = x[N+N4+j];
227       }
228    }
229 }
230
231 int celt_encode(CELTEncoder *st, short *pcm)
232 {
233    int i, c, N, B, C, N4;
234    N = st->block_size;
235    B = st->nb_blocks;
236    C = st->mode->nbChannels;
237    float in[(B+1)*C*N];
238
239    float X[B*C*N];         /**< Interleaved signal MDCTs */
240    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
241    float mask[B*C*N];      /**< Masking curve */
242    float bandE[st->mode->nbEBands*C];
243    float gains[st->mode->nbPBands];
244    int pitch_index;
245
246    N4 = (N-st->overlap)/2;
247
248    for (c=0;c<C;c++)
249    {
250       for (i=0;i<N4;i++)
251          in[C*i+c] = 0;
252       for (i=0;i<st->overlap;i++)
253          in[C*(i+N4)+c] = st->in_mem[C*i+c];
254       for (i=0;i<B*N;i++)
255       {
256          float tmp = pcm[C*i+c];
257          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
258          st->preemph_memE[c] = tmp;
259       }
260       for (i=N*(B+1)-N4;i<N*(B+1);i++)
261          in[C*i+c] = 0;
262       for (i=0;i<st->overlap;i++)
263          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
264    }
265    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
266    /* Compute MDCTs */
267    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
268
269    compute_mdct_masking(X, mask, B*C*N, st->Fs);
270
271    /* Invert and stretch the mask to length of X 
272       For some reason, I get better results by using the sqrt instead,
273       although there's no valid reason to. Must investigate further */
274    for (i=0;i<B*C*N;i++)
275       mask[i] = 1/(.1+mask[i]);
276
277    /* Pitch analysis */
278    for (c=0;c<C;c++)
279    {
280       for (i=0;i<N;i++)
281       {
282          in[C*i+c] *= st->window[i];
283          in[C*(B*N+i)+c] *= st->window[N+i];
284       }
285    }
286    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
287    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
288    
289    /* Compute MDCTs of the pitch part */
290    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
291    
292    /*int j;
293    for (j=0;j<B*N;j++)
294       printf ("%f ", X[j]);
295    for (j=0;j<B*N;j++)
296       printf ("%f ", P[j]);
297    printf ("\n");*/
298
299    /* Band normalisation */
300    compute_band_energies(st->mode, X, bandE);
301    normalise_bands(st->mode, X, bandE);
302    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
303    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
304
305    /* Normalise the pitch vector as well (discard the energies) */
306    {
307       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
308       compute_band_energies(st->mode, P, bandEp);
309       normalise_bands(st->mode, P, bandEp);
310    }
311
312    if (C==2)
313    {
314       haar1(X, B*N*C, 1);
315       haar1(P, B*N*C, 1);
316    }
317    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
318    time_dct(X, N, B, C);
319    time_dct(P, N, B, C);
320
321
322    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
323
324    /* Pitch prediction */
325    compute_pitch_gain(st->mode, X, P, gains, bandE);
326    quant_pitch(gains, st->mode->nbPBands, &st->enc);
327    pitch_quant_bands(st->mode, X, P, gains);
328
329    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
330    /* Compute residual that we're going to encode */
331    for (i=0;i<B*C*N;i++)
332       X[i] -= P[i];
333
334    /*float sum=0;
335    for (i=0;i<B*N;i++)
336       sum += X[i]*X[i];
337    printf ("%f\n", sum);*/
338    /* Residual quantisation */
339    quant_bands(st->mode, X, P, mask, &st->enc);
340    
341    time_idct(X, N, B, C);
342    if (C==2)
343       haar1(X, B*N*C, 1);
344
345    renormalise_bands(st->mode, X);
346    /* Synthesis */
347    denormalise_bands(st->mode, X, bandE);
348
349
350    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
351
352    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
353    /* De-emphasis and put everything back at the right place in the synthesis history */
354    for (c=0;c<C;c++)
355    {
356       for (i=0;i<B;i++)
357       {
358          int j;
359          for (j=0;j<N;j++)
360          {
361             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
362             st->preemph_memD[c] = tmp;
363             if (tmp > 32767) tmp = 32767;
364             if (tmp < -32767) tmp = -32767;
365             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
366          }
367       }
368    }
369    return 0;
370 }
371
372 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
373 {
374    char *data;
375    ec_enc_done(&st->enc);
376    *nbBytes = ec_byte_bytes(&st->buf);
377    data = ec_byte_get_buffer(&st->buf);
378    //printf ("%d\n", *nbBytes);
379    
380    /* Reset the packing for the next encoding */
381    ec_byte_reset(&st->buf);
382    ec_enc_init(&st->enc,&st->buf);
383
384    return data;
385 }
386
387
388 /****************************************************************************/
389 /*                                                                          */
390 /*                                DECODER                                   */
391 /*                                                                          */
392 /****************************************************************************/
393
394
395
396 struct CELTDecoder {
397    const CELTMode *mode;
398    int frame_size;
399    int block_size;
400    int nb_blocks;
401    int overlap;
402
403    ec_byte_buffer buf;
404    ec_enc         enc;
405
406    float preemph;
407    float *preemph_memD;
408    
409    mdct_lookup mdct_lookup;
410    
411    float *window;
412    float *mdct_overlap;
413    float *out_mem;
414
415    float *oldBandE;
416    
417    int last_pitch_index;
418 };
419
420 CELTDecoder *celt_decoder_new(const CELTMode *mode)
421 {
422    int i, N, B, C, N4;
423    N = mode->mdctSize;
424    B = mode->nbMdctBlocks;
425    C = mode->nbChannels;
426    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
427    
428    st->mode = mode;
429    st->frame_size = B*N;
430    st->block_size = N;
431    st->nb_blocks  = B;
432    st->overlap = mode->overlap;
433
434    N4 = (N-st->overlap)/2;
435    
436    mdct_init(&st->mdct_lookup, 2*N);
437    
438    st->window = celt_alloc(2*N*sizeof(float));
439    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
440    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
441
442    for (i=0;i<2*N;i++)
443       st->window[i] = 0;
444    for (i=0;i<st->overlap;i++)
445       st->window[N4+i] = st->window[2*N-N4-i-1] 
446             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
447    for (i=0;i<2*N4;i++)
448       st->window[N-N4+i] = 1;
449    
450    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
451
452    st->preemph = 0.8;
453    st->preemph_memD = celt_alloc(C*sizeof(float));;
454
455    st->last_pitch_index = 0;
456    return st;
457 }
458
459 void celt_decoder_destroy(CELTDecoder *st)
460 {
461    if (st == NULL)
462    {
463       celt_warning("NULL passed to celt_encoder_destroy");
464       return;
465    }
466
467    mdct_clear(&st->mdct_lookup);
468
469    celt_free(st->window);
470    celt_free(st->mdct_overlap);
471    celt_free(st->out_mem);
472    
473    celt_free(st->oldBandE);
474    celt_free(st);
475 }
476
477 static void celt_decode_lost(CELTDecoder *st, short *pcm)
478 {
479    int i, c, N, B, C;
480    N = st->block_size;
481    B = st->nb_blocks;
482    C = st->mode->nbChannels;
483    float X[C*B*N];         /**< Interleaved signal MDCTs */
484    int pitch_index;
485    
486    pitch_index = st->last_pitch_index;
487    
488    /* Use the pitch MDCT as the "guessed" signal */
489    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
490
491    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
492    /* Compute inverse MDCTs */
493    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
494
495    for (c=0;c<C;c++)
496    {
497       for (i=0;i<B;i++)
498       {
499          int j;
500          for (j=0;j<N;j++)
501          {
502             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
503             st->preemph_memD[c] = tmp;
504             if (tmp > 32767) tmp = 32767;
505             if (tmp < -32767) tmp = -32767;
506             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
507          }
508       }
509    }
510 }
511
512 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
513 {
514    int i, c, N, B, C;
515    N = st->block_size;
516    B = st->nb_blocks;
517    C = st->mode->nbChannels;
518    
519    float X[C*B*N];         /**< Interleaved signal MDCTs */
520    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
521    float bandE[st->mode->nbEBands];
522    float gains[st->mode->nbPBands];
523    int pitch_index;
524    ec_dec dec;
525    ec_byte_buffer buf;
526    
527    if (data == NULL)
528    {
529       celt_decode_lost(st, pcm);
530       return 0;
531    }
532    
533    ec_byte_readinit(&buf,data,len);
534    ec_dec_init(&dec,&buf);
535    
536    /* Get the pitch index */
537    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
538    st->last_pitch_index = pitch_index;
539    
540    /* Get band energies */
541    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
542    
543    /* Pitch MDCT */
544    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
545
546    if (C==2)
547       haar1(P, B*N*C, 1);
548    time_dct(P, N, B, C);
549
550    {
551       float bandEp[st->mode->nbEBands];
552       compute_band_energies(st->mode, P, bandEp);
553       normalise_bands(st->mode, P, bandEp);
554    }
555
556    /* Get the pitch gains */
557    unquant_pitch(gains, st->mode->nbPBands, &dec);
558
559    /* Apply pitch gains */
560    pitch_quant_bands(st->mode, X, P, gains);
561
562    /* Decode fixed codebook and merge with pitch */
563    unquant_bands(st->mode, X, P, &dec);
564
565    /* Synthesis */
566    denormalise_bands(st->mode, X, bandE);
567
568    time_idct(X, N, B, C);
569    if (C==2)
570       haar1(X, B*N*C, 1);
571
572    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
573    /* Compute inverse MDCTs */
574    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
575
576    for (c=0;c<C;c++)
577    {
578       for (i=0;i<B;i++)
579       {
580          int j;
581          for (j=0;j<N;j++)
582          {
583             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
584             st->preemph_memD[c] = tmp;
585             if (tmp > 32767) tmp = 32767;
586             if (tmp < -32767) tmp = -32767;
587             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
588          }
589       }
590    }
591    return 0;
592    //printf ("\n");
593 }
594