Working on some stability issues (appears to be solved by making the pitch
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43
44 #define MAX_PERIOD 1024
45
46
47 struct CELTEncoder {
48    const CELTMode *mode;
49    int frame_size;
50    int block_size;
51    int nb_blocks;
52    int channels;
53    
54    ec_byte_buffer buf;
55    ec_enc         enc;
56
57    float preemph;
58    float *preemph_memE;
59    float *preemph_memD;
60    
61    mdct_lookup mdct_lookup;
62    void *fft;
63    
64    float *window;
65    float *in_mem;
66    float *mdct_overlap;
67    float *out_mem;
68
69    float *oldBandE;
70 };
71
72
73
74 CELTEncoder *celt_encoder_new(const CELTMode *mode)
75 {
76    int i, N, B, C;
77    N = mode->mdctSize;
78    B = mode->nbMdctBlocks;
79    C = mode->nbChannels;
80    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
81    
82    st->mode = mode;
83    st->frame_size = B*N;
84    st->block_size = N;
85    st->nb_blocks  = B;
86    
87    ec_byte_writeinit(&st->buf);
88    ec_enc_init(&st->enc,&st->buf);
89
90    mdct_init(&st->mdct_lookup, 2*N);
91    st->fft = spx_fft_init(MAX_PERIOD*C);
92    
93    st->window = celt_alloc(2*N*sizeof(float));
94    st->in_mem = celt_alloc(N*C*sizeof(float));
95    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
96    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
97    for (i=0;i<N;i++)
98       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
99    
100    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
101
102    st->preemph = 0.8;
103    st->preemph_memE = celt_alloc(C*sizeof(float));;
104    st->preemph_memD = celt_alloc(C*sizeof(float));;
105
106    return st;
107 }
108
109 void celt_encoder_destroy(CELTEncoder *st)
110 {
111    if (st == NULL)
112    {
113       celt_warning("NULL passed to celt_encoder_destroy");
114       return;
115    }
116    ec_byte_writeclear(&st->buf);
117
118    mdct_clear(&st->mdct_lookup);
119    spx_fft_destroy(st->fft);
120
121    celt_free(st->window);
122    celt_free(st->in_mem);
123    celt_free(st->mdct_overlap);
124    celt_free(st->out_mem);
125    
126    celt_free(st->oldBandE);
127    celt_free(st);
128 }
129
130 static void haar1(float *X, int N)
131 {
132    int i;
133    for (i=0;i<N;i+=2)
134    {
135       float a, b;
136       a = X[i];
137       b = X[i+1];
138       X[i] = .707107f*(a+b);
139       X[i+1] = .707107f*(a-b);
140    }
141 }
142
143 static void inv_haar1(float *X, int N)
144 {
145    int i;
146    for (i=0;i<N;i+=2)
147    {
148       float a, b;
149       a = X[i];
150       b = X[i+1];
151       X[i] = .707107f*(a+b);
152       X[i+1] = .707107f*(a-b);
153    }
154 }
155
156 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
157 {
158    int i, c;
159    for (c=0;c<C;c++)
160    {
161       for (i=0;i<B;i++)
162       {
163          int j;
164          float x[2*N];
165          float tmp[N];
166          for (j=0;j<2*N;j++)
167             x[j] = window[j]*in[C*i*N+C*j+c];
168          mdct_forward(mdct_lookup, x, tmp);
169          /* Interleaving the sub-frames */
170          for (j=0;j<N;j++)
171             out[C*B*j+C*i+c] = tmp[j];
172       }
173    }
174 }
175
176 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int B, int C)
177 {
178    int i, c;
179    for (c=0;c<C;c++)
180    {
181       for (i=0;i<B;i++)
182       {
183          int j;
184          float x[2*N];
185          float tmp[N];
186          /* De-interleaving the sub-frames */
187          for (j=0;j<N;j++)
188             tmp[j] = X[C*B*j+C*i+c];
189          mdct_backward(mdct_lookup, tmp, x);
190          for (j=0;j<2*N;j++)
191             x[j] = window[j]*x[j];
192          for (j=0;j<N;j++)
193             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[j]+mdct_overlap[C*j+c];
194          for (j=0;j<N;j++)
195             mdct_overlap[C*j+c] = x[N+j];
196       }
197    }
198 }
199
200 int celt_encode(CELTEncoder *st, short *pcm)
201 {
202    int i, c, N, B, C;
203    N = st->block_size;
204    B = st->nb_blocks;
205    C = st->mode->nbChannels;
206    float in[(B+1)*C*N];
207    
208    float X[B*C*N];         /**< Interleaved signal MDCTs */
209    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
210    float bandE[st->mode->nbEBands];
211    float gains[st->mode->nbPBands];
212    int pitch_index;
213    
214    for (c=0;c<C;c++)
215    {
216       for (i=0;i<N;i++)
217          in[C*i+c] = st->in_mem[C*i+c];
218       for (;i<(B+1)*N;i++)
219       {
220          float tmp = pcm[C*(i-N)+c];
221          in[C*i+c] = tmp - st->preemph*st->preemph_memE[c];
222          st->preemph_memE[c] = tmp;
223       }
224       for (i=0;i<N;i++)
225          st->in_mem[C*i+c] = in[C*(B*N+i)+c];
226    }
227    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
228    /* Compute MDCTs */
229    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
230    
231    /* Pitch analysis */
232    for (c=0;c<C;c++)
233    {
234       for (i=0;i<N;i++)
235       {
236          in[C*i+c] *= st->window[i];
237          in[C*(B*N+i)+c] *= st->window[N+i];
238       }
239    }
240    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
241    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
242    
243    /* Compute MDCTs of the pitch part */
244    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
245    
246    /*int j;
247    for (j=0;j<B*N;j++)
248       printf ("%f ", X[j]);
249    for (j=0;j<B*N;j++)
250       printf ("%f ", P[j]);
251    printf ("\n");*/
252    if (C==2)
253    {
254       haar1(X, B*N);
255       haar1(P, B*N);
256    }
257    
258    /* Band normalisation */
259    compute_band_energies(st->mode, X, bandE);
260    normalise_bands(st->mode, X, bandE);
261    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
262    
263    {
264       float bandEp[st->mode->nbEBands];
265       compute_band_energies(st->mode, P, bandEp);
266       normalise_bands(st->mode, P, bandEp);
267    }
268    
269    band_rotation(st->mode, X, -1);
270    band_rotation(st->mode, P, -1);
271    
272    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
273    
274    /* Pitch prediction */
275    compute_pitch_gain(st->mode, X, P, gains, bandE);
276    quant_pitch(gains, st->mode->nbPBands, &st->enc);
277    pitch_quant_bands(st->mode, X, P, gains);
278
279    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
280    /* Subtract the pitch prediction from the signal to encode */
281    for (i=0;i<B*C*N;i++)
282       X[i] -= P[i];
283
284    /*float sum=0;
285    for (i=0;i<B*N;i++)
286       sum += X[i]*X[i];
287    printf ("%f\n", sum);*/
288    /* Residual quantisation */
289    quant_bands(st->mode, X, P, &st->enc);
290    
291    if (0) {//This is just for debugging
292       ec_enc_done(&st->enc);
293       ec_dec dec;
294       ec_byte_readinit(&st->buf,ec_byte_get_buffer(&st->buf),ec_byte_bytes(&st->buf));
295       ec_dec_init(&dec,&st->buf);
296
297       unquant_bands(st->mode, X, P, &dec);
298       //printf ("\n");
299    }
300    
301    band_rotation(st->mode, X, 1);
302
303    /* Synthesis */
304    denormalise_bands(st->mode, X, bandE);
305
306    if (C==2)
307       inv_haar1(X, B*N);
308
309    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
310    /* Compute inverse MDCTs */
311    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
312
313    for (c=0;c<C;c++)
314    {
315       for (i=0;i<B;i++)
316       {
317          int j;
318          for (j=0;j<N;j++)
319          {
320             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
321             st->preemph_memD[c] = tmp;
322             if (tmp > 32767) tmp = 32767;
323             if (tmp < -32767) tmp = -32767;
324             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
325          }
326       }
327    }
328    return 0;
329 }
330
331 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
332 {
333    char *data;
334    ec_enc_done(&st->enc);
335    *nbBytes = ec_byte_bytes(&st->buf);
336    data = ec_byte_get_buffer(&st->buf);
337    //printf ("%d\n", *nbBytes);
338    
339    /* Reset the packing for the next encoding */
340    ec_byte_reset(&st->buf);
341    ec_enc_init(&st->enc,&st->buf);
342
343    return data;
344 }
345
346
347 /****************************************************************************/
348 /*                                                                          */
349 /*                                DECODER                                   */
350 /*                                                                          */
351 /****************************************************************************/
352
353
354
355 struct CELTDecoder {
356    const CELTMode *mode;
357    int frame_size;
358    int block_size;
359    int nb_blocks;
360    
361    ec_byte_buffer buf;
362    ec_enc         enc;
363
364    float preemph;
365    float *preemph_memD;
366    
367    mdct_lookup mdct_lookup;
368    
369    float *window;
370    float *mdct_overlap;
371    float *out_mem;
372
373    float *oldBandE;
374    
375    int last_pitch_index;
376 };
377
378 CELTDecoder *celt_decoder_new(const CELTMode *mode)
379 {
380    int i, N, B, C;
381    N = mode->mdctSize;
382    B = mode->nbMdctBlocks;
383    C = mode->nbChannels;
384    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
385    
386    st->mode = mode;
387    st->frame_size = B*N;
388    st->block_size = N;
389    st->nb_blocks  = B;
390    
391    mdct_init(&st->mdct_lookup, 2*N);
392    
393    st->window = celt_alloc(2*N*sizeof(float));
394    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
395    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
396    for (i=0;i<N;i++)
397       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
398    
399    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
400
401    st->preemph = 0.8;
402    st->preemph_memD = celt_alloc(C*sizeof(float));;
403
404    st->last_pitch_index = 0;
405    return st;
406 }
407
408 void celt_decoder_destroy(CELTDecoder *st)
409 {
410    if (st == NULL)
411    {
412       celt_warning("NULL passed to celt_encoder_destroy");
413       return;
414    }
415
416    mdct_clear(&st->mdct_lookup);
417
418    celt_free(st->window);
419    celt_free(st->mdct_overlap);
420    celt_free(st->out_mem);
421    
422    celt_free(st->oldBandE);
423    celt_free(st);
424 }
425
426 static void celt_decode_lost(CELTDecoder *st, short *pcm)
427 {
428    int i, c, N, B, C;
429    N = st->block_size;
430    B = st->nb_blocks;
431    C = st->mode->nbChannels;
432    float X[C*B*N];         /**< Interleaved signal MDCTs */
433    int pitch_index;
434    
435    pitch_index = st->last_pitch_index;
436    
437    /* Use the pitch MDCT as the "guessed" signal */
438    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
439
440    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
441    /* Compute inverse MDCTs */
442    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
443
444    for (c=0;c<C;c++)
445    {
446       for (i=0;i<B;i++)
447       {
448          int j;
449          for (j=0;j<N;j++)
450          {
451             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
452             st->preemph_memD[c] = tmp;
453             if (tmp > 32767) tmp = 32767;
454             if (tmp < -32767) tmp = -32767;
455             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
456          }
457       }
458    }
459 }
460
461 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
462 {
463    int i, c, N, B, C;
464    N = st->block_size;
465    B = st->nb_blocks;
466    C = st->mode->nbChannels;
467    
468    float X[C*B*N];         /**< Interleaved signal MDCTs */
469    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
470    float bandE[st->mode->nbEBands];
471    float gains[st->mode->nbPBands];
472    int pitch_index;
473    ec_dec dec;
474    ec_byte_buffer buf;
475    
476    if (data == NULL)
477    {
478       celt_decode_lost(st, pcm);
479       return 0;
480    }
481    
482    ec_byte_readinit(&buf,data,len);
483    ec_dec_init(&dec,&buf);
484    
485    /* Get the pitch index */
486    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
487    st->last_pitch_index = pitch_index;
488    
489    /* Get band energies */
490    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
491    
492    /* Pitch MDCT */
493    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
494
495    if (C==2)
496       haar1(P, B*N);
497
498    {
499       float bandEp[st->mode->nbEBands];
500       compute_band_energies(st->mode, P, bandEp);
501       normalise_bands(st->mode, P, bandEp);
502    }
503    band_rotation(st->mode, P, -1);
504
505    /* Get the pitch gains */
506    unquant_pitch(gains, st->mode->nbPBands, &dec);
507
508    /* Apply pitch gains */
509    pitch_quant_bands(st->mode, X, P, gains);
510
511    /* Decode fixed codebook and merge with pitch */
512    unquant_bands(st->mode, X, P, &dec);
513
514    band_rotation(st->mode, X, 1);
515    
516    /* Synthesis */
517    denormalise_bands(st->mode, X, bandE);
518
519    if (C==2)
520       inv_haar1(X, B*N);
521
522    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
523    /* Compute inverse MDCTs */
524    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
525
526    for (c=0;c<C;c++)
527    {
528       for (i=0;i<B;i++)
529       {
530          int j;
531          for (j=0;j<N;j++)
532          {
533             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
534             st->preemph_memD[c] = tmp;
535             if (tmp > 32767) tmp = 32767;
536             if (tmp < -32767) tmp = -32767;
537             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
538          }
539       }
540    }
541    return 0;
542    //printf ("\n");
543 }
544