More C89 fixes, making sure to include config.h from all source files.
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "os_support.h"
37 #include "mdct.h"
38 #include <math.h>
39 #include "celt.h"
40 #include "pitch.h"
41 #include "kiss_fftr.h"
42 #include "bands.h"
43 #include "modes.h"
44 #include "entcode.h"
45 #include "quant_pitch.h"
46 #include "quant_bands.h"
47 #include "psy.h"
48 #include "rate.h"
49
50 #define MAX_PERIOD 1024
51
52 #ifndef M_PI
53 #define M_PI 3.14159263
54 #endif
55
56 struct CELTEncoder {
57    const CELTMode *mode;
58    int frame_size;
59    int block_size;
60    int nb_blocks;
61    int overlap;
62    int channels;
63    int Fs;
64    
65    ec_byte_buffer buf;
66    ec_enc         enc;
67
68    float preemph;
69    float *preemph_memE;
70    float *preemph_memD;
71    
72    mdct_lookup mdct_lookup;
73    kiss_fftr_cfg fft;
74    struct PsyDecay psy;
75    
76    float *window;
77    float *in_mem;
78    float *mdct_overlap;
79    float *out_mem;
80
81    float *oldBandE;
82 };
83
84
85
86 CELTEncoder *celt_encoder_new(const CELTMode *mode)
87 {
88    int i, N, B, C, N4;
89    CELTEncoder *st;
90    N = mode->mdctSize;
91    B = mode->nbMdctBlocks;
92    C = mode->nbChannels;
93    st = celt_alloc(sizeof(CELTEncoder));
94    
95    st->mode = mode;
96    st->frame_size = B*N;
97    st->block_size = N;
98    st->nb_blocks  = B;
99    st->overlap = mode->overlap;
100    st->Fs = 44100;
101
102    N4 = (N-st->overlap)/2;
103    ec_byte_writeinit(&st->buf);
104    ec_enc_init(&st->enc,&st->buf);
105
106    mdct_init(&st->mdct_lookup, 2*N);
107    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
108    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
109    
110    st->window = celt_alloc(2*N*sizeof(float));
111    st->in_mem = celt_alloc(N*C*sizeof(float));
112    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
113    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
114    for (i=0;i<2*N;i++)
115       st->window[i] = 0;
116    for (i=0;i<st->overlap;i++)
117       st->window[N4+i] = st->window[2*N-N4-i-1] 
118             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
119    for (i=0;i<2*N4;i++)
120       st->window[N-N4+i] = 1;
121    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
122
123    st->preemph = 0.8;
124    st->preemph_memE = celt_alloc(C*sizeof(float));;
125    st->preemph_memD = celt_alloc(C*sizeof(float));;
126
127    return st;
128 }
129
130 void celt_encoder_destroy(CELTEncoder *st)
131 {
132    if (st == NULL)
133    {
134       celt_warning("NULL passed to celt_encoder_destroy");
135       return;
136    }
137    ec_byte_writeclear(&st->buf);
138
139    mdct_clear(&st->mdct_lookup);
140    kiss_fft_free(st->fft);
141    psydecay_clear(&st->psy);
142
143    celt_free(st->window);
144    celt_free(st->in_mem);
145    celt_free(st->mdct_overlap);
146    celt_free(st->out_mem);
147    
148    celt_free(st->oldBandE);
149    
150    celt_free(st->preemph_memE);
151    celt_free(st->preemph_memD);
152    
153    celt_free(st);
154 }
155
156
157 static float compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
158 {
159    int i, c;
160    float E = 1e-15;
161    VARDECL(float *x);
162    VARDECL(float *tmp);
163    ALLOC(x, 2*N, float);
164    ALLOC(tmp, N, float);
165    for (c=0;c<C;c++)
166    {
167       for (i=0;i<B;i++)
168       {
169          int j;
170          for (j=0;j<2*N;j++)
171          {
172             x[j] = window[j]*in[C*i*N+C*j+c];
173             E += x[j]*x[j];
174          }
175          mdct_forward(mdct_lookup, x, tmp);
176          /* Interleaving the sub-frames */
177          for (j=0;j<N;j++)
178             out[C*B*j+C*i+c] = tmp[j];
179       }
180    }
181    return E;
182 }
183
184 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
185 {
186    int i, c, N4;
187    VARDECL(float *x);
188    VARDECL(float *tmp);
189    ALLOC(x, 2*N, float);
190    ALLOC(tmp, N, float);
191    N4 = (N-overlap)/2;
192    for (c=0;c<C;c++)
193    {
194       for (i=0;i<B;i++)
195       {
196          int j;
197          /* De-interleaving the sub-frames */
198          for (j=0;j<N;j++)
199             tmp[j] = X[C*B*j+C*i+c];
200          mdct_backward(mdct_lookup, tmp, x);
201          for (j=0;j<2*N;j++)
202             x[j] = window[j]*x[j];
203          for (j=0;j<overlap;j++)
204             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
205          for (j=0;j<2*N4;j++)
206             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
207          for (j=0;j<overlap;j++)
208             mdct_overlap[C*j+c] = x[N+N4+j];
209       }
210    }
211 }
212
213 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
214 {
215    int i, c, N, B, C, N4;
216    int has_pitch;
217    N = st->block_size;
218    B = st->nb_blocks;
219    C = st->mode->nbChannels;
220    float in[(B+1)*C*N];
221
222    float X[B*C*N];         /**< Interleaved signal MDCTs */
223    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
224    float mask[B*C*N];      /**< Masking curve */
225    float bandE[st->mode->nbEBands*C];
226    float gains[st->mode->nbPBands];
227    int pitch_index;
228    float curr_power, pitch_power;
229    
230    N4 = (N-st->overlap)/2;
231
232    for (c=0;c<C;c++)
233    {
234       for (i=0;i<N4;i++)
235          in[C*i+c] = 0;
236       for (i=0;i<st->overlap;i++)
237          in[C*(i+N4)+c] = st->in_mem[C*i+c];
238       for (i=0;i<B*N;i++)
239       {
240          float tmp = pcm[C*i+c];
241          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
242          st->preemph_memE[c] = tmp;
243       }
244       for (i=N*(B+1)-N4;i<N*(B+1);i++)
245          in[C*i+c] = 0;
246       for (i=0;i<st->overlap;i++)
247          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
248    }
249    /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
250    /* Compute MDCTs */
251    curr_power = compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
252
253 #if 0 /* Mask disabled until it can be made to do something useful */
254    compute_mdct_masking(X, mask, B*C*N, st->Fs);
255
256    /* Invert and stretch the mask to length of X 
257       For some reason, I get better results by using the sqrt instead,
258       although there's no valid reason to. Must investigate further */
259    for (i=0;i<B*C*N;i++)
260       mask[i] = 1/(.1+mask[i]);
261 #else
262    for (i=0;i<B*C*N;i++)
263       mask[i] = 1;
264 #endif
265    /* Pitch analysis */
266    for (c=0;c<C;c++)
267    {
268       for (i=0;i<N;i++)
269       {
270          in[C*i+c] *= st->window[i];
271          in[C*(B*N+i)+c] *= st->window[N+i];
272       }
273    }
274    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
275    
276    /* Compute MDCTs of the pitch part */
277    pitch_power = compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
278    
279    /*printf ("%f %f\n", curr_power, pitch_power);*/
280    /*int j;
281    for (j=0;j<B*N;j++)
282       printf ("%f ", X[j]);
283    for (j=0;j<B*N;j++)
284       printf ("%f ", P[j]);
285    printf ("\n");*/
286
287    /* Band normalisation */
288    compute_band_energies(st->mode, X, bandE);
289    normalise_bands(st->mode, X, bandE);
290    /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
291    /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
292
293    quant_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, &st->enc);
294
295    if (C==2)
296    {
297       stereo_mix(st->mode, X, bandE, 1);
298    }
299
300    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
301    if (curr_power + 1e5f < 10.f*pitch_power)
302    {
303       /* Normalise the pitch vector as well (discard the energies) */
304       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
305       compute_band_energies(st->mode, P, bandEp);
306       normalise_bands(st->mode, P, bandEp);
307
308       if (C==2)
309          stereo_mix(st->mode, P, bandE, 1);
310       /* Simulates intensity stereo */
311       /*for (i=30;i<N*B;i++)
312          X[i*C+1] = P[i*C+1] = 0;*/
313
314       /* Pitch prediction */
315       compute_pitch_gain(st->mode, X, P, gains, bandE);
316       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
317       if (has_pitch)
318          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
319    } else {
320       /* No pitch, so we just pretend we found a gain of zero */
321       for (i=0;i<st->mode->nbPBands;i++)
322          gains[i] = 0;
323       ec_enc_uint(&st->enc, 0, 128);
324       for (i=0;i<B*C*N;i++)
325          P[i] = 0;
326    }
327    
328
329    pitch_quant_bands(st->mode, X, P, gains);
330
331    /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
332    /* Compute residual that we're going to encode */
333    for (i=0;i<B*C*N;i++)
334       X[i] -= P[i];
335
336    /*float sum=0;
337    for (i=0;i<B*N;i++)
338       sum += X[i]*X[i];
339    printf ("%f\n", sum);*/
340    /* Residual quantisation */
341    quant_bands(st->mode, X, P, mask, nbCompressedBytes*8, &st->enc);
342    
343    if (C==2)
344       stereo_mix(st->mode, X, bandE, -1);
345
346    renormalise_bands(st->mode, X);
347    /* Synthesis */
348    denormalise_bands(st->mode, X, bandE);
349
350
351    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
352
353    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
354    /* De-emphasis and put everything back at the right place in the synthesis history */
355    for (c=0;c<C;c++)
356    {
357       for (i=0;i<B;i++)
358       {
359          int j;
360          for (j=0;j<N;j++)
361          {
362             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
363             st->preemph_memD[c] = tmp;
364             if (tmp > 32767) tmp = 32767;
365             if (tmp < -32767) tmp = -32767;
366             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
367          }
368       }
369    }
370    
371    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
372       celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
373    /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
374    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
375    {
376       int val = 0;
377       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
378       {
379          ec_enc_uint(&st->enc, val, 2);
380          val = 1-val;
381       }
382    }
383    ec_enc_done(&st->enc);
384    {
385       unsigned char *data;
386       int nbBytes = ec_byte_bytes(&st->buf);
387       if (nbBytes > nbCompressedBytes)
388       {
389          celt_warning_int ("got too many bytes:", nbBytes);
390          return CELT_INTERNAL_ERROR;
391       }
392       /*printf ("%d\n", *nbBytes);*/
393       data = ec_byte_get_buffer(&st->buf);
394       for (i=0;i<nbBytes;i++)
395          compressed[i] = data[i];
396       for (;i<nbCompressedBytes;i++)
397          compressed[i] = 0;
398    }
399    /* Reset the packing for the next encoding */
400    ec_byte_reset(&st->buf);
401    ec_enc_init(&st->enc,&st->buf);
402
403    return nbCompressedBytes;
404 }
405
406
407 /****************************************************************************/
408 /*                                                                          */
409 /*                                DECODER                                   */
410 /*                                                                          */
411 /****************************************************************************/
412
413
414
415 struct CELTDecoder {
416    const CELTMode *mode;
417    int frame_size;
418    int block_size;
419    int nb_blocks;
420    int overlap;
421
422    ec_byte_buffer buf;
423    ec_enc         enc;
424
425    float preemph;
426    float *preemph_memD;
427    
428    mdct_lookup mdct_lookup;
429    
430    float *window;
431    float *mdct_overlap;
432    float *out_mem;
433
434    float *oldBandE;
435    
436    int last_pitch_index;
437 };
438
439 CELTDecoder *celt_decoder_new(const CELTMode *mode)
440 {
441    int i, N, B, C, N4;
442    N = mode->mdctSize;
443    B = mode->nbMdctBlocks;
444    C = mode->nbChannels;
445    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
446    
447    st->mode = mode;
448    st->frame_size = B*N;
449    st->block_size = N;
450    st->nb_blocks  = B;
451    st->overlap = mode->overlap;
452
453    N4 = (N-st->overlap)/2;
454    
455    mdct_init(&st->mdct_lookup, 2*N);
456    
457    st->window = celt_alloc(2*N*sizeof(float));
458    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
459    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
460
461    for (i=0;i<2*N;i++)
462       st->window[i] = 0;
463    for (i=0;i<st->overlap;i++)
464       st->window[N4+i] = st->window[2*N-N4-i-1] 
465             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
466    for (i=0;i<2*N4;i++)
467       st->window[N-N4+i] = 1;
468    
469    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
470
471    st->preemph = 0.8;
472    st->preemph_memD = celt_alloc(C*sizeof(float));;
473
474    st->last_pitch_index = 0;
475    return st;
476 }
477
478 void celt_decoder_destroy(CELTDecoder *st)
479 {
480    if (st == NULL)
481    {
482       celt_warning("NULL passed to celt_encoder_destroy");
483       return;
484    }
485
486    mdct_clear(&st->mdct_lookup);
487
488    celt_free(st->window);
489    celt_free(st->mdct_overlap);
490    celt_free(st->out_mem);
491    
492    celt_free(st->oldBandE);
493    
494    celt_free(st->preemph_memD);
495
496    celt_free(st);
497 }
498
499 static void celt_decode_lost(CELTDecoder *st, short *pcm)
500 {
501    int i, c, N, B, C;
502    N = st->block_size;
503    B = st->nb_blocks;
504    C = st->mode->nbChannels;
505    float X[C*B*N];         /**< Interleaved signal MDCTs */
506    int pitch_index;
507    
508    pitch_index = st->last_pitch_index;
509    
510    /* Use the pitch MDCT as the "guessed" signal */
511    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
512
513    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
514    /* Compute inverse MDCTs */
515    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
516
517    for (c=0;c<C;c++)
518    {
519       for (i=0;i<B;i++)
520       {
521          int j;
522          for (j=0;j<N;j++)
523          {
524             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
525             st->preemph_memD[c] = tmp;
526             if (tmp > 32767) tmp = 32767;
527             if (tmp < -32767) tmp = -32767;
528             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
529          }
530       }
531    }
532 }
533
534 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
535 {
536    int i, c, N, B, C;
537    int has_pitch;
538    N = st->block_size;
539    B = st->nb_blocks;
540    C = st->mode->nbChannels;
541    
542    float X[C*B*N];         /**< Interleaved signal MDCTs */
543    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
544    float bandE[st->mode->nbEBands*C];
545    float gains[st->mode->nbPBands];
546    int pitch_index;
547    ec_dec dec;
548    ec_byte_buffer buf;
549    
550    if (data == NULL)
551    {
552       celt_decode_lost(st, pcm);
553       return 0;
554    }
555    
556    ec_byte_readinit(&buf,data,len);
557    ec_dec_init(&dec,&buf);
558    
559    /* Get band energies */
560    unquant_energy(st->mode, bandE, st->oldBandE, len*8/3, &dec);
561    
562    /* Get the pitch gains */
563    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
564    
565    /* Get the pitch index */
566    if (has_pitch)
567    {
568       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
569       st->last_pitch_index = pitch_index;
570    } else {
571       /* FIXME: We could be more intelligent here and just not compute the MDCT */
572       pitch_index = 0;
573    }
574    
575    /* Pitch MDCT */
576    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
577
578    {
579       float bandEp[st->mode->nbEBands*C];
580       compute_band_energies(st->mode, P, bandEp);
581       normalise_bands(st->mode, P, bandEp);
582    }
583
584    if (C==2)
585       stereo_mix(st->mode, P, bandE, 1);
586
587    /* Apply pitch gains */
588    pitch_quant_bands(st->mode, X, P, gains);
589
590    /* Decode fixed codebook and merge with pitch */
591    unquant_bands(st->mode, X, P, len*8, &dec);
592
593    if (C==2)
594       stereo_mix(st->mode, X, bandE, -1);
595
596    renormalise_bands(st->mode, X);
597    
598    /* Synthesis */
599    denormalise_bands(st->mode, X, bandE);
600
601
602    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
603    /* Compute inverse MDCTs */
604    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
605
606    for (c=0;c<C;c++)
607    {
608       for (i=0;i<B;i++)
609       {
610          int j;
611          for (j=0;j<N;j++)
612          {
613             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
614             st->preemph_memD[c] = tmp;
615             if (tmp > 32767) tmp = 32767;
616             if (tmp < -32767) tmp = -32767;
617             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
618          }
619       }
620    }
621
622    {
623       int val = 0;
624       while (ec_dec_tell(&dec, 0) < len*8)
625       {
626          if (ec_dec_uint(&dec, 2) != val)
627          {
628             celt_warning("decode error");
629             return CELT_CORRUPTED_DATA;
630          }
631          val = 1-val;
632       }
633    }
634
635    return 0;
636    /*printf ("\n");*/
637 }
638