bdea832504388cabee094cc75b3af4c7dc63613c
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41
42 #define MAX_PERIOD 1024
43
44 /* This is only for cheating until the decoder is complete */
45 float cheating_ebands[100];
46 float cheating_pitch_gains[100];
47 float cheating_period;
48
49
50 struct CELTEncoder {
51    const CELTMode *mode;
52    int frame_size;
53    int block_size;
54    int nb_blocks;
55       
56    ec_byte_buffer buf;
57    ec_enc         enc;
58
59    float preemph;
60    float preemph_memE;
61    float preemph_memD;
62    
63    mdct_lookup mdct_lookup;
64    void *fft;
65    
66    float *window;
67    float *in_mem;
68    float *mdct_overlap;
69    float *out_mem;
70
71    float *oldBandE;
72 };
73
74
75
76 CELTEncoder *celt_encoder_new(const CELTMode *mode)
77 {
78    int i, N, B;
79    N = mode->mdctSize;
80    B = mode->nbMdctBlocks;
81    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
82    
83    st->mode = mode;
84    st->frame_size = B*N;
85    st->block_size = N;
86    st->nb_blocks  = B;
87    
88    ec_byte_writeinit(&st->buf);
89
90    mdct_init(&st->mdct_lookup, 2*N);
91    st->fft = spx_fft_init(MAX_PERIOD);
92    
93    st->window = celt_alloc(2*N*sizeof(float));
94    st->in_mem = celt_alloc(N*sizeof(float));
95    st->mdct_overlap = celt_alloc(N*sizeof(float));
96    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
97    for (i=0;i<N;i++)
98       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
99    
100    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
101
102    st->preemph = 0.8;
103    return st;
104 }
105
106 void celt_encoder_destroy(CELTEncoder *st)
107 {
108    if (st == NULL)
109    {
110       celt_warning("NULL passed to celt_encoder_destroy");
111       return;
112    }
113    ec_byte_writeclear(&st->buf);
114
115    mdct_clear(&st->mdct_lookup);
116    spx_fft_destroy(st->fft);
117
118    celt_free(st->window);
119    celt_free(st->in_mem);
120    celt_free(st->mdct_overlap);
121    celt_free(st->out_mem);
122    
123    celt_free(st->oldBandE);
124    celt_free(st);
125 }
126
127 static void haar1(float *X, int N)
128 {
129    int i;
130    for (i=0;i<N;i+=2)
131    {
132       float a, b;
133       a = X[i];
134       b = X[i+1];
135       X[i] = .707107f*(a+b);
136       X[i+1] = .707107f*(a-b);
137    }
138 }
139
140 static void inv_haar1(float *X, int N)
141 {
142    int i;
143    for (i=0;i<N;i+=2)
144    {
145       float a, b;
146       a = X[i];
147       b = X[i+1];
148       X[i] = .707107f*(a+b);
149       X[i+1] = .707107f*(a-b);
150    }
151 }
152
153 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
154 {
155    int i;
156    for (i=0;i<B;i++)
157    {
158       int j;
159       float x[2*N];
160       float tmp[N];
161       for (j=0;j<2*N;j++)
162          x[j] = window[j]*in[i*N+j];
163       mdct_forward(mdct_lookup, x, tmp);
164       /* Interleaving the sub-frames */
165       for (j=0;j<N;j++)
166          out[B*j+i] = tmp[j];
167    }
168
169 }
170
171 int celt_encode(CELTEncoder *st, short *pcm)
172 {
173    int i, N, B;
174    N = st->block_size;
175    B = st->nb_blocks;
176    float in[(B+1)*N];
177    
178    float X[B*N];         /**< Interleaved signal MDCTs */
179    float P[B*N];         /**< Interleaved pitch MDCTs*/
180    float bandE[st->mode->nbEBands];
181    float gains[st->mode->nbPBands];
182    int pitch_index;
183    
184
185    ec_byte_reset(&st->buf);
186    ec_enc_init(&st->enc,&st->buf);
187
188    for (i=0;i<N;i++)
189       in[i] = st->in_mem[i];
190    for (;i<(B+1)*N;i++)
191    {
192       float tmp = pcm[i-N];
193       in[i] = tmp - st->preemph*st->preemph_memE;
194       st->preemph_memE = tmp;
195    }
196    for (i=0;i<N;i++)
197       st->in_mem[i] = in[B*N+i];
198
199    /* Compute MDCTs */
200    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
201    
202    /* Pitch analysis */
203    for (i=0;i<N;i++)
204    {
205       in[i] *= st->window[i];
206       in[B*N+i] *= st->window[N+i];
207    }
208    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
209    
210    /* Compute MDCTs of the pitch part */
211    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
212    
213    /*int j;
214    for (j=0;j<B*N;j++)
215       printf ("%f ", X[j]);
216    for (j=0;j<B*N;j++)
217       printf ("%f ", P[j]);
218    printf ("\n");*/
219    //haar1(X, B*N);
220    //haar1(P, B*N);
221    
222    /* Band normalisation */
223    compute_band_energies(st->mode, X, bandE);
224    normalise_bands(st->mode, X, bandE);
225    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
226    
227    {
228       float bandEp[st->mode->nbEBands];
229       compute_band_energies(st->mode, P, bandEp);
230       normalise_bands(st->mode, P, bandEp);
231    }
232    
233    quant_energy(st->mode, bandE, st->oldBandE);
234    
235    /* Pitch prediction */
236    compute_pitch_gain(st->mode, X, P, gains, bandE);
237    //quantise_pitch(gains, PBANDS);
238    pitch_quant_bands(st->mode, X, P, gains);
239    
240    cheating_period = pitch_index;
241    for (i=0;i<st->mode->nbEBands;i++)
242       cheating_ebands[i] = bandE[i];
243    for (i=0;i<st->mode->nbPBands;i++)
244       cheating_pitch_gains[i] = gains[i];
245
246    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
247    /* Subtract the pitch prediction from the signal to encode */
248    for (i=0;i<B*N;i++)
249       X[i] -= P[i];
250
251    /*float sum=0;
252    for (i=0;i<B*N;i++)
253       sum += X[i]*X[i];
254    printf ("%f\n", sum);*/
255    /* Residual quantisation */
256    quant_bands(st->mode, X, P, &st->enc);
257    
258    if (0) {//This is just for debugging
259       ec_enc_done(&st->enc);
260       ec_dec dec;
261       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
262       ec_dec_init(&dec,&st->buf);
263
264       unquant_bands(st->mode, X, P, &dec);
265       //printf ("\n");
266    }
267    
268    /* Synthesis */
269    denormalise_bands(st->mode, X, bandE);
270
271    //inv_haar1(X, B*N);
272
273    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
274    /* Compute inverse MDCTs */
275    for (i=0;i<B;i++)
276    {
277       int j;
278       float x[2*N];
279       float tmp[N];
280       /* De-interleaving the sub-frames */
281       for (j=0;j<N;j++)
282          tmp[j] = X[B*j+i];
283       mdct_backward(&st->mdct_lookup, tmp, x);
284       for (j=0;j<2*N;j++)
285          x[j] = st->window[j]*x[j];
286       for (j=0;j<N;j++)
287          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
288       for (j=0;j<N;j++)
289          st->mdct_overlap[j] = x[N+j];
290       
291       for (j=0;j<N;j++)
292       {
293          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
294          st->preemph_memD = tmp;
295          pcm[i*N+j] = (short)floor(.5+tmp);
296       }
297    }
298    ec_enc_done(&st->enc);
299    //printf ("%d\n", ec_byte_bytes(&st->buf));
300    return 0;
301 }
302
303 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
304 {
305    *nbBytes = ec_byte_bytes(&st->buf);
306    return ec_byte_get_buffer(&st->buf);
307 }
308
309
310 /****************************************************************************/
311 /*                                Decoder                                   */
312 /****************************************************************************/
313
314
315
316 struct CELTDecoder {
317    const CELTMode *mode;
318    int frame_size;
319    int block_size;
320    int nb_blocks;
321    
322    ec_byte_buffer buf;
323    ec_enc         enc;
324
325    float preemph;
326    float preemph_memD;
327    
328    mdct_lookup mdct_lookup;
329    
330    float *window;
331    float *mdct_overlap;
332    float *out_mem;
333
334    float *oldBandE;
335 };
336
337 CELTDecoder *celt_decoder_new(const CELTMode *mode)
338 {
339    int i, N, B;
340    N = mode->mdctSize;
341    B = mode->nbMdctBlocks;
342    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
343    
344    st->mode = mode;
345    st->frame_size = B*N;
346    st->block_size = N;
347    st->nb_blocks  = B;
348    
349    ec_byte_writeinit(&st->buf);
350
351    mdct_init(&st->mdct_lookup, 2*N);
352    
353    st->window = celt_alloc(2*N*sizeof(float));
354    st->mdct_overlap = celt_alloc(N*sizeof(float));
355    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
356    for (i=0;i<N;i++)
357       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
358    
359    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
360
361    st->preemph = 0.8;
362    return st;
363 }
364
365 void celt_decoder_destroy(CELTDecoder *st)
366 {
367    if (st == NULL)
368    {
369       celt_warning("NULL passed to celt_encoder_destroy");
370       return;
371    }
372
373    mdct_clear(&st->mdct_lookup);
374
375    celt_free(st->window);
376    celt_free(st->mdct_overlap);
377    celt_free(st->out_mem);
378    
379    celt_free(st->oldBandE);
380    celt_free(st);
381 }
382
383 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
384 {
385    int i, N, B;
386    N = st->block_size;
387    B = st->nb_blocks;
388    
389    float X[B*N];         /**< Interleaved signal MDCTs */
390    float P[B*N];         /**< Interleaved pitch MDCTs*/
391    float bandE[st->mode->nbEBands];
392    float gains[st->mode->nbPBands];
393    int pitch_index;
394    ec_dec dec;
395    ec_byte_buffer buf;
396    
397    ec_byte_readinit(&buf,data,len);
398    ec_dec_init(&dec,&buf);
399    
400    /* Get band energies */
401    for (i=0;i<st->mode->nbEBands;i++)
402       bandE[i] = cheating_ebands[i];
403
404    /* Get the pitch index */
405    pitch_index = cheating_period;
406    
407    /* Pitch MDCT */
408    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
409
410    //haar1(P, B*N);
411
412    {
413       float bandEp[st->mode->nbEBands];
414       compute_band_energies(st->mode, P, bandEp);
415       normalise_bands(st->mode, P, bandEp);
416    }
417
418    /* Get the pitch gains */
419    for (i=0;i<st->mode->nbPBands;i++)
420       gains[i] = cheating_pitch_gains[i];
421
422    /* Apply pitch gains */
423    pitch_quant_bands(st->mode, X, P, gains);
424
425    /* Decode fixed codebook and merge with pitch */
426    unquant_bands(st->mode, X, P, &dec);
427
428    /* Synthesis */
429    denormalise_bands(st->mode, X, bandE);
430
431    //inv_haar1(X, B*N);
432
433    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
434    /* Compute inverse MDCTs */
435    for (i=0;i<B;i++)
436    {
437       int j;
438       float x[2*N];
439       float tmp[N];
440       /* De-interleaving the sub-frames */
441       for (j=0;j<N;j++)
442          tmp[j] = X[B*j+i];
443       mdct_backward(&st->mdct_lookup, tmp, x);
444       for (j=0;j<2*N;j++)
445          x[j] = st->window[j]*x[j];
446       for (j=0;j<N;j++)
447          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
448       for (j=0;j<N;j++)
449          st->mdct_overlap[j] = x[N+j];
450       
451       for (j=0;j<N;j++)
452       {
453          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
454          st->preemph_memD = tmp;
455          pcm[i*N+j] = (short)floor(.5+tmp);
456       }
457    }
458    //printf ("\n");
459 }
460