Very basic packet loss concealment code
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43
44 #define MAX_PERIOD 1024
45
46
47 struct CELTEncoder {
48    const CELTMode *mode;
49    int frame_size;
50    int block_size;
51    int nb_blocks;
52       
53    ec_byte_buffer buf;
54    ec_enc         enc;
55
56    float preemph;
57    float preemph_memE;
58    float preemph_memD;
59    
60    mdct_lookup mdct_lookup;
61    void *fft;
62    
63    float *window;
64    float *in_mem;
65    float *mdct_overlap;
66    float *out_mem;
67
68    float *oldBandE;
69 };
70
71
72
73 CELTEncoder *celt_encoder_new(const CELTMode *mode)
74 {
75    int i, N, B;
76    N = mode->mdctSize;
77    B = mode->nbMdctBlocks;
78    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
79    
80    st->mode = mode;
81    st->frame_size = B*N;
82    st->block_size = N;
83    st->nb_blocks  = B;
84    
85    ec_byte_writeinit(&st->buf);
86    ec_enc_init(&st->enc,&st->buf);
87
88    mdct_init(&st->mdct_lookup, 2*N);
89    st->fft = spx_fft_init(MAX_PERIOD);
90    
91    st->window = celt_alloc(2*N*sizeof(float));
92    st->in_mem = celt_alloc(N*sizeof(float));
93    st->mdct_overlap = celt_alloc(N*sizeof(float));
94    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
95    for (i=0;i<N;i++)
96       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
97    
98    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
99
100    st->preemph = 0.8;
101    return st;
102 }
103
104 void celt_encoder_destroy(CELTEncoder *st)
105 {
106    if (st == NULL)
107    {
108       celt_warning("NULL passed to celt_encoder_destroy");
109       return;
110    }
111    ec_byte_writeclear(&st->buf);
112
113    mdct_clear(&st->mdct_lookup);
114    spx_fft_destroy(st->fft);
115
116    celt_free(st->window);
117    celt_free(st->in_mem);
118    celt_free(st->mdct_overlap);
119    celt_free(st->out_mem);
120    
121    celt_free(st->oldBandE);
122    celt_free(st);
123 }
124
125 static void haar1(float *X, int N)
126 {
127    int i;
128    for (i=0;i<N;i+=2)
129    {
130       float a, b;
131       a = X[i];
132       b = X[i+1];
133       X[i] = .707107f*(a+b);
134       X[i+1] = .707107f*(a-b);
135    }
136 }
137
138 static void inv_haar1(float *X, int N)
139 {
140    int i;
141    for (i=0;i<N;i+=2)
142    {
143       float a, b;
144       a = X[i];
145       b = X[i+1];
146       X[i] = .707107f*(a+b);
147       X[i+1] = .707107f*(a-b);
148    }
149 }
150
151 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
152 {
153    int i;
154    for (i=0;i<B;i++)
155    {
156       int j;
157       float x[2*N];
158       float tmp[N];
159       for (j=0;j<2*N;j++)
160          x[j] = window[j]*in[i*N+j];
161       mdct_forward(mdct_lookup, x, tmp);
162       /* Interleaving the sub-frames */
163       for (j=0;j<N;j++)
164          out[B*j+i] = tmp[j];
165    }
166
167 }
168
169 int celt_encode(CELTEncoder *st, short *pcm)
170 {
171    int i, N, B;
172    N = st->block_size;
173    B = st->nb_blocks;
174    float in[(B+1)*N];
175    
176    float X[B*N];         /**< Interleaved signal MDCTs */
177    float P[B*N];         /**< Interleaved pitch MDCTs*/
178    float bandE[st->mode->nbEBands];
179    float gains[st->mode->nbPBands];
180    int pitch_index;
181    
182    for (i=0;i<N;i++)
183       in[i] = st->in_mem[i];
184    for (;i<(B+1)*N;i++)
185    {
186       float tmp = pcm[i-N];
187       in[i] = tmp - st->preemph*st->preemph_memE;
188       st->preemph_memE = tmp;
189    }
190    for (i=0;i<N;i++)
191       st->in_mem[i] = in[B*N+i];
192
193    /* Compute MDCTs */
194    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
195    
196    /* Pitch analysis */
197    for (i=0;i<N;i++)
198    {
199       in[i] *= st->window[i];
200       in[B*N+i] *= st->window[N+i];
201    }
202    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
203    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
204    
205    /* Compute MDCTs of the pitch part */
206    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
207    
208    /*int j;
209    for (j=0;j<B*N;j++)
210       printf ("%f ", X[j]);
211    for (j=0;j<B*N;j++)
212       printf ("%f ", P[j]);
213    printf ("\n");*/
214    //haar1(X, B*N);
215    //haar1(P, B*N);
216    
217    /* Band normalisation */
218    compute_band_energies(st->mode, X, bandE);
219    normalise_bands(st->mode, X, bandE);
220    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
221    
222    {
223       float bandEp[st->mode->nbEBands];
224       compute_band_energies(st->mode, P, bandEp);
225       normalise_bands(st->mode, P, bandEp);
226    }
227    
228    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
229    
230    /* Pitch prediction */
231    compute_pitch_gain(st->mode, X, P, gains, bandE);
232    quant_pitch(gains, st->mode->nbPBands, &st->enc);
233    pitch_quant_bands(st->mode, X, P, gains);
234
235    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
236    /* Subtract the pitch prediction from the signal to encode */
237    for (i=0;i<B*N;i++)
238       X[i] -= P[i];
239
240    /*float sum=0;
241    for (i=0;i<B*N;i++)
242       sum += X[i]*X[i];
243    printf ("%f\n", sum);*/
244    /* Residual quantisation */
245    quant_bands(st->mode, X, P, &st->enc);
246    
247    if (0) {//This is just for debugging
248       ec_enc_done(&st->enc);
249       ec_dec dec;
250       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
251       ec_dec_init(&dec,&st->buf);
252
253       unquant_bands(st->mode, X, P, &dec);
254       //printf ("\n");
255    }
256    
257    /* Synthesis */
258    denormalise_bands(st->mode, X, bandE);
259
260    //inv_haar1(X, B*N);
261
262    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
263    /* Compute inverse MDCTs */
264    for (i=0;i<B;i++)
265    {
266       int j;
267       float x[2*N];
268       float tmp[N];
269       /* De-interleaving the sub-frames */
270       for (j=0;j<N;j++)
271          tmp[j] = X[B*j+i];
272       mdct_backward(&st->mdct_lookup, tmp, x);
273       for (j=0;j<2*N;j++)
274          x[j] = st->window[j]*x[j];
275       for (j=0;j<N;j++)
276          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
277       for (j=0;j<N;j++)
278          st->mdct_overlap[j] = x[N+j];
279       
280       for (j=0;j<N;j++)
281       {
282          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
283          st->preemph_memD = tmp;
284          pcm[i*N+j] = (short)floor(.5+tmp);
285       }
286    }
287    return 0;
288 }
289
290 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
291 {
292    char *data;
293    ec_enc_done(&st->enc);
294    *nbBytes = ec_byte_bytes(&st->buf);
295    data = ec_byte_get_buffer(&st->buf);
296    //printf ("%d\n", *nbBytes);
297    
298    /* Reset the packing for the next encoding */
299    ec_byte_reset(&st->buf);
300    ec_enc_init(&st->enc,&st->buf);
301
302    return data;
303 }
304
305
306 /****************************************************************************/
307 /*                                                                          */
308 /*                                DECODER                                   */
309 /*                                                                          */
310 /****************************************************************************/
311
312
313
314 struct CELTDecoder {
315    const CELTMode *mode;
316    int frame_size;
317    int block_size;
318    int nb_blocks;
319    
320    ec_byte_buffer buf;
321    ec_enc         enc;
322
323    float preemph;
324    float preemph_memD;
325    
326    mdct_lookup mdct_lookup;
327    
328    float *window;
329    float *mdct_overlap;
330    float *out_mem;
331
332    float *oldBandE;
333    
334    int last_pitch_index;
335 };
336
337 CELTDecoder *celt_decoder_new(const CELTMode *mode)
338 {
339    int i, N, B;
340    N = mode->mdctSize;
341    B = mode->nbMdctBlocks;
342    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
343    
344    st->mode = mode;
345    st->frame_size = B*N;
346    st->block_size = N;
347    st->nb_blocks  = B;
348    
349    mdct_init(&st->mdct_lookup, 2*N);
350    
351    st->window = celt_alloc(2*N*sizeof(float));
352    st->mdct_overlap = celt_alloc(N*sizeof(float));
353    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
354    for (i=0;i<N;i++)
355       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
356    
357    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
358
359    st->preemph = 0.8;
360    
361    st->last_pitch_index = 0;
362    return st;
363 }
364
365 void celt_decoder_destroy(CELTDecoder *st)
366 {
367    if (st == NULL)
368    {
369       celt_warning("NULL passed to celt_encoder_destroy");
370       return;
371    }
372
373    mdct_clear(&st->mdct_lookup);
374
375    celt_free(st->window);
376    celt_free(st->mdct_overlap);
377    celt_free(st->out_mem);
378    
379    celt_free(st->oldBandE);
380    celt_free(st);
381 }
382
383 int celt_decode_lost(CELTDecoder *st, short *pcm)
384 {
385    int i, N, B;
386    N = st->block_size;
387    B = st->nb_blocks;
388    
389    float X[B*N];         /**< Interleaved signal MDCTs */
390    float P[B*N];         /**< Interleaved pitch MDCTs*/
391    float bandE[st->mode->nbEBands];
392    float gains[st->mode->nbPBands];
393    int pitch_index;
394    
395    pitch_index = st->last_pitch_index;
396    
397    /* Pitch MDCT */
398    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, X, N, B);
399
400    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
401    /* Compute inverse MDCTs */
402    for (i=0;i<B;i++)
403    {
404       int j;
405       float x[2*N];
406       float tmp[N];
407       /* De-interleaving the sub-frames */
408       for (j=0;j<N;j++)
409          tmp[j] = X[B*j+i];
410       mdct_backward(&st->mdct_lookup, tmp, x);
411       for (j=0;j<2*N;j++)
412          x[j] = st->window[j]*x[j];
413       for (j=0;j<N;j++)
414          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
415       for (j=0;j<N;j++)
416          st->mdct_overlap[j] = x[N+j];
417       
418       for (j=0;j<N;j++)
419       {
420          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
421          st->preemph_memD = tmp;
422          pcm[i*N+j] = (short)floor(.5+tmp);
423       }
424    }
425
426 }
427
428 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
429 {
430    int i, N, B;
431    N = st->block_size;
432    B = st->nb_blocks;
433    
434    float X[B*N];         /**< Interleaved signal MDCTs */
435    float P[B*N];         /**< Interleaved pitch MDCTs*/
436    float bandE[st->mode->nbEBands];
437    float gains[st->mode->nbPBands];
438    int pitch_index;
439    ec_dec dec;
440    ec_byte_buffer buf;
441    
442    if (data == NULL)
443    {
444       celt_decode_lost(st, pcm);
445       return 0;
446    }
447    
448    ec_byte_readinit(&buf,data,len);
449    ec_dec_init(&dec,&buf);
450    
451    /* Get the pitch index */
452    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);;
453    st->last_pitch_index = pitch_index;
454    
455    /* Get band energies */
456    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
457    
458    /* Pitch MDCT */
459    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
460
461    //haar1(P, B*N);
462
463    {
464       float bandEp[st->mode->nbEBands];
465       compute_band_energies(st->mode, P, bandEp);
466       normalise_bands(st->mode, P, bandEp);
467    }
468
469    /* Get the pitch gains */
470    unquant_pitch(gains, st->mode->nbPBands, &dec);
471
472    /* Apply pitch gains */
473    pitch_quant_bands(st->mode, X, P, gains);
474
475    /* Decode fixed codebook and merge with pitch */
476    unquant_bands(st->mode, X, P, &dec);
477
478    /* Synthesis */
479    denormalise_bands(st->mode, X, bandE);
480
481    //inv_haar1(X, B*N);
482
483    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
484    /* Compute inverse MDCTs */
485    for (i=0;i<B;i++)
486    {
487       int j;
488       float x[2*N];
489       float tmp[N];
490       /* De-interleaving the sub-frames */
491       for (j=0;j<N;j++)
492          tmp[j] = X[B*j+i];
493       mdct_backward(&st->mdct_lookup, tmp, x);
494       for (j=0;j<2*N;j++)
495          x[j] = st->window[j]*x[j];
496       for (j=0;j<N;j++)
497          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
498       for (j=0;j<N;j++)
499          st->mdct_overlap[j] = x[N+j];
500       
501       for (j=0;j<N;j++)
502       {
503          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
504          st->preemph_memD = tmp;
505          pcm[i*N+j] = (short)floor(.5+tmp);
506       }
507    }
508    return 0;
509    //printf ("\n");
510 }
511