More decoding work
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41
42 #define MAX_PERIOD 1024
43
44 struct CELTEncoder {
45    const CELTMode *mode;
46    int frame_size;
47    int block_size;
48    int nb_blocks;
49       
50    ec_byte_buffer buf;
51    ec_enc         enc;
52
53    float preemph;
54    float preemph_memE;
55    float preemph_memD;
56    
57    mdct_lookup mdct_lookup;
58    void *fft;
59    
60    float *window;
61    float *in_mem;
62    float *mdct_overlap;
63    float *out_mem;
64
65    float *oldBandE;
66 };
67
68
69
70 CELTEncoder *celt_encoder_new(const CELTMode *mode)
71 {
72    int i, N, B;
73    N = mode->mdctSize;
74    B = mode->nbMdctBlocks;
75    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
76    
77    st->mode = mode;
78    st->frame_size = B*N;
79    st->block_size = N;
80    st->nb_blocks  = B;
81    
82    ec_byte_writeinit(&st->buf);
83
84    mdct_init(&st->mdct_lookup, 2*N);
85    st->fft = spx_fft_init(MAX_PERIOD);
86    
87    st->window = celt_alloc(2*N*sizeof(float));
88    st->in_mem = celt_alloc(N*sizeof(float));
89    st->mdct_overlap = celt_alloc(N*sizeof(float));
90    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
91    for (i=0;i<N;i++)
92       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
93    
94    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
95
96    st->preemph = 0.8;
97    return st;
98 }
99
100 void celt_encoder_destroy(CELTEncoder *st)
101 {
102    if (st == NULL)
103    {
104       celt_warning("NULL passed to celt_encoder_destroy");
105       return;
106    }
107    ec_byte_writeclear(&st->buf);
108
109    mdct_clear(&st->mdct_lookup);
110    spx_fft_destroy(st->fft);
111
112    celt_free(st->window);
113    celt_free(st->in_mem);
114    celt_free(st->mdct_overlap);
115    celt_free(st->out_mem);
116    
117    celt_free(st->oldBandE);
118    celt_free(st);
119 }
120
121 static void haar1(float *X, int N)
122 {
123    int i;
124    for (i=0;i<N;i+=2)
125    {
126       float a, b;
127       a = X[i];
128       b = X[i+1];
129       X[i] = .707107f*(a+b);
130       X[i+1] = .707107f*(a-b);
131    }
132 }
133
134 static void inv_haar1(float *X, int N)
135 {
136    int i;
137    for (i=0;i<N;i+=2)
138    {
139       float a, b;
140       a = X[i];
141       b = X[i+1];
142       X[i] = .707107f*(a+b);
143       X[i+1] = .707107f*(a-b);
144    }
145 }
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
148 {
149    int i;
150    for (i=0;i<B;i++)
151    {
152       int j;
153       float x[2*N];
154       float tmp[N];
155       for (j=0;j<2*N;j++)
156          x[j] = window[j]*in[i*N+j];
157       mdct_forward(mdct_lookup, x, tmp);
158       /* Interleaving the sub-frames */
159       for (j=0;j<N;j++)
160          out[B*j+i] = tmp[j];
161    }
162
163 }
164
165 int celt_encode(CELTEncoder *st, short *pcm)
166 {
167    int i, N, B;
168    N = st->block_size;
169    B = st->nb_blocks;
170    float in[(B+1)*N];
171    
172    float X[B*N];         /**< Interleaved signal MDCTs */
173    float P[B*N];         /**< Interleaved pitch MDCTs*/
174    float bandE[st->mode->nbEBands];
175    float gains[st->mode->nbPBands];
176    int pitch_index;
177    
178
179    ec_byte_reset(&st->buf);
180    ec_enc_init(&st->enc,&st->buf);
181
182    for (i=0;i<N;i++)
183       in[i] = st->in_mem[i];
184    for (;i<(B+1)*N;i++)
185    {
186       float tmp = pcm[i-N];
187       in[i] = tmp - st->preemph*st->preemph_memE;
188       st->preemph_memE = tmp;
189    }
190    for (i=0;i<N;i++)
191       st->in_mem[i] = in[B*N+i];
192
193    /* Compute MDCTs */
194    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
195    
196    /* Pitch analysis */
197    for (i=0;i<N;i++)
198    {
199       in[i] *= st->window[i];
200       in[B*N+i] *= st->window[N+i];
201    }
202    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
203    
204    /* Compute MDCTs of the pitch part */
205    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
206    
207    /*int j;
208    for (j=0;j<B*N;j++)
209       printf ("%f ", X[j]);
210    for (j=0;j<B*N;j++)
211       printf ("%f ", P[j]);
212    printf ("\n");*/
213    //haar1(X, B*N);
214    //haar1(P, B*N);
215    
216    /* Band normalisation */
217    compute_band_energies(st->mode, X, bandE);
218    normalise_bands(st->mode, X, bandE);
219    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
220    
221    {
222       float bandEp[st->mode->nbEBands];
223       compute_band_energies(st->mode, P, bandEp);
224       normalise_bands(st->mode, P, bandEp);
225    }
226    
227    quant_energy(st->mode, bandE, st->oldBandE);
228    
229    /* Pitch prediction */
230    compute_pitch_gain(st->mode, X, P, gains, bandE);
231    //quantise_pitch(gains, PBANDS);
232    pitch_quant_bands(st->mode, X, P, gains);
233    
234    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
235    /* Subtract the pitch prediction from the signal to encode */
236    for (i=0;i<B*N;i++)
237       X[i] -= P[i];
238
239    /*float sum=0;
240    for (i=0;i<B*N;i++)
241       sum += X[i]*X[i];
242    printf ("%f\n", sum);*/
243    /* Residual quantisation */
244    quant_bands(st->mode, X, P, &st->enc);
245    
246    if (0) {//This is just for debugging
247       ec_enc_done(&st->enc);
248       ec_dec dec;
249       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
250       ec_dec_init(&dec,&st->buf);
251
252       unquant_bands(st->mode, X, P, &dec);
253       //printf ("\n");
254    }
255    
256    /* Synthesis */
257    denormalise_bands(st->mode, X, bandE);
258
259    //inv_haar1(X, B*N);
260
261    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
262    /* Compute inverse MDCTs */
263    for (i=0;i<B;i++)
264    {
265       int j;
266       float x[2*N];
267       float tmp[N];
268       /* De-interleaving the sub-frames */
269       for (j=0;j<N;j++)
270          tmp[j] = X[B*j+i];
271       mdct_backward(&st->mdct_lookup, tmp, x);
272       for (j=0;j<2*N;j++)
273          x[j] = st->window[j]*x[j];
274       for (j=0;j<N;j++)
275          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
276       for (j=0;j<N;j++)
277          st->mdct_overlap[j] = x[N+j];
278       
279       for (j=0;j<N;j++)
280       {
281          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
282          st->preemph_memD = tmp;
283          pcm[i*N+j] = (short)floor(.5+tmp);
284       }
285    }
286    ec_enc_done(&st->enc);
287    //printf ("%d\n", ec_byte_bytes(&st->buf));
288    return 0;
289 }
290
291 /****************************************************************************/
292 /*                                Decoder                                   */
293 /****************************************************************************/
294
295
296
297 struct CELTDecoder {
298    const CELTMode *mode;
299    int frame_size;
300    int block_size;
301    int nb_blocks;
302    
303    ec_byte_buffer buf;
304    ec_enc         enc;
305
306    float preemph;
307    float preemph_memD;
308    
309    mdct_lookup mdct_lookup;
310    
311    float *window;
312    float *mdct_overlap;
313    float *out_mem;
314
315    float *oldBandE;
316 };
317
318 CELTDecoder *celt_decoder_new(const CELTMode *mode)
319 {
320    int i, N, B;
321    N = mode->mdctSize;
322    B = mode->nbMdctBlocks;
323    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
324    
325    st->mode = mode;
326    st->frame_size = B*N;
327    st->block_size = N;
328    st->nb_blocks  = B;
329    
330    ec_byte_writeinit(&st->buf);
331
332    mdct_init(&st->mdct_lookup, 2*N);
333    
334    st->window = celt_alloc(2*N*sizeof(float));
335    st->mdct_overlap = celt_alloc(N*sizeof(float));
336    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
337    for (i=0;i<N;i++)
338       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
339    
340    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
341
342    st->preemph = 0.8;
343    return st;
344 }
345
346 void celt_decoder_destroy(CELTDecoder *st)
347 {
348    if (st == NULL)
349    {
350       celt_warning("NULL passed to celt_encoder_destroy");
351       return;
352    }
353
354    mdct_clear(&st->mdct_lookup);
355
356    celt_free(st->window);
357    celt_free(st->mdct_overlap);
358    celt_free(st->out_mem);
359    
360    celt_free(st->oldBandE);
361    celt_free(st);
362 }
363
364 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
365 {
366    int i, N, B;
367    N = st->block_size;
368    B = st->nb_blocks;
369    
370    float X[B*N];         /**< Interleaved signal MDCTs */
371    float P[B*N];         /**< Interleaved pitch MDCTs*/
372    float bandE[st->mode->nbEBands];
373    float gains[st->mode->nbPBands];
374    int pitch_index;
375    ec_dec dec;
376    ec_byte_buffer buf;
377    
378    ec_byte_readinit(&buf,data,len);
379    ec_dec_init(&dec,&buf);
380    
381    /* Get band energies */
382    
383    /* Get the pitch index */
384    
385    /* Pitch MDCT */
386    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
387
388    //haar1(P, B*N);
389
390    {
391       float bandEp[st->mode->nbEBands];
392       compute_band_energies(st->mode, P, bandEp);
393       normalise_bands(st->mode, P, bandEp);
394    }
395
396    /* Get the pitch gains */
397    
398    /* Apply pitch gains */
399    pitch_quant_bands(st->mode, X, P, gains);
400
401    /* Decode fixed codebook and merge with pitch */
402    unquant_bands(st->mode, X, P, &dec);
403
404    /* Synthesis */
405    denormalise_bands(st->mode, X, bandE);
406
407    //inv_haar1(X, B*N);
408
409    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
410    /* Compute inverse MDCTs */
411    for (i=0;i<B;i++)
412    {
413       int j;
414       float x[2*N];
415       float tmp[N];
416       /* De-interleaving the sub-frames */
417       for (j=0;j<N;j++)
418          tmp[j] = X[B*j+i];
419       mdct_backward(&st->mdct_lookup, tmp, x);
420       for (j=0;j<2*N;j++)
421          x[j] = st->window[j]*x[j];
422       for (j=0;j<N;j++)
423          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
424       for (j=0;j<N;j++)
425          st->mdct_overlap[j] = x[N+j];
426       
427       for (j=0;j<N;j++)
428       {
429          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
430          st->preemph_memD = tmp;
431          pcm[i*N+j] = (short)floor(.5+tmp);
432       }
433    }
434 }
435