Fixed stereo version of the pitch estimator
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43
44 #define MAX_PERIOD 1024
45
46
47 struct CELTEncoder {
48    const CELTMode *mode;
49    int frame_size;
50    int block_size;
51    int nb_blocks;
52    int channels;
53    
54    ec_byte_buffer buf;
55    ec_enc         enc;
56
57    float preemph;
58    float *preemph_memE;
59    float *preemph_memD;
60    
61    mdct_lookup mdct_lookup;
62    void *fft;
63    
64    float *window;
65    float *in_mem;
66    float *mdct_overlap;
67    float *out_mem;
68
69    float *oldBandE;
70 };
71
72
73
74 CELTEncoder *celt_encoder_new(const CELTMode *mode)
75 {
76    int i, N, B, C;
77    N = mode->mdctSize;
78    B = mode->nbMdctBlocks;
79    C = mode->nbChannels;
80    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
81    
82    st->mode = mode;
83    st->frame_size = B*N;
84    st->block_size = N;
85    st->nb_blocks  = B;
86    
87    ec_byte_writeinit(&st->buf);
88    ec_enc_init(&st->enc,&st->buf);
89
90    mdct_init(&st->mdct_lookup, 2*N);
91    st->fft = spx_fft_init(MAX_PERIOD*C);
92    
93    st->window = celt_alloc(2*N*sizeof(float));
94    st->in_mem = celt_alloc(N*C*sizeof(float));
95    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
96    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
97    for (i=0;i<N;i++)
98       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
99    
100    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
101
102    st->preemph = 0.8;
103    st->preemph_memE = celt_alloc(C*sizeof(float));;
104    st->preemph_memD = celt_alloc(C*sizeof(float));;
105
106    return st;
107 }
108
109 void celt_encoder_destroy(CELTEncoder *st)
110 {
111    if (st == NULL)
112    {
113       celt_warning("NULL passed to celt_encoder_destroy");
114       return;
115    }
116    ec_byte_writeclear(&st->buf);
117
118    mdct_clear(&st->mdct_lookup);
119    spx_fft_destroy(st->fft);
120
121    celt_free(st->window);
122    celt_free(st->in_mem);
123    celt_free(st->mdct_overlap);
124    celt_free(st->out_mem);
125    
126    celt_free(st->oldBandE);
127    celt_free(st);
128 }
129
130 static void haar1(float *X, int N)
131 {
132    int i;
133    for (i=0;i<N;i+=2)
134    {
135       float a, b;
136       a = X[i];
137       b = X[i+1];
138       X[i] = .707107f*(a+b);
139       X[i+1] = .707107f*(a-b);
140    }
141 }
142
143 static void inv_haar1(float *X, int N)
144 {
145    int i;
146    for (i=0;i<N;i+=2)
147    {
148       float a, b;
149       a = X[i];
150       b = X[i+1];
151       X[i] = .707107f*(a+b);
152       X[i+1] = .707107f*(a-b);
153    }
154 }
155
156 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
157 {
158    int i, c;
159    for (c=0;c<C;c++)
160    {
161       for (i=0;i<B;i++)
162       {
163          int j;
164          float x[2*N];
165          float tmp[N];
166          for (j=0;j<2*N;j++)
167             x[j] = window[j]*in[C*i*N+C*j+c];
168          mdct_forward(mdct_lookup, x, tmp);
169          /* Interleaving the sub-frames */
170          for (j=0;j<N;j++)
171             out[C*B*j+C*i+c] = tmp[j];
172       }
173    }
174 }
175
176 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int B, int C)
177 {
178    int i, c;
179    for (c=0;c<C;c++)
180    {
181       for (i=0;i<B;i++)
182       {
183          int j;
184          float x[2*N];
185          float tmp[N];
186          /* De-interleaving the sub-frames */
187          for (j=0;j<N;j++)
188             tmp[j] = X[C*B*j+C*i+c];
189          mdct_backward(mdct_lookup, tmp, x);
190          for (j=0;j<2*N;j++)
191             x[j] = window[j]*x[j];
192          for (j=0;j<N;j++)
193             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[j]+mdct_overlap[C*j+c];
194          for (j=0;j<N;j++)
195             mdct_overlap[C*j+c] = x[N+j];
196       }
197    }
198 }
199
200 int celt_encode(CELTEncoder *st, short *pcm)
201 {
202    int i, c, N, B, C;
203    N = st->block_size;
204    B = st->nb_blocks;
205    C = st->mode->nbChannels;
206    float in[(B+1)*C*N];
207    
208    float X[B*C*N];         /**< Interleaved signal MDCTs */
209    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
210    float bandE[st->mode->nbEBands];
211    float gains[st->mode->nbPBands];
212    int pitch_index;
213    
214    for (c=0;c<C;c++)
215    {
216       for (i=0;i<N;i++)
217          in[C*i+c] = st->in_mem[C*i+c];
218       for (;i<(B+1)*N;i++)
219       {
220          float tmp = pcm[C*(i-N)+c];
221          in[C*i+c] = tmp - st->preemph*st->preemph_memE[c];
222          st->preemph_memE[c] = tmp;
223       }
224       for (i=0;i<N;i++)
225          st->in_mem[C*i+c] = in[C*(B*N+i)+c];
226    }
227    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
228    /* Compute MDCTs */
229    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
230    
231    /* Pitch analysis */
232    for (c=0;c<C;c++)
233    {
234       for (i=0;i<N;i++)
235       {
236          in[C*i+c] *= st->window[i];
237          in[C*(B*N+i)+c] *= st->window[N+i];
238       }
239    }
240    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
241    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
242    
243    /* Compute MDCTs of the pitch part */
244    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
245    
246    /*int j;
247    for (j=0;j<B*N;j++)
248       printf ("%f ", X[j]);
249    for (j=0;j<B*N;j++)
250       printf ("%f ", P[j]);
251    printf ("\n");*/
252    if (C==2)
253    {
254       haar1(X, B*N);
255       haar1(P, B*N);
256    }
257    
258    /* Band normalisation */
259    compute_band_energies(st->mode, X, bandE);
260    normalise_bands(st->mode, X, bandE);
261    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
262    
263    {
264       float bandEp[st->mode->nbEBands];
265       compute_band_energies(st->mode, P, bandEp);
266       normalise_bands(st->mode, P, bandEp);
267    }
268    
269    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
270    
271    /* Pitch prediction */
272    compute_pitch_gain(st->mode, X, P, gains, bandE);
273    quant_pitch(gains, st->mode->nbPBands, &st->enc);
274    pitch_quant_bands(st->mode, X, P, gains);
275
276    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
277    /* Subtract the pitch prediction from the signal to encode */
278    for (i=0;i<B*C*N;i++)
279       X[i] -= P[i];
280
281    /*float sum=0;
282    for (i=0;i<B*N;i++)
283       sum += X[i]*X[i];
284    printf ("%f\n", sum);*/
285    /* Residual quantisation */
286    quant_bands(st->mode, X, P, &st->enc);
287    
288    if (0) {//This is just for debugging
289       ec_enc_done(&st->enc);
290       ec_dec dec;
291       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
292       ec_dec_init(&dec,&st->buf);
293
294       unquant_bands(st->mode, X, P, &dec);
295       //printf ("\n");
296    }
297    
298    /* Synthesis */
299    denormalise_bands(st->mode, X, bandE);
300
301    if (C==2)
302       inv_haar1(X, B*N);
303
304    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
305    /* Compute inverse MDCTs */
306    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
307
308    for (c=0;c<C;c++)
309    {
310       for (i=0;i<B;i++)
311       {
312          int j;
313          for (j=0;j<N;j++)
314          {
315             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
316             st->preemph_memD[c] = tmp;
317             if (tmp > 32767) tmp = 32767;
318             if (tmp < -32767) tmp = -32767;
319             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
320          }
321       }
322    }
323    return 0;
324 }
325
326 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
327 {
328    char *data;
329    ec_enc_done(&st->enc);
330    *nbBytes = ec_byte_bytes(&st->buf);
331    data = ec_byte_get_buffer(&st->buf);
332    //printf ("%d\n", *nbBytes);
333    
334    /* Reset the packing for the next encoding */
335    ec_byte_reset(&st->buf);
336    ec_enc_init(&st->enc,&st->buf);
337
338    return data;
339 }
340
341
342 /****************************************************************************/
343 /*                                                                          */
344 /*                                DECODER                                   */
345 /*                                                                          */
346 /****************************************************************************/
347
348
349
350 struct CELTDecoder {
351    const CELTMode *mode;
352    int frame_size;
353    int block_size;
354    int nb_blocks;
355    
356    ec_byte_buffer buf;
357    ec_enc         enc;
358
359    float preemph;
360    float *preemph_memD;
361    
362    mdct_lookup mdct_lookup;
363    
364    float *window;
365    float *mdct_overlap;
366    float *out_mem;
367
368    float *oldBandE;
369    
370    int last_pitch_index;
371 };
372
373 CELTDecoder *celt_decoder_new(const CELTMode *mode)
374 {
375    int i, N, B, C;
376    N = mode->mdctSize;
377    B = mode->nbMdctBlocks;
378    C = mode->nbChannels;
379    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
380    
381    st->mode = mode;
382    st->frame_size = B*N;
383    st->block_size = N;
384    st->nb_blocks  = B;
385    
386    mdct_init(&st->mdct_lookup, 2*N);
387    
388    st->window = celt_alloc(2*N*sizeof(float));
389    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
390    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
391    for (i=0;i<N;i++)
392       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
393    
394    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
395
396    st->preemph = 0.8;
397    st->preemph_memD = celt_alloc(C*sizeof(float));;
398
399    st->last_pitch_index = 0;
400    return st;
401 }
402
403 void celt_decoder_destroy(CELTDecoder *st)
404 {
405    if (st == NULL)
406    {
407       celt_warning("NULL passed to celt_encoder_destroy");
408       return;
409    }
410
411    mdct_clear(&st->mdct_lookup);
412
413    celt_free(st->window);
414    celt_free(st->mdct_overlap);
415    celt_free(st->out_mem);
416    
417    celt_free(st->oldBandE);
418    celt_free(st);
419 }
420
421 static void celt_decode_lost(CELTDecoder *st, short *pcm)
422 {
423    int i, c, N, B, C;
424    N = st->block_size;
425    B = st->nb_blocks;
426    C = st->mode->nbChannels;
427    float X[C*B*N];         /**< Interleaved signal MDCTs */
428    int pitch_index;
429    
430    pitch_index = st->last_pitch_index;
431    
432    /* Use the pitch MDCT as the "guessed" signal */
433    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
434
435    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
436    /* Compute inverse MDCTs */
437    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
438
439    for (c=0;c<C;c++)
440    {
441       for (i=0;i<B;i++)
442       {
443          int j;
444          for (j=0;j<N;j++)
445          {
446             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
447             st->preemph_memD[c] = tmp;
448             if (tmp > 32767) tmp = 32767;
449             if (tmp < -32767) tmp = -32767;
450             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
451          }
452       }
453    }
454 }
455
456 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
457 {
458    int i, c, N, B, C;
459    N = st->block_size;
460    B = st->nb_blocks;
461    C = st->mode->nbChannels;
462    
463    float X[C*B*N];         /**< Interleaved signal MDCTs */
464    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
465    float bandE[st->mode->nbEBands];
466    float gains[st->mode->nbPBands];
467    int pitch_index;
468    ec_dec dec;
469    ec_byte_buffer buf;
470    
471    if (data == NULL)
472    {
473       celt_decode_lost(st, pcm);
474       return 0;
475    }
476    
477    ec_byte_readinit(&buf,data,len);
478    ec_dec_init(&dec,&buf);
479    
480    /* Get the pitch index */
481    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
482    st->last_pitch_index = pitch_index;
483    
484    /* Get band energies */
485    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
486    
487    /* Pitch MDCT */
488    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
489
490    if (C==2)
491       haar1(P, B*N);
492
493    {
494       float bandEp[st->mode->nbEBands];
495       compute_band_energies(st->mode, P, bandEp);
496       normalise_bands(st->mode, P, bandEp);
497    }
498
499    /* Get the pitch gains */
500    unquant_pitch(gains, st->mode->nbPBands, &dec);
501
502    /* Apply pitch gains */
503    pitch_quant_bands(st->mode, X, P, gains);
504
505    /* Decode fixed codebook and merge with pitch */
506    unquant_bands(st->mode, X, P, &dec);
507
508    /* Synthesis */
509    denormalise_bands(st->mode, X, bandE);
510
511    if (C==2)
512       inv_haar1(X, B*N);
513
514    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
515    /* Compute inverse MDCTs */
516    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
517
518    for (c=0;c<C;c++)
519    {
520       for (i=0;i<B;i++)
521       {
522          int j;
523          for (j=0;j<N;j++)
524          {
525             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
526             st->preemph_memD[c] = tmp;
527             if (tmp > 32767) tmp = 32767;
528             if (tmp < -32767) tmp = -32767;
529             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
530          }
531       }
532    }
533    return 0;
534    //printf ("\n");
535 }
536