Energy is now Laplace-encoded (very poorly for now)
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41
42 #define MAX_PERIOD 1024
43
44 /* This is only for cheating until the decoder is complete */
45 float cheating_pitch_gains[100];
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53       
54    ec_byte_buffer buf;
55    ec_enc         enc;
56
57    float preemph;
58    float preemph_memE;
59    float preemph_memD;
60    
61    mdct_lookup mdct_lookup;
62    void *fft;
63    
64    float *window;
65    float *in_mem;
66    float *mdct_overlap;
67    float *out_mem;
68
69    float *oldBandE;
70 };
71
72
73
74 CELTEncoder *celt_encoder_new(const CELTMode *mode)
75 {
76    int i, N, B;
77    N = mode->mdctSize;
78    B = mode->nbMdctBlocks;
79    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
80    
81    st->mode = mode;
82    st->frame_size = B*N;
83    st->block_size = N;
84    st->nb_blocks  = B;
85    
86    ec_byte_writeinit(&st->buf);
87    ec_enc_init(&st->enc,&st->buf);
88
89    mdct_init(&st->mdct_lookup, 2*N);
90    st->fft = spx_fft_init(MAX_PERIOD);
91    
92    st->window = celt_alloc(2*N*sizeof(float));
93    st->in_mem = celt_alloc(N*sizeof(float));
94    st->mdct_overlap = celt_alloc(N*sizeof(float));
95    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
96    for (i=0;i<N;i++)
97       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
98    
99    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
100
101    st->preemph = 0.8;
102    return st;
103 }
104
105 void celt_encoder_destroy(CELTEncoder *st)
106 {
107    if (st == NULL)
108    {
109       celt_warning("NULL passed to celt_encoder_destroy");
110       return;
111    }
112    ec_byte_writeclear(&st->buf);
113
114    mdct_clear(&st->mdct_lookup);
115    spx_fft_destroy(st->fft);
116
117    celt_free(st->window);
118    celt_free(st->in_mem);
119    celt_free(st->mdct_overlap);
120    celt_free(st->out_mem);
121    
122    celt_free(st->oldBandE);
123    celt_free(st);
124 }
125
126 static void haar1(float *X, int N)
127 {
128    int i;
129    for (i=0;i<N;i+=2)
130    {
131       float a, b;
132       a = X[i];
133       b = X[i+1];
134       X[i] = .707107f*(a+b);
135       X[i+1] = .707107f*(a-b);
136    }
137 }
138
139 static void inv_haar1(float *X, int N)
140 {
141    int i;
142    for (i=0;i<N;i+=2)
143    {
144       float a, b;
145       a = X[i];
146       b = X[i+1];
147       X[i] = .707107f*(a+b);
148       X[i+1] = .707107f*(a-b);
149    }
150 }
151
152 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
153 {
154    int i;
155    for (i=0;i<B;i++)
156    {
157       int j;
158       float x[2*N];
159       float tmp[N];
160       for (j=0;j<2*N;j++)
161          x[j] = window[j]*in[i*N+j];
162       mdct_forward(mdct_lookup, x, tmp);
163       /* Interleaving the sub-frames */
164       for (j=0;j<N;j++)
165          out[B*j+i] = tmp[j];
166    }
167
168 }
169
170 int celt_encode(CELTEncoder *st, short *pcm)
171 {
172    int i, N, B;
173    N = st->block_size;
174    B = st->nb_blocks;
175    float in[(B+1)*N];
176    
177    float X[B*N];         /**< Interleaved signal MDCTs */
178    float P[B*N];         /**< Interleaved pitch MDCTs*/
179    float bandE[st->mode->nbEBands];
180    float gains[st->mode->nbPBands];
181    int pitch_index;
182    
183    for (i=0;i<N;i++)
184       in[i] = st->in_mem[i];
185    for (;i<(B+1)*N;i++)
186    {
187       float tmp = pcm[i-N];
188       in[i] = tmp - st->preemph*st->preemph_memE;
189       st->preemph_memE = tmp;
190    }
191    for (i=0;i<N;i++)
192       st->in_mem[i] = in[B*N+i];
193
194    /* Compute MDCTs */
195    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
196    
197    /* Pitch analysis */
198    for (i=0;i<N;i++)
199    {
200       in[i] *= st->window[i];
201       in[B*N+i] *= st->window[N+i];
202    }
203    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
204    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
205    
206    /* Compute MDCTs of the pitch part */
207    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
208    
209    /*int j;
210    for (j=0;j<B*N;j++)
211       printf ("%f ", X[j]);
212    for (j=0;j<B*N;j++)
213       printf ("%f ", P[j]);
214    printf ("\n");*/
215    //haar1(X, B*N);
216    //haar1(P, B*N);
217    
218    /* Band normalisation */
219    compute_band_energies(st->mode, X, bandE);
220    normalise_bands(st->mode, X, bandE);
221    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
222    
223    {
224       float bandEp[st->mode->nbEBands];
225       compute_band_energies(st->mode, P, bandEp);
226       normalise_bands(st->mode, P, bandEp);
227    }
228    
229    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
230    
231    /* Pitch prediction */
232    compute_pitch_gain(st->mode, X, P, gains, bandE);
233    //quantise_pitch(gains, PBANDS);
234    pitch_quant_bands(st->mode, X, P, gains);
235    
236    for (i=0;i<st->mode->nbPBands;i++)
237       cheating_pitch_gains[i] = gains[i];
238
239    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
240    /* Subtract the pitch prediction from the signal to encode */
241    for (i=0;i<B*N;i++)
242       X[i] -= P[i];
243
244    /*float sum=0;
245    for (i=0;i<B*N;i++)
246       sum += X[i]*X[i];
247    printf ("%f\n", sum);*/
248    /* Residual quantisation */
249    quant_bands(st->mode, X, P, &st->enc);
250    
251    if (0) {//This is just for debugging
252       ec_enc_done(&st->enc);
253       ec_dec dec;
254       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
255       ec_dec_init(&dec,&st->buf);
256
257       unquant_bands(st->mode, X, P, &dec);
258       //printf ("\n");
259    }
260    
261    /* Synthesis */
262    denormalise_bands(st->mode, X, bandE);
263
264    //inv_haar1(X, B*N);
265
266    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
267    /* Compute inverse MDCTs */
268    for (i=0;i<B;i++)
269    {
270       int j;
271       float x[2*N];
272       float tmp[N];
273       /* De-interleaving the sub-frames */
274       for (j=0;j<N;j++)
275          tmp[j] = X[B*j+i];
276       mdct_backward(&st->mdct_lookup, tmp, x);
277       for (j=0;j<2*N;j++)
278          x[j] = st->window[j]*x[j];
279       for (j=0;j<N;j++)
280          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
281       for (j=0;j<N;j++)
282          st->mdct_overlap[j] = x[N+j];
283       
284       for (j=0;j<N;j++)
285       {
286          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
287          st->preemph_memD = tmp;
288          pcm[i*N+j] = (short)floor(.5+tmp);
289       }
290    }
291    return 0;
292 }
293
294 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
295 {
296    char *data;
297    ec_enc_done(&st->enc);
298    *nbBytes = ec_byte_bytes(&st->buf);
299    data = ec_byte_get_buffer(&st->buf);
300    //printf ("%d\n", *nbBytes);
301    
302    /* Reset the packing for the next encoding */
303    ec_byte_reset(&st->buf);
304    ec_enc_init(&st->enc,&st->buf);
305
306    return data;
307 }
308
309
310 /****************************************************************************/
311 /*                                                                          */
312 /*                                DECODER                                   */
313 /*                                                                          */
314 /****************************************************************************/
315
316
317
318 struct CELTDecoder {
319    const CELTMode *mode;
320    int frame_size;
321    int block_size;
322    int nb_blocks;
323    
324    ec_byte_buffer buf;
325    ec_enc         enc;
326
327    float preemph;
328    float preemph_memD;
329    
330    mdct_lookup mdct_lookup;
331    
332    float *window;
333    float *mdct_overlap;
334    float *out_mem;
335
336    float *oldBandE;
337 };
338
339 CELTDecoder *celt_decoder_new(const CELTMode *mode)
340 {
341    int i, N, B;
342    N = mode->mdctSize;
343    B = mode->nbMdctBlocks;
344    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
345    
346    st->mode = mode;
347    st->frame_size = B*N;
348    st->block_size = N;
349    st->nb_blocks  = B;
350    
351    mdct_init(&st->mdct_lookup, 2*N);
352    
353    st->window = celt_alloc(2*N*sizeof(float));
354    st->mdct_overlap = celt_alloc(N*sizeof(float));
355    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
356    for (i=0;i<N;i++)
357       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
358    
359    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
360
361    st->preemph = 0.8;
362    return st;
363 }
364
365 void celt_decoder_destroy(CELTDecoder *st)
366 {
367    if (st == NULL)
368    {
369       celt_warning("NULL passed to celt_encoder_destroy");
370       return;
371    }
372
373    mdct_clear(&st->mdct_lookup);
374
375    celt_free(st->window);
376    celt_free(st->mdct_overlap);
377    celt_free(st->out_mem);
378    
379    celt_free(st->oldBandE);
380    celt_free(st);
381 }
382
383 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
384 {
385    int i, N, B;
386    N = st->block_size;
387    B = st->nb_blocks;
388    
389    float X[B*N];         /**< Interleaved signal MDCTs */
390    float P[B*N];         /**< Interleaved pitch MDCTs*/
391    float bandE[st->mode->nbEBands];
392    float gains[st->mode->nbPBands];
393    int pitch_index;
394    ec_dec dec;
395    ec_byte_buffer buf;
396    
397    ec_byte_readinit(&buf,data,len);
398    ec_dec_init(&dec,&buf);
399    
400    /* Get the pitch index */
401    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);;
402    
403    /* Get band energies */
404    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
405    
406    /* Pitch MDCT */
407    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
408
409    //haar1(P, B*N);
410
411    {
412       float bandEp[st->mode->nbEBands];
413       compute_band_energies(st->mode, P, bandEp);
414       normalise_bands(st->mode, P, bandEp);
415    }
416
417    /* Get the pitch gains */
418    for (i=0;i<st->mode->nbPBands;i++)
419       gains[i] = cheating_pitch_gains[i];
420
421    /* Apply pitch gains */
422    pitch_quant_bands(st->mode, X, P, gains);
423
424    /* Decode fixed codebook and merge with pitch */
425    unquant_bands(st->mode, X, P, &dec);
426
427    /* Synthesis */
428    denormalise_bands(st->mode, X, bandE);
429
430    //inv_haar1(X, B*N);
431
432    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
433    /* Compute inverse MDCTs */
434    for (i=0;i<B;i++)
435    {
436       int j;
437       float x[2*N];
438       float tmp[N];
439       /* De-interleaving the sub-frames */
440       for (j=0;j<N;j++)
441          tmp[j] = X[B*j+i];
442       mdct_backward(&st->mdct_lookup, tmp, x);
443       for (j=0;j<2*N;j++)
444          x[j] = st->window[j]*x[j];
445       for (j=0;j<N;j++)
446          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
447       for (j=0;j<N;j++)
448          st->mdct_overlap[j] = x[N+j];
449       
450       for (j=0;j<N;j++)
451       {
452          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
453          st->preemph_memD = tmp;
454          pcm[i*N+j] = (short)floor(.5+tmp);
455       }
456    }
457    //printf ("\n");
458 }
459