More work on decoding (still cheating)
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41
42 #define MAX_PERIOD 1024
43
44 /* This is only for cheating until the decoder is complete */
45 float cheating_ebands[100];
46 float cheating_pitch_gains[100];
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54       
55    ec_byte_buffer buf;
56    ec_enc         enc;
57
58    float preemph;
59    float preemph_memE;
60    float preemph_memD;
61    
62    mdct_lookup mdct_lookup;
63    void *fft;
64    
65    float *window;
66    float *in_mem;
67    float *mdct_overlap;
68    float *out_mem;
69
70    float *oldBandE;
71 };
72
73
74
75 CELTEncoder *celt_encoder_new(const CELTMode *mode)
76 {
77    int i, N, B;
78    N = mode->mdctSize;
79    B = mode->nbMdctBlocks;
80    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
81    
82    st->mode = mode;
83    st->frame_size = B*N;
84    st->block_size = N;
85    st->nb_blocks  = B;
86    
87    ec_byte_writeinit(&st->buf);
88    ec_enc_init(&st->enc,&st->buf);
89
90    mdct_init(&st->mdct_lookup, 2*N);
91    st->fft = spx_fft_init(MAX_PERIOD);
92    
93    st->window = celt_alloc(2*N*sizeof(float));
94    st->in_mem = celt_alloc(N*sizeof(float));
95    st->mdct_overlap = celt_alloc(N*sizeof(float));
96    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
97    for (i=0;i<N;i++)
98       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
99    
100    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
101
102    st->preemph = 0.8;
103    return st;
104 }
105
106 void celt_encoder_destroy(CELTEncoder *st)
107 {
108    if (st == NULL)
109    {
110       celt_warning("NULL passed to celt_encoder_destroy");
111       return;
112    }
113    ec_byte_writeclear(&st->buf);
114
115    mdct_clear(&st->mdct_lookup);
116    spx_fft_destroy(st->fft);
117
118    celt_free(st->window);
119    celt_free(st->in_mem);
120    celt_free(st->mdct_overlap);
121    celt_free(st->out_mem);
122    
123    celt_free(st->oldBandE);
124    celt_free(st);
125 }
126
127 static void haar1(float *X, int N)
128 {
129    int i;
130    for (i=0;i<N;i+=2)
131    {
132       float a, b;
133       a = X[i];
134       b = X[i+1];
135       X[i] = .707107f*(a+b);
136       X[i+1] = .707107f*(a-b);
137    }
138 }
139
140 static void inv_haar1(float *X, int N)
141 {
142    int i;
143    for (i=0;i<N;i+=2)
144    {
145       float a, b;
146       a = X[i];
147       b = X[i+1];
148       X[i] = .707107f*(a+b);
149       X[i+1] = .707107f*(a-b);
150    }
151 }
152
153 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
154 {
155    int i;
156    for (i=0;i<B;i++)
157    {
158       int j;
159       float x[2*N];
160       float tmp[N];
161       for (j=0;j<2*N;j++)
162          x[j] = window[j]*in[i*N+j];
163       mdct_forward(mdct_lookup, x, tmp);
164       /* Interleaving the sub-frames */
165       for (j=0;j<N;j++)
166          out[B*j+i] = tmp[j];
167    }
168
169 }
170
171 int celt_encode(CELTEncoder *st, short *pcm)
172 {
173    int i, N, B;
174    N = st->block_size;
175    B = st->nb_blocks;
176    float in[(B+1)*N];
177    
178    float X[B*N];         /**< Interleaved signal MDCTs */
179    float P[B*N];         /**< Interleaved pitch MDCTs*/
180    float bandE[st->mode->nbEBands];
181    float gains[st->mode->nbPBands];
182    int pitch_index;
183    
184    for (i=0;i<N;i++)
185       in[i] = st->in_mem[i];
186    for (;i<(B+1)*N;i++)
187    {
188       float tmp = pcm[i-N];
189       in[i] = tmp - st->preemph*st->preemph_memE;
190       st->preemph_memE = tmp;
191    }
192    for (i=0;i<N;i++)
193       st->in_mem[i] = in[B*N+i];
194
195    /* Compute MDCTs */
196    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
197    
198    /* Pitch analysis */
199    for (i=0;i<N;i++)
200    {
201       in[i] *= st->window[i];
202       in[B*N+i] *= st->window[N+i];
203    }
204    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
205    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
206    
207    /* Compute MDCTs of the pitch part */
208    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
209    
210    /*int j;
211    for (j=0;j<B*N;j++)
212       printf ("%f ", X[j]);
213    for (j=0;j<B*N;j++)
214       printf ("%f ", P[j]);
215    printf ("\n");*/
216    //haar1(X, B*N);
217    //haar1(P, B*N);
218    
219    /* Band normalisation */
220    compute_band_energies(st->mode, X, bandE);
221    normalise_bands(st->mode, X, bandE);
222    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
223    
224    {
225       float bandEp[st->mode->nbEBands];
226       compute_band_energies(st->mode, P, bandEp);
227       normalise_bands(st->mode, P, bandEp);
228    }
229    
230    quant_energy(st->mode, bandE, st->oldBandE);
231    
232    /* Pitch prediction */
233    compute_pitch_gain(st->mode, X, P, gains, bandE);
234    //quantise_pitch(gains, PBANDS);
235    pitch_quant_bands(st->mode, X, P, gains);
236    
237    for (i=0;i<st->mode->nbEBands;i++)
238       cheating_ebands[i] = bandE[i];
239    for (i=0;i<st->mode->nbPBands;i++)
240       cheating_pitch_gains[i] = gains[i];
241
242    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
243    /* Subtract the pitch prediction from the signal to encode */
244    for (i=0;i<B*N;i++)
245       X[i] -= P[i];
246
247    /*float sum=0;
248    for (i=0;i<B*N;i++)
249       sum += X[i]*X[i];
250    printf ("%f\n", sum);*/
251    /* Residual quantisation */
252    quant_bands(st->mode, X, P, &st->enc);
253    
254    if (0) {//This is just for debugging
255       ec_enc_done(&st->enc);
256       ec_dec dec;
257       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
258       ec_dec_init(&dec,&st->buf);
259
260       unquant_bands(st->mode, X, P, &dec);
261       //printf ("\n");
262    }
263    
264    /* Synthesis */
265    denormalise_bands(st->mode, X, bandE);
266
267    //inv_haar1(X, B*N);
268
269    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
270    /* Compute inverse MDCTs */
271    for (i=0;i<B;i++)
272    {
273       int j;
274       float x[2*N];
275       float tmp[N];
276       /* De-interleaving the sub-frames */
277       for (j=0;j<N;j++)
278          tmp[j] = X[B*j+i];
279       mdct_backward(&st->mdct_lookup, tmp, x);
280       for (j=0;j<2*N;j++)
281          x[j] = st->window[j]*x[j];
282       for (j=0;j<N;j++)
283          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
284       for (j=0;j<N;j++)
285          st->mdct_overlap[j] = x[N+j];
286       
287       for (j=0;j<N;j++)
288       {
289          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
290          st->preemph_memD = tmp;
291          pcm[i*N+j] = (short)floor(.5+tmp);
292       }
293    }
294    return 0;
295 }
296
297 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
298 {
299    char *data;
300    ec_enc_done(&st->enc);
301    *nbBytes = ec_byte_bytes(&st->buf);
302    data = ec_byte_get_buffer(&st->buf);
303    //printf ("%d\n", *nbBytes);
304    
305    /* Reset the packing for the next encoding */
306    ec_byte_reset(&st->buf);
307    ec_enc_init(&st->enc,&st->buf);
308
309    return data;
310 }
311
312
313 /****************************************************************************/
314 /*                                                                          */
315 /*                                DECODER                                   */
316 /*                                                                          */
317 /****************************************************************************/
318
319
320
321 struct CELTDecoder {
322    const CELTMode *mode;
323    int frame_size;
324    int block_size;
325    int nb_blocks;
326    
327    ec_byte_buffer buf;
328    ec_enc         enc;
329
330    float preemph;
331    float preemph_memD;
332    
333    mdct_lookup mdct_lookup;
334    
335    float *window;
336    float *mdct_overlap;
337    float *out_mem;
338
339    float *oldBandE;
340 };
341
342 CELTDecoder *celt_decoder_new(const CELTMode *mode)
343 {
344    int i, N, B;
345    N = mode->mdctSize;
346    B = mode->nbMdctBlocks;
347    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
348    
349    st->mode = mode;
350    st->frame_size = B*N;
351    st->block_size = N;
352    st->nb_blocks  = B;
353    
354    mdct_init(&st->mdct_lookup, 2*N);
355    
356    st->window = celt_alloc(2*N*sizeof(float));
357    st->mdct_overlap = celt_alloc(N*sizeof(float));
358    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
359    for (i=0;i<N;i++)
360       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
361    
362    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
363
364    st->preemph = 0.8;
365    return st;
366 }
367
368 void celt_decoder_destroy(CELTDecoder *st)
369 {
370    if (st == NULL)
371    {
372       celt_warning("NULL passed to celt_encoder_destroy");
373       return;
374    }
375
376    mdct_clear(&st->mdct_lookup);
377
378    celt_free(st->window);
379    celt_free(st->mdct_overlap);
380    celt_free(st->out_mem);
381    
382    celt_free(st->oldBandE);
383    celt_free(st);
384 }
385
386 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
387 {
388    int i, N, B;
389    N = st->block_size;
390    B = st->nb_blocks;
391    
392    float X[B*N];         /**< Interleaved signal MDCTs */
393    float P[B*N];         /**< Interleaved pitch MDCTs*/
394    float bandE[st->mode->nbEBands];
395    float gains[st->mode->nbPBands];
396    int pitch_index;
397    ec_dec dec;
398    ec_byte_buffer buf;
399    
400    ec_byte_readinit(&buf,data,len);
401    ec_dec_init(&dec,&buf);
402    
403    /* Get band energies */
404    for (i=0;i<st->mode->nbEBands;i++)
405       bandE[i] = cheating_ebands[i];
406
407    /* Get the pitch index */
408    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);;
409    
410    /* Pitch MDCT */
411    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
412
413    //haar1(P, B*N);
414
415    {
416       float bandEp[st->mode->nbEBands];
417       compute_band_energies(st->mode, P, bandEp);
418       normalise_bands(st->mode, P, bandEp);
419    }
420
421    /* Get the pitch gains */
422    for (i=0;i<st->mode->nbPBands;i++)
423       gains[i] = cheating_pitch_gains[i];
424
425    /* Apply pitch gains */
426    pitch_quant_bands(st->mode, X, P, gains);
427
428    /* Decode fixed codebook and merge with pitch */
429    unquant_bands(st->mode, X, P, &dec);
430
431    /* Synthesis */
432    denormalise_bands(st->mode, X, bandE);
433
434    //inv_haar1(X, B*N);
435
436    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
437    /* Compute inverse MDCTs */
438    for (i=0;i<B;i++)
439    {
440       int j;
441       float x[2*N];
442       float tmp[N];
443       /* De-interleaving the sub-frames */
444       for (j=0;j<N;j++)
445          tmp[j] = X[B*j+i];
446       mdct_backward(&st->mdct_lookup, tmp, x);
447       for (j=0;j<2*N;j++)
448          x[j] = st->window[j]*x[j];
449       for (j=0;j<N;j++)
450          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
451       for (j=0;j<N;j++)
452          st->mdct_overlap[j] = x[N+j];
453       
454       for (j=0;j<N;j++)
455       {
456          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
457          st->preemph_memD = tmp;
458          pcm[i*N+j] = (short)floor(.5+tmp);
459       }
460    }
461    //printf ("\n");
462 }
463