Moved the content of libentcode into libcelt to reduce dependencies,
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146 static void haar1(float *X, int N, int stride)
147 {
148    int i, k;
149    for (k=0;k<stride;k++)
150    {
151       for (i=k;i<N*stride;i+=2*stride)
152       {
153          float a, b;
154          a = X[i];
155          b = X[i+stride];
156          X[i] = .707107f*(a+b);
157          X[i+stride] = .707107f*(a-b);
158       }
159    }
160 }
161
162 static void time_dct(float *X, int N, int B, int stride)
163 {
164    switch (B)
165    {
166       case 1:
167          break;
168       case 2:
169          haar1(X, B*N, stride);
170          break;
171       default:
172          celt_warning("time_dct not defined for B > 2");
173    };
174 }
175
176 static void time_idct(float *X, int N, int B, int stride)
177 {
178    switch (B)
179    {
180       case 1:
181          break;
182       case 2:
183          haar1(X, B*N, stride);
184          break;
185       default:
186          celt_warning("time_dct not defined for B > 2");
187    };
188 }
189
190 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
191 {
192    int i, c;
193    for (c=0;c<C;c++)
194    {
195       for (i=0;i<B;i++)
196       {
197          int j;
198          float x[2*N];
199          float tmp[N];
200          for (j=0;j<2*N;j++)
201             x[j] = window[j]*in[C*i*N+C*j+c];
202          mdct_forward(mdct_lookup, x, tmp);
203          /* Interleaving the sub-frames */
204          for (j=0;j<N;j++)
205             out[C*B*j+C*i+c] = tmp[j];
206       }
207    }
208 }
209
210 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
211 {
212    int i, c, N4;
213    N4 = (N-overlap)/2;
214    for (c=0;c<C;c++)
215    {
216       for (i=0;i<B;i++)
217       {
218          int j;
219          float x[2*N];
220          float tmp[N];
221          /* De-interleaving the sub-frames */
222          for (j=0;j<N;j++)
223             tmp[j] = X[C*B*j+C*i+c];
224          mdct_backward(mdct_lookup, tmp, x);
225          for (j=0;j<2*N;j++)
226             x[j] = window[j]*x[j];
227          for (j=0;j<overlap;j++)
228             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
229          for (j=0;j<2*N4;j++)
230             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
231          for (j=0;j<overlap;j++)
232             mdct_overlap[C*j+c] = x[N+N4+j];
233       }
234    }
235 }
236
237 int celt_encode(CELTEncoder *st, short *pcm, unsigned char *compressed, int nbCompressedBytes)
238 {
239    int i, c, N, B, C, N4;
240    N = st->block_size;
241    B = st->nb_blocks;
242    C = st->mode->nbChannels;
243    float in[(B+1)*C*N];
244
245    float X[B*C*N];         /**< Interleaved signal MDCTs */
246    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
247    float mask[B*C*N];      /**< Masking curve */
248    float bandE[st->mode->nbEBands*C];
249    float gains[st->mode->nbPBands];
250    int pitch_index;
251
252    N4 = (N-st->overlap)/2;
253
254    for (c=0;c<C;c++)
255    {
256       for (i=0;i<N4;i++)
257          in[C*i+c] = 0;
258       for (i=0;i<st->overlap;i++)
259          in[C*(i+N4)+c] = st->in_mem[C*i+c];
260       for (i=0;i<B*N;i++)
261       {
262          float tmp = pcm[C*i+c];
263          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
264          st->preemph_memE[c] = tmp;
265       }
266       for (i=N*(B+1)-N4;i<N*(B+1);i++)
267          in[C*i+c] = 0;
268       for (i=0;i<st->overlap;i++)
269          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
270    }
271    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
272    /* Compute MDCTs */
273    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
274
275    compute_mdct_masking(X, mask, B*C*N, st->Fs);
276
277    /* Invert and stretch the mask to length of X 
278       For some reason, I get better results by using the sqrt instead,
279       although there's no valid reason to. Must investigate further */
280    for (i=0;i<B*C*N;i++)
281       mask[i] = 1/(.1+mask[i]);
282
283    /* Pitch analysis */
284    for (c=0;c<C;c++)
285    {
286       for (i=0;i<N;i++)
287       {
288          in[C*i+c] *= st->window[i];
289          in[C*(B*N+i)+c] *= st->window[N+i];
290       }
291    }
292    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
293    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
294    
295    /* Compute MDCTs of the pitch part */
296    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
297    
298    /*int j;
299    for (j=0;j<B*N;j++)
300       printf ("%f ", X[j]);
301    for (j=0;j<B*N;j++)
302       printf ("%f ", P[j]);
303    printf ("\n");*/
304
305    /* Band normalisation */
306    compute_band_energies(st->mode, X, bandE);
307    normalise_bands(st->mode, X, bandE);
308    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
309    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
310
311    /* Normalise the pitch vector as well (discard the energies) */
312    {
313       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
314       compute_band_energies(st->mode, P, bandEp);
315       normalise_bands(st->mode, P, bandEp);
316    }
317
318    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
319
320    if (C==2)
321    {
322       stereo_mix(st->mode, X, bandE, 1);
323       stereo_mix(st->mode, P, bandE, 1);
324       //haar1(X, B*N*C, 1);
325       //haar1(P, B*N*C, 1);
326    }
327    /* Simulates intensity stereo */
328    //for (i=30;i<N*B;i++)
329    //   X[i*C+1] = P[i*C+1] = 0;
330    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
331    time_dct(X, N, B, C);
332    time_dct(P, N, B, C);
333
334
335    /* Pitch prediction */
336    compute_pitch_gain(st->mode, X, P, gains, bandE);
337    quant_pitch(gains, st->mode->nbPBands, &st->enc);
338    pitch_quant_bands(st->mode, X, P, gains);
339
340    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
341    /* Compute residual that we're going to encode */
342    for (i=0;i<B*C*N;i++)
343       X[i] -= P[i];
344
345    /*float sum=0;
346    for (i=0;i<B*N;i++)
347       sum += X[i]*X[i];
348    printf ("%f\n", sum);*/
349    /* Residual quantisation */
350    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
351    
352    time_idct(X, N, B, C);
353    if (C==2)
354       //haar1(X, B*N*C, 1);
355       stereo_mix(st->mode, X, bandE, -1);
356
357    renormalise_bands(st->mode, X);
358    /* Synthesis */
359    denormalise_bands(st->mode, X, bandE);
360
361
362    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
363
364    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
365    /* De-emphasis and put everything back at the right place in the synthesis history */
366    for (c=0;c<C;c++)
367    {
368       for (i=0;i<B;i++)
369       {
370          int j;
371          for (j=0;j<N;j++)
372          {
373             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
374             st->preemph_memD[c] = tmp;
375             if (tmp > 32767) tmp = 32767;
376             if (tmp < -32767) tmp = -32767;
377             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
378          }
379       }
380    }
381    
382    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
383    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
384    {
385       int val = 0;
386       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
387       {
388          ec_enc_uint(&st->enc, val, 2);
389          val = 1-val;
390       }
391    }
392    ec_enc_done(&st->enc);
393    {
394       unsigned char *data;
395       int nbBytes = ec_byte_bytes(&st->buf);
396       if (nbBytes != nbCompressedBytes)
397       {
398          if (nbBytes > nbCompressedBytes)
399             celt_warning("got too many bytes");
400          else
401             celt_warning("not enough bytes");
402          return CELT_INTERNAL_ERROR;
403       }
404       //printf ("%d\n", *nbBytes);
405       data = ec_byte_get_buffer(&st->buf);
406       for (i=0;i<nbBytes;i++)
407          compressed[i] = data[i];
408    }
409    /* Reset the packing for the next encoding */
410    ec_byte_reset(&st->buf);
411    ec_enc_init(&st->enc,&st->buf);
412
413    return nbCompressedBytes;
414 }
415
416
417 /****************************************************************************/
418 /*                                                                          */
419 /*                                DECODER                                   */
420 /*                                                                          */
421 /****************************************************************************/
422
423
424
425 struct CELTDecoder {
426    const CELTMode *mode;
427    int frame_size;
428    int block_size;
429    int nb_blocks;
430    int overlap;
431
432    ec_byte_buffer buf;
433    ec_enc         enc;
434
435    float preemph;
436    float *preemph_memD;
437    
438    mdct_lookup mdct_lookup;
439    
440    float *window;
441    float *mdct_overlap;
442    float *out_mem;
443
444    float *oldBandE;
445    
446    int last_pitch_index;
447    
448    struct alloc_data alloc;
449 };
450
451 CELTDecoder *celt_decoder_new(const CELTMode *mode)
452 {
453    int i, N, B, C, N4;
454    N = mode->mdctSize;
455    B = mode->nbMdctBlocks;
456    C = mode->nbChannels;
457    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
458    
459    st->mode = mode;
460    st->frame_size = B*N;
461    st->block_size = N;
462    st->nb_blocks  = B;
463    st->overlap = mode->overlap;
464
465    N4 = (N-st->overlap)/2;
466    
467    mdct_init(&st->mdct_lookup, 2*N);
468    
469    st->window = celt_alloc(2*N*sizeof(float));
470    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
471    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
472
473    for (i=0;i<2*N;i++)
474       st->window[i] = 0;
475    for (i=0;i<st->overlap;i++)
476       st->window[N4+i] = st->window[2*N-N4-i-1] 
477             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
478    for (i=0;i<2*N4;i++)
479       st->window[N-N4+i] = 1;
480    
481    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
482
483    st->preemph = 0.8;
484    st->preemph_memD = celt_alloc(C*sizeof(float));;
485
486    st->last_pitch_index = 0;
487    alloc_init(&st->alloc, st->mode);
488
489    return st;
490 }
491
492 void celt_decoder_destroy(CELTDecoder *st)
493 {
494    if (st == NULL)
495    {
496       celt_warning("NULL passed to celt_encoder_destroy");
497       return;
498    }
499
500    mdct_clear(&st->mdct_lookup);
501
502    celt_free(st->window);
503    celt_free(st->mdct_overlap);
504    celt_free(st->out_mem);
505    
506    celt_free(st->oldBandE);
507    alloc_clear(&st->alloc);
508
509    celt_free(st);
510 }
511
512 static void celt_decode_lost(CELTDecoder *st, short *pcm)
513 {
514    int i, c, N, B, C;
515    N = st->block_size;
516    B = st->nb_blocks;
517    C = st->mode->nbChannels;
518    float X[C*B*N];         /**< Interleaved signal MDCTs */
519    int pitch_index;
520    
521    pitch_index = st->last_pitch_index;
522    
523    /* Use the pitch MDCT as the "guessed" signal */
524    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
525
526    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
527    /* Compute inverse MDCTs */
528    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
529
530    for (c=0;c<C;c++)
531    {
532       for (i=0;i<B;i++)
533       {
534          int j;
535          for (j=0;j<N;j++)
536          {
537             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
538             st->preemph_memD[c] = tmp;
539             if (tmp > 32767) tmp = 32767;
540             if (tmp < -32767) tmp = -32767;
541             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
542          }
543       }
544    }
545 }
546
547 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
548 {
549    int i, c, N, B, C;
550    N = st->block_size;
551    B = st->nb_blocks;
552    C = st->mode->nbChannels;
553    
554    float X[C*B*N];         /**< Interleaved signal MDCTs */
555    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
556    float bandE[st->mode->nbEBands*C];
557    float gains[st->mode->nbPBands];
558    int pitch_index;
559    ec_dec dec;
560    ec_byte_buffer buf;
561    
562    if (data == NULL)
563    {
564       celt_decode_lost(st, pcm);
565       return 0;
566    }
567    
568    ec_byte_readinit(&buf,data,len);
569    ec_dec_init(&dec,&buf);
570    
571    /* Get the pitch index */
572    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
573    st->last_pitch_index = pitch_index;
574    
575    /* Get band energies */
576    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
577    
578    /* Pitch MDCT */
579    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
580
581    {
582       float bandEp[st->mode->nbEBands];
583       compute_band_energies(st->mode, P, bandEp);
584       normalise_bands(st->mode, P, bandEp);
585    }
586
587    if (C==2)
588       //haar1(P, B*N*C, 1);
589       stereo_mix(st->mode, P, bandE, 1);
590    time_dct(P, N, B, C);
591
592    /* Get the pitch gains */
593    unquant_pitch(gains, st->mode->nbPBands, &dec);
594
595    /* Apply pitch gains */
596    pitch_quant_bands(st->mode, X, P, gains);
597
598    /* Decode fixed codebook and merge with pitch */
599    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
600
601    time_idct(X, N, B, C);
602    if (C==2)
603       //haar1(X, B*N*C, 1);
604       stereo_mix(st->mode, X, bandE, -1);
605
606    renormalise_bands(st->mode, X);
607    
608    /* Synthesis */
609    denormalise_bands(st->mode, X, bandE);
610
611
612    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
613    /* Compute inverse MDCTs */
614    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
615
616    for (c=0;c<C;c++)
617    {
618       for (i=0;i<B;i++)
619       {
620          int j;
621          for (j=0;j<N;j++)
622          {
623             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
624             st->preemph_memD[c] = tmp;
625             if (tmp > 32767) tmp = 32767;
626             if (tmp < -32767) tmp = -32767;
627             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
628          }
629       }
630    }
631
632    {
633       int val = 0;
634       while (ec_dec_tell(&dec, 0) < len*8)
635       {
636          if (ec_dec_uint(&dec, 2) != val)
637          {
638             celt_warning("decode error");
639             return CELT_CORRUPTED_DATA;
640          }
641          val = 1-val;
642       }
643    }
644
645    return 0;
646    //printf ("\n");
647 }
648