Algebraic codebook decoding (not tested yet)
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41
42 #define MAX_PERIOD 1024
43
44 struct CELTEncoder {
45    const CELTMode *mode;
46    int frame_size;
47    int block_size;
48    int nb_blocks;
49       
50    ec_byte_buffer buf;
51    ec_enc         enc;
52
53    float preemph;
54    float preemph_memE;
55    float preemph_memD;
56    
57    mdct_lookup mdct_lookup;
58    void *fft;
59    
60    float *window;
61    float *in_mem;
62    float *mdct_overlap;
63    float *out_mem;
64
65    float *oldBandE;
66 };
67
68
69
70 CELTEncoder *celt_encoder_new(const CELTMode *mode)
71 {
72    int i, N, B;
73    N = mode->mdctSize;
74    B = mode->nbMdctBlocks;
75    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
76    
77    st->mode = mode;
78    st->frame_size = B*N;
79    st->block_size = N;
80    st->nb_blocks  = B;
81    
82    ec_byte_writeinit(&st->buf);
83
84    mdct_init(&st->mdct_lookup, 2*N);
85    st->fft = spx_fft_init(MAX_PERIOD);
86    
87    st->window = celt_alloc(2*N*sizeof(float));
88    st->in_mem = celt_alloc(N*sizeof(float));
89    st->mdct_overlap = celt_alloc(N*sizeof(float));
90    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
91    for (i=0;i<N;i++)
92       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
93    
94    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
95
96    st->preemph = 0.8;
97    return st;
98 }
99
100 void celt_encoder_destroy(CELTEncoder *st)
101 {
102    if (st == NULL)
103    {
104       celt_warning("NULL passed to celt_encoder_destroy");
105       return;
106    }
107    ec_byte_writeclear(&st->buf);
108
109    mdct_clear(&st->mdct_lookup);
110    spx_fft_destroy(st->fft);
111
112    celt_free(st->window);
113    celt_free(st->in_mem);
114    celt_free(st->mdct_overlap);
115    celt_free(st->out_mem);
116    
117    celt_free(st->oldBandE);
118    celt_free(st);
119 }
120
121 static void haar1(float *X, int N)
122 {
123    int i;
124    for (i=0;i<N;i+=2)
125    {
126       float a, b;
127       a = X[i];
128       b = X[i+1];
129       X[i] = .707107f*(a+b);
130       X[i+1] = .707107f*(a-b);
131    }
132 }
133
134 static void inv_haar1(float *X, int N)
135 {
136    int i;
137    for (i=0;i<N;i+=2)
138    {
139       float a, b;
140       a = X[i];
141       b = X[i+1];
142       X[i] = .707107f*(a+b);
143       X[i+1] = .707107f*(a-b);
144    }
145 }
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B)
148 {
149    int i;
150    for (i=0;i<B;i++)
151    {
152       int j;
153       float x[2*N];
154       float tmp[N];
155       for (j=0;j<2*N;j++)
156          x[j] = window[j]*in[i*N+j];
157       mdct_forward(mdct_lookup, x, tmp);
158       /* Interleaving the sub-frames */
159       for (j=0;j<N;j++)
160          out[B*j+i] = tmp[j];
161    }
162
163 }
164
165 int celt_encode(CELTEncoder *st, short *pcm)
166 {
167    int i, N, B;
168    N = st->block_size;
169    B = st->nb_blocks;
170    float in[(B+1)*N];
171    
172    float X[B*N];         /**< Interleaved signal MDCTs */
173    float P[B*N];         /**< Interleaved pitch MDCTs*/
174    float bandE[st->mode->nbEBands];
175    float gains[st->mode->nbPBands];
176    int pitch_index;
177    
178
179    ec_byte_reset(&st->buf);
180    ec_enc_init(&st->enc,&st->buf);
181
182    for (i=0;i<N;i++)
183       in[i] = st->in_mem[i];
184    for (;i<(B+1)*N;i++)
185    {
186       float tmp = pcm[i-N];
187       in[i] = tmp - st->preemph*st->preemph_memE;
188       st->preemph_memE = tmp;
189    }
190    for (i=0;i<N;i++)
191       st->in_mem[i] = in[B*N+i];
192
193    /* Compute MDCTs */
194    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B);
195    
196    /* Pitch analysis */
197    for (i=0;i<N;i++)
198    {
199       in[i] *= st->window[i];
200       in[B*N+i] *= st->window[N+i];
201    }
202    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, &pitch_index);
203    
204    /* Compute MDCTs of the pitch part */
205    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
206    
207    /*int j;
208    for (j=0;j<B*N;j++)
209       printf ("%f ", X[j]);
210    for (j=0;j<B*N;j++)
211       printf ("%f ", P[j]);
212    printf ("\n");*/
213    //haar1(X, B*N);
214    //haar1(P, B*N);
215    
216    /* Band normalisation */
217    compute_band_energies(st->mode, X, bandE);
218    normalise_bands(st->mode, X, bandE);
219    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
220    
221    {
222       float bandEp[st->mode->nbEBands];
223       compute_band_energies(st->mode, P, bandEp);
224       normalise_bands(st->mode, P, bandEp);
225    }
226    
227    quant_energy(st->mode, bandE, st->oldBandE);
228    
229    /* Pitch prediction */
230    compute_pitch_gain(st->mode, X, P, gains, bandE);
231    //quantise_pitch(gains, PBANDS);
232    pitch_quant_bands(st->mode, X, P, gains);
233    
234    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
235    /* Subtract the pitch prediction from the signal to encode */
236    for (i=0;i<B*N;i++)
237       X[i] -= P[i];
238
239    /*float sum=0;
240    for (i=0;i<B*N;i++)
241       sum += X[i]*X[i];
242    printf ("%f\n", sum);*/
243    /* Residual quantisation */
244    quant_bands(st->mode, X, P, &st->enc);
245    
246    /* Synthesis */
247    denormalise_bands(st->mode, X, bandE);
248
249    //inv_haar1(X, B*N);
250
251    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
252    /* Compute inverse MDCTs */
253    for (i=0;i<B;i++)
254    {
255       int j;
256       float x[2*N];
257       float tmp[N];
258       /* De-interleaving the sub-frames */
259       for (j=0;j<N;j++)
260          tmp[j] = X[B*j+i];
261       mdct_backward(&st->mdct_lookup, tmp, x);
262       for (j=0;j<2*N;j++)
263          x[j] = st->window[j]*x[j];
264       for (j=0;j<N;j++)
265          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
266       for (j=0;j<N;j++)
267          st->mdct_overlap[j] = x[N+j];
268       
269       for (j=0;j<N;j++)
270       {
271          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
272          st->preemph_memD = tmp;
273          pcm[i*N+j] = (short)floor(.5+tmp);
274       }
275    }
276    ec_enc_done(&st->enc);
277    //printf ("%d\n", ec_byte_bytes(&st->buf));
278    return 0;
279 }
280
281 /****************************************************************************/
282 /*                                Decoder                                   */
283 /****************************************************************************/
284
285
286
287 struct CELTDecoder {
288    const CELTMode *mode;
289    int frame_size;
290    int block_size;
291    int nb_blocks;
292    
293    ec_byte_buffer buf;
294    ec_enc         enc;
295
296    float preemph;
297    float preemph_memD;
298    
299    mdct_lookup mdct_lookup;
300    
301    float *window;
302    float *mdct_overlap;
303    float *out_mem;
304
305    float *oldBandE;
306 };
307
308 CELTDecoder *celt_decoder_new(const CELTMode *mode)
309 {
310    int i, N, B;
311    N = mode->mdctSize;
312    B = mode->nbMdctBlocks;
313    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
314    
315    st->mode = mode;
316    st->frame_size = B*N;
317    st->block_size = N;
318    st->nb_blocks  = B;
319    
320    ec_byte_writeinit(&st->buf);
321
322    mdct_init(&st->mdct_lookup, 2*N);
323    
324    st->window = celt_alloc(2*N*sizeof(float));
325    st->mdct_overlap = celt_alloc(N*sizeof(float));
326    st->out_mem = celt_alloc(MAX_PERIOD*sizeof(float));
327    for (i=0;i<N;i++)
328       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
329    
330    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
331
332    st->preemph = 0.8;
333    return st;
334 }
335
336 void celt_decoder_destroy(CELTDecoder *st)
337 {
338    if (st == NULL)
339    {
340       celt_warning("NULL passed to celt_encoder_destroy");
341       return;
342    }
343
344    mdct_clear(&st->mdct_lookup);
345
346    celt_free(st->window);
347    celt_free(st->mdct_overlap);
348    celt_free(st->out_mem);
349    
350    celt_free(st->oldBandE);
351    celt_free(st);
352 }
353
354 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
355 {
356    int i, N, B;
357    N = st->block_size;
358    B = st->nb_blocks;
359    
360    float X[B*N];         /**< Interleaved signal MDCTs */
361    float P[B*N];         /**< Interleaved pitch MDCTs*/
362    float bandE[st->mode->nbEBands];
363    float gains[st->mode->nbPBands];
364    int pitch_index;
365
366    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index, P, N, B);
367
368    //haar1(P, B*N);
369
370    {
371       float bandEp[st->mode->nbEBands];
372       compute_band_energies(st->mode, P, bandEp);
373       normalise_bands(st->mode, P, bandEp);
374    }
375
376    /* Apply pitch gains */
377    
378    /* Decode fixed codebook */
379    
380    /* Merge pitch and fixed codebook */
381    
382    /* Synthesis */
383    denormalise_bands(st->mode, X, bandE);
384
385    //inv_haar1(X, B*N);
386
387    CELT_MOVE(st->out_mem, st->out_mem+B*N, MAX_PERIOD-B*N);
388    /* Compute inverse MDCTs */
389    for (i=0;i<B;i++)
390    {
391       int j;
392       float x[2*N];
393       float tmp[N];
394       /* De-interleaving the sub-frames */
395       for (j=0;j<N;j++)
396          tmp[j] = X[B*j+i];
397       mdct_backward(&st->mdct_lookup, tmp, x);
398       for (j=0;j<2*N;j++)
399          x[j] = st->window[j]*x[j];
400       for (j=0;j<N;j++)
401          st->out_mem[MAX_PERIOD+(i-B)*N+j] = x[j]+st->mdct_overlap[j];
402       for (j=0;j<N;j++)
403          st->mdct_overlap[j] = x[N+j];
404       
405       for (j=0;j<N;j++)
406       {
407          float tmp = st->out_mem[MAX_PERIOD+(i-B)*N+j] + st->preemph*st->preemph_memD;
408          st->preemph_memD = tmp;
409          pcm[i*N+j] = (short)floor(.5+tmp);
410       }
411    }
412 }
413