No more cheating, everything fully quantised
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43
44 #define MAX_PERIOD 1024
45
46
47 struct CELTEncoder {
48    const CELTMode *mode;
49    int frame_size;
50    int block_size;
51    int nb_blocks;
52       
53    ec_byte_buffer buf;
54    ec_enc         enc;
55
56    float preemph;
57    float preemph_memE;
58    float preemph_memD;
59    
60    mdct_lookup mdct_lookup;
61    void *fft;
62    
63    float *window;
64    float *in_mem;
65    float *mdct_overlap;
66    float *out_mem;
67
68    float *oldBandE;
69 };
70
71
72
73 CELTEncoder *celt_encoder_new(const CELTMode *mode)
74 {
75    int i, N, B;
76    N = mode->mdctSize;
77    B = mode->nbMdctBlocks;
78    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
79    
80    st->mode = mode;
81    st->frame_size = B*N;
82    st->block_size = N;
83    st->nb_blocks  = B;
84    
85    ec_byte_writeinit(&st->buf);
86    ec_enc_init(&st->enc,&st->buf);
87
88    mdct_init(&st->mdct_lookup, 2*N);
89    st->fft = spx_fft_init(MAX_PERIOD);
90    
91    st->window = celt_alloc(2*N*sizeof(float));
92    st->in_mem = celt_alloc(N*sizeof(float));
93    st->mdct_overlap = celt_alloc(N*sizeof(float));
94    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
95    for (i=0;i<N;i++)
96       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
97    
98    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
99
100    st->preemph = 0.8;
101    return st;
102 }
103
104 void celt_encoder_destroy(CELTEncoder *st)
105 {
106    if (st == NULL)
107    {
108       celt_warning("NULL passed to celt_encoder_destroy");
109       return;
110    }
111    ec_byte_writeclear(&st->buf);
112
113    mdct_clear(&st->mdct_lookup);
114    spx_fft_destroy(st->fft);
115
116    celt_free(st->window);
117    celt_free(st->in_mem);
118    celt_free(st->mdct_overlap);
119    celt_free(st->out_mem);
120    
121    celt_free(st->oldBandE);
122    celt_free(st);
123 }
124
125 static void haar1(float *X, int N)
126 {
127    int i;
128    for (i=0;i<N;i+=2)
129    {
130       float a, b;
131       a = X[i];
132       b = X[i+1];
133       X[i] = .707107f*(a+b);
134       X[i+1] = .707107f*(a-b);
135    }
136 }
137
138 static void inv_haar1(float *X, int N)
139 {
140    int i;
141    for (i=0;i<N;i+=2)
142    {
143       float a, b;
144       a = X[i];
145       b = X[i+1];
146       X[i] = .707107f*(a+b);
147       X[i+1] = .707107f*(a-b);
148    }
149 }
150
151 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
152 {
153    int i;
154    for (i=0;i<B;i++)
155    {
156       int j;
157       float x[2*N];
158       float tmp[N];
159       for (j=0;j<2*N;j++)
160          x[j] = window[j]*in[i*N+j];
161       mdct_forward(mdct_lookup, x, tmp);
162       /* Interleaving the sub-frames */
163       for (j=0;j<N;j++)
164          out[B*j+i] = tmp[j];
165    }
166
167 }
168
169 int celt_encode(CELTEncoder *st, short *pcm)
170 {
171    int i, N, B;
172    N = st->block_size;
173    B = st->nb_blocks;
174    float in[(B+1)*N];
175    
176    float X[B*N];         /**< Interleaved signal MDCTs */
177    float P[B*N];         /**< Interleaved pitch MDCTs*/
178    float bandE[st->mode->nbEBands];
179    float gains[st->mode->nbPBands];
180    int pitch_index;
181    
182    for (i=0;i<N;i++)
183       in[i] = st->in_mem[i];
184    for (;i<(B+1)*N;i++)
185    {
186       float tmp = pcm[i-N];
187       in[i] = tmp - st->preemph*st->preemph_memE;
188       st->preemph_memE = tmp;
189    }
190    for (i=0;i<N;i++)
191       st->in_mem[i] = in[B*N+i];
192
193    /* Compute MDCTs */
194    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
195    
196    /* Pitch analysis */
197    for (i=0;i<N;i++)
198    {
199       in[i] *= st->window[i];
200       in[B*N+i] *= st->window[N+i];
201    }
202    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
203    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
204    
205    /* Compute MDCTs of the pitch part */
206    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
207    
208    /*int j;
209    for (j=0;j<B*N;j++)
210       printf ("%f ", X[j]);
211    for (j=0;j<B*N;j++)
212       printf ("%f ", P[j]);
213    printf ("\n");*/
214    //haar1(X, B*N);
215    //haar1(P, B*N);
216    
217    /* Band normalisation */
218    compute_band_energies(st->mode, X, bandE);
219    normalise_bands(st->mode, X, bandE);
220    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
221    
222    {
223       float bandEp[st->mode->nbEBands];
224       compute_band_energies(st->mode, P, bandEp);
225       normalise_bands(st->mode, P, bandEp);
226    }
227    
228    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
229    
230    /* Pitch prediction */
231    compute_pitch_gain(st->mode, X, P, gains, bandE);
232    quant_pitch(gains, st->mode->nbPBands, &st->enc);
233    pitch_quant_bands(st->mode, X, P, gains);
234
235    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
236    /* Subtract the pitch prediction from the signal to encode */
237    for (i=0;i<B*N;i++)
238       X[i] -= P[i];
239
240    /*float sum=0;
241    for (i=0;i<B*N;i++)
242       sum += X[i]*X[i];
243    printf ("%f\n", sum);*/
244    /* Residual quantisation */
245    quant_bands(st->mode, X, P, &st->enc);
246    
247    if (0) {//This is just for debugging
248       ec_enc_done(&st->enc);
249       ec_dec dec;
250       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
251       ec_dec_init(&dec,&st->buf);
252
253       unquant_bands(st->mode, X, P, &dec);
254       //printf ("\n");
255    }
256    
257    /* Synthesis */
258    denormalise_bands(st->mode, X, bandE);
259
260    //inv_haar1(X, B*N);
261
262    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
263    /* Compute inverse MDCTs */
264    for (i=0;i<B;i++)
265    {
266       int j;
267       float x[2*N];
268       float tmp[N];
269       /* De-interleaving the sub-frames */
270       for (j=0;j<N;j++)
271          tmp[j] = X[B*j+i];
272       mdct_backward(&st->mdct_lookup, tmp, x);
273       for (j=0;j<2*N;j++)
274          x[j] = st->window[j]*x[j];
275       for (j=0;j<N;j++)
276          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
277       for (j=0;j<N;j++)
278          st->mdct_overlap[j] = x[N+j];
279       
280       for (j=0;j<N;j++)
281       {
282          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
283          st->preemph_memD = tmp;
284          pcm[i*N+j] = (short)floor(.5+tmp);
285       }
286    }
287    return 0;
288 }
289
290 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
291 {
292    char *data;
293    ec_enc_done(&st->enc);
294    *nbBytes = ec_byte_bytes(&st->buf);
295    data = ec_byte_get_buffer(&st->buf);
296    //printf ("%d\n", *nbBytes);
297    
298    /* Reset the packing for the next encoding */
299    ec_byte_reset(&st->buf);
300    ec_enc_init(&st->enc,&st->buf);
301
302    return data;
303 }
304
305
306 /****************************************************************************/
307 /*                                                                          */
308 /*                                DECODER                                   */
309 /*                                                                          */
310 /****************************************************************************/
311
312
313
314 struct CELTDecoder {
315    const CELTMode *mode;
316    int frame_size;
317    int block_size;
318    int nb_blocks;
319    
320    ec_byte_buffer buf;
321    ec_enc         enc;
322
323    float preemph;
324    float preemph_memD;
325    
326    mdct_lookup mdct_lookup;
327    
328    float *window;
329    float *mdct_overlap;
330    float *out_mem;
331
332    float *oldBandE;
333 };
334
335 CELTDecoder *celt_decoder_new(const CELTMode *mode)
336 {
337    int i, N, B;
338    N = mode->mdctSize;
339    B = mode->nbMdctBlocks;
340    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
341    
342    st->mode = mode;
343    st->frame_size = B*N;
344    st->block_size = N;
345    st->nb_blocks  = B;
346    
347    mdct_init(&st->mdct_lookup, 2*N);
348    
349    st->window = celt_alloc(2*N*sizeof(float));
350    st->mdct_overlap = celt_alloc(N*sizeof(float));
351    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
352    for (i=0;i<N;i++)
353       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
354    
355    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
356
357    st->preemph = 0.8;
358    return st;
359 }
360
361 void celt_decoder_destroy(CELTDecoder *st)
362 {
363    if (st == NULL)
364    {
365       celt_warning("NULL passed to celt_encoder_destroy");
366       return;
367    }
368
369    mdct_clear(&st->mdct_lookup);
370
371    celt_free(st->window);
372    celt_free(st->mdct_overlap);
373    celt_free(st->out_mem);
374    
375    celt_free(st->oldBandE);
376    celt_free(st);
377 }
378
379 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
380 {
381    int i, N, B;
382    N = st->block_size;
383    B = st->nb_blocks;
384    
385    float X[B*N];         /**< Interleaved signal MDCTs */
386    float P[B*N];         /**< Interleaved pitch MDCTs*/
387    float bandE[st->mode->nbEBands];
388    float gains[st->mode->nbPBands];
389    int pitch_index;
390    ec_dec dec;
391    ec_byte_buffer buf;
392    
393    ec_byte_readinit(&buf,data,len);
394    ec_dec_init(&dec,&buf);
395    
396    /* Get the pitch index */
397    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);;
398    
399    /* Get band energies */
400    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
401    
402    /* Pitch MDCT */
403    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
404
405    //haar1(P, B*N);
406
407    {
408       float bandEp[st->mode->nbEBands];
409       compute_band_energies(st->mode, P, bandEp);
410       normalise_bands(st->mode, P, bandEp);
411    }
412
413    /* Get the pitch gains */
414    unquant_pitch(gains, st->mode->nbPBands, &dec);
415
416    /* Apply pitch gains */
417    pitch_quant_bands(st->mode, X, P, gains);
418
419    /* Decode fixed codebook and merge with pitch */
420    unquant_bands(st->mode, X, P, &dec);
421
422    /* Synthesis */
423    denormalise_bands(st->mode, X, bandE);
424
425    //inv_haar1(X, B*N);
426
427    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
428    /* Compute inverse MDCTs */
429    for (i=0;i<B;i++)
430    {
431       int j;
432       float x[2*N];
433       float tmp[N];
434       /* De-interleaving the sub-frames */
435       for (j=0;j<N;j++)
436          tmp[j] = X[B*j+i];
437       mdct_backward(&st->mdct_lookup, tmp, x);
438       for (j=0;j<2*N;j++)
439          x[j] = st->window[j]*x[j];
440       for (j=0;j<N;j++)
441          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
442       for (j=0;j<N;j++)
443          st->mdct_overlap[j] = x[N+j];
444       
445       for (j=0;j<N;j++)
446       {
447          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
448          st->preemph_memD = tmp;
449          pcm[i*N+j] = (short)floor(.5+tmp);
450       }
451    }
452    //printf ("\n");
453 }
454