code cleanup: all inverse MDCTs in the same function
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43
44 #define MAX_PERIOD 1024
45
46
47 struct CELTEncoder {
48    const CELTMode *mode;
49    int frame_size;
50    int block_size;
51    int nb_blocks;
52       
53    ec_byte_buffer buf;
54    ec_enc         enc;
55
56    float preemph;
57    float preemph_memE;
58    float preemph_memD;
59    
60    mdct_lookup mdct_lookup;
61    void *fft;
62    
63    float *window;
64    float *in_mem;
65    float *mdct_overlap;
66    float *out_mem;
67
68    float *oldBandE;
69 };
70
71
72
73 CELTEncoder *celt_encoder_new(const CELTMode *mode)
74 {
75    int i, N, B;
76    N = mode->mdctSize;
77    B = mode->nbMdctBlocks;
78    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
79    
80    st->mode = mode;
81    st->frame_size = B*N;
82    st->block_size = N;
83    st->nb_blocks  = B;
84    
85    ec_byte_writeinit(&st->buf);
86    ec_enc_init(&st->enc,&st->buf);
87
88    mdct_init(&st->mdct_lookup, 2*N);
89    st->fft = spx_fft_init(MAX_PERIOD);
90    
91    st->window = celt_alloc(2*N*sizeof(float));
92    st->in_mem = celt_alloc(N*sizeof(float));
93    st->mdct_overlap = celt_alloc(N*sizeof(float));
94    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
95    for (i=0;i<N;i++)
96       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
97    
98    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
99
100    st->preemph = 0.8;
101    return st;
102 }
103
104 void celt_encoder_destroy(CELTEncoder *st)
105 {
106    if (st == NULL)
107    {
108       celt_warning("NULL passed to celt_encoder_destroy");
109       return;
110    }
111    ec_byte_writeclear(&st->buf);
112
113    mdct_clear(&st->mdct_lookup);
114    spx_fft_destroy(st->fft);
115
116    celt_free(st->window);
117    celt_free(st->in_mem);
118    celt_free(st->mdct_overlap);
119    celt_free(st->out_mem);
120    
121    celt_free(st->oldBandE);
122    celt_free(st);
123 }
124
125 static void haar1(float *X, int N)
126 {
127    int i;
128    for (i=0;i<N;i+=2)
129    {
130       float a, b;
131       a = X[i];
132       b = X[i+1];
133       X[i] = .707107f*(a+b);
134       X[i+1] = .707107f*(a-b);
135    }
136 }
137
138 static void inv_haar1(float *X, int N)
139 {
140    int i;
141    for (i=0;i<N;i+=2)
142    {
143       float a, b;
144       a = X[i];
145       b = X[i+1];
146       X[i] = .707107f*(a+b);
147       X[i+1] = .707107f*(a-b);
148    }
149 }
150
151 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
152 {
153    int i;
154    for (i=0;i<B;i++)
155    {
156       int j;
157       float x[2*N];
158       float tmp[N];
159       for (j=0;j<2*N;j++)
160          x[j] = window[j]*in[i*N+j];
161       mdct_forward(mdct_lookup, x, tmp);
162       /* Interleaving the sub-frames */
163       for (j=0;j<N;j++)
164          out[B*j+i] = tmp[j];
165    }
166
167 }
168
169 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int B)
170 {
171    int i;
172    for (i=0;i<B;i++)
173    {
174       int j;
175       float x[2*N];
176       float tmp[N];
177       /* De-interleaving the sub-frames */
178       for (j=0;j<N;j++)
179          tmp[j] = X[B*j+i];
180       mdct_backward(mdct_lookup, tmp, x);
181       for (j=0;j<2*N;j++)
182          x[j] = window[j]*x[j];
183       for (j=0;j<N;j++)
184          out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+mdct_overlap[j];
185       for (j=0;j<N;j++)
186          mdct_overlap[j] = x[N+j];
187    }
188 }
189
190 int celt_encode(CELTEncoder *st, short *pcm)
191 {
192    int i, N, B;
193    N = st->block_size;
194    B = st->nb_blocks;
195    float in[(B+1)*N];
196    
197    float X[B*N];         /**< Interleaved signal MDCTs */
198    float P[B*N];         /**< Interleaved pitch MDCTs*/
199    float bandE[st->mode->nbEBands];
200    float gains[st->mode->nbPBands];
201    int pitch_index;
202    
203    for (i=0;i<N;i++)
204       in[i] = st->in_mem[i];
205    for (;i<(B+1)*N;i++)
206    {
207       float tmp = pcm[i-N];
208       in[i] = tmp - st->preemph*st->preemph_memE;
209       st->preemph_memE = tmp;
210    }
211    for (i=0;i<N;i++)
212       st->in_mem[i] = in[B*N+i];
213
214    /* Compute MDCTs */
215    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
216    
217    /* Pitch analysis */
218    for (i=0;i<N;i++)
219    {
220       in[i] *= st->window[i];
221       in[B*N+i] *= st->window[N+i];
222    }
223    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
224    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
225    
226    /* Compute MDCTs of the pitch part */
227    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
228    
229    /*int j;
230    for (j=0;j<B*N;j++)
231       printf ("%f ", X[j]);
232    for (j=0;j<B*N;j++)
233       printf ("%f ", P[j]);
234    printf ("\n");*/
235    //haar1(X, B*N);
236    //haar1(P, B*N);
237    
238    /* Band normalisation */
239    compute_band_energies(st->mode, X, bandE);
240    normalise_bands(st->mode, X, bandE);
241    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
242    
243    {
244       float bandEp[st->mode->nbEBands];
245       compute_band_energies(st->mode, P, bandEp);
246       normalise_bands(st->mode, P, bandEp);
247    }
248    
249    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
250    
251    /* Pitch prediction */
252    compute_pitch_gain(st->mode, X, P, gains, bandE);
253    quant_pitch(gains, st->mode->nbPBands, &st->enc);
254    pitch_quant_bands(st->mode, X, P, gains);
255
256    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
257    /* Subtract the pitch prediction from the signal to encode */
258    for (i=0;i<B*N;i++)
259       X[i] -= P[i];
260
261    /*float sum=0;
262    for (i=0;i<B*N;i++)
263       sum += X[i]*X[i];
264    printf ("%f\n", sum);*/
265    /* Residual quantisation */
266    quant_bands(st->mode, X, P, &st->enc);
267    
268    if (0) {//This is just for debugging
269       ec_enc_done(&st->enc);
270       ec_dec dec;
271       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
272       ec_dec_init(&dec,&st->buf);
273
274       unquant_bands(st->mode, X, P, &dec);
275       //printf ("\n");
276    }
277    
278    /* Synthesis */
279    denormalise_bands(st->mode, X, bandE);
280
281    //inv_haar1(X, B*N);
282
283    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
284    /* Compute inverse MDCTs */
285    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B);
286
287    for (i=0;i<B;i++)
288    {
289       int j;
290       for (j=0;j<N;j++)
291       {
292          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
293          st->preemph_memD = tmp;
294          pcm[i*N+j] = (short)floor(.5+tmp);
295       }
296    }
297    return 0;
298 }
299
300 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
301 {
302    char *data;
303    ec_enc_done(&st->enc);
304    *nbBytes = ec_byte_bytes(&st->buf);
305    data = ec_byte_get_buffer(&st->buf);
306    //printf ("%d\n", *nbBytes);
307    
308    /* Reset the packing for the next encoding */
309    ec_byte_reset(&st->buf);
310    ec_enc_init(&st->enc,&st->buf);
311
312    return data;
313 }
314
315
316 /****************************************************************************/
317 /*                                                                          */
318 /*                                DECODER                                   */
319 /*                                                                          */
320 /****************************************************************************/
321
322
323
324 struct CELTDecoder {
325    const CELTMode *mode;
326    int frame_size;
327    int block_size;
328    int nb_blocks;
329    
330    ec_byte_buffer buf;
331    ec_enc         enc;
332
333    float preemph;
334    float preemph_memD;
335    
336    mdct_lookup mdct_lookup;
337    
338    float *window;
339    float *mdct_overlap;
340    float *out_mem;
341
342    float *oldBandE;
343    
344    int last_pitch_index;
345 };
346
347 CELTDecoder *celt_decoder_new(const CELTMode *mode)
348 {
349    int i, N, B;
350    N = mode->mdctSize;
351    B = mode->nbMdctBlocks;
352    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
353    
354    st->mode = mode;
355    st->frame_size = B*N;
356    st->block_size = N;
357    st->nb_blocks  = B;
358    
359    mdct_init(&st->mdct_lookup, 2*N);
360    
361    st->window = celt_alloc(2*N*sizeof(float));
362    st->mdct_overlap = celt_alloc(N*sizeof(float));
363    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
364    for (i=0;i<N;i++)
365       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
366    
367    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
368
369    st->preemph = 0.8;
370    
371    st->last_pitch_index = 0;
372    return st;
373 }
374
375 void celt_decoder_destroy(CELTDecoder *st)
376 {
377    if (st == NULL)
378    {
379       celt_warning("NULL passed to celt_encoder_destroy");
380       return;
381    }
382
383    mdct_clear(&st->mdct_lookup);
384
385    celt_free(st->window);
386    celt_free(st->mdct_overlap);
387    celt_free(st->out_mem);
388    
389    celt_free(st->oldBandE);
390    celt_free(st);
391 }
392
393 int celt_decode_lost(CELTDecoder *st, short *pcm)
394 {
395    int i, N, B;
396    N = st->block_size;
397    B = st->nb_blocks;
398    
399    float X[B*N];         /**< Interleaved signal MDCTs */
400    float P[B*N];         /**< Interleaved pitch MDCTs*/
401    float bandE[st->mode->nbEBands];
402    float gains[st->mode->nbPBands];
403    int pitch_index;
404    
405    pitch_index = st->last_pitch_index;
406    
407    /* Use the pitch MDCT as the "guessed" signal */
408    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, X, N, B);
409
410    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
411    /* Compute inverse MDCTs */
412    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B);
413
414    for (i=0;i<B;i++)
415    {
416       int j;
417       for (j=0;j<N;j++)
418       {
419          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
420          st->preemph_memD = tmp;
421          pcm[i*N+j] = (short)floor(.5+tmp);
422       }
423    }
424
425 }
426
427 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
428 {
429    int i, N, B;
430    N = st->block_size;
431    B = st->nb_blocks;
432    
433    float X[B*N];         /**< Interleaved signal MDCTs */
434    float P[B*N];         /**< Interleaved pitch MDCTs*/
435    float bandE[st->mode->nbEBands];
436    float gains[st->mode->nbPBands];
437    int pitch_index;
438    ec_dec dec;
439    ec_byte_buffer buf;
440    
441    if (data == NULL)
442    {
443       celt_decode_lost(st, pcm);
444       return 0;
445    }
446    
447    ec_byte_readinit(&buf,data,len);
448    ec_dec_init(&dec,&buf);
449    
450    /* Get the pitch index */
451    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);;
452    st->last_pitch_index = pitch_index;
453    
454    /* Get band energies */
455    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
456    
457    /* Pitch MDCT */
458    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
459
460    //haar1(P, B*N);
461
462    {
463       float bandEp[st->mode->nbEBands];
464       compute_band_energies(st->mode, P, bandEp);
465       normalise_bands(st->mode, P, bandEp);
466    }
467
468    /* Get the pitch gains */
469    unquant_pitch(gains, st->mode->nbPBands, &dec);
470
471    /* Apply pitch gains */
472    pitch_quant_bands(st->mode, X, P, gains);
473
474    /* Decode fixed codebook and merge with pitch */
475    unquant_bands(st->mode, X, P, &dec);
476
477    /* Synthesis */
478    denormalise_bands(st->mode, X, bandE);
479
480    //inv_haar1(X, B*N);
481
482    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
483    /* Compute inverse MDCTs */
484    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B);
485    for (i=0;i<B;i++)
486    {
487       int j;
488       for (j=0;j<N;j++)
489       {
490          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
491          st->preemph_memD = tmp;
492          pcm[i*N+j] = (short)floor(.5+tmp);
493       }
494    }
495    return 0;
496    //printf ("\n");
497 }
498