Re-ordered the parameters in the stream: [energy, pitch index, pitch gains]
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99
100    mdct_init(&st->mdct_lookup, 2*N);
101    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
102    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    kiss_fft_free(st->fft);
136    psydecay_clear(&st->psy);
137
138    celt_free(st->window);
139    celt_free(st->in_mem);
140    celt_free(st->mdct_overlap);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148    alloc_clear(&st->alloc);
149
150    celt_free(st);
151 }
152
153
154 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
155 {
156    int i, c;
157    for (c=0;c<C;c++)
158    {
159       for (i=0;i<B;i++)
160       {
161          int j;
162          float x[2*N];
163          float tmp[N];
164          for (j=0;j<2*N;j++)
165             x[j] = window[j]*in[C*i*N+C*j+c];
166          mdct_forward(mdct_lookup, x, tmp);
167          /* Interleaving the sub-frames */
168          for (j=0;j<N;j++)
169             out[C*B*j+C*i+c] = tmp[j];
170       }
171    }
172 }
173
174 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
175 {
176    int i, c, N4;
177    N4 = (N-overlap)/2;
178    for (c=0;c<C;c++)
179    {
180       for (i=0;i<B;i++)
181       {
182          int j;
183          float x[2*N];
184          float tmp[N];
185          /* De-interleaving the sub-frames */
186          for (j=0;j<N;j++)
187             tmp[j] = X[C*B*j+C*i+c];
188          mdct_backward(mdct_lookup, tmp, x);
189          for (j=0;j<2*N;j++)
190             x[j] = window[j]*x[j];
191          for (j=0;j<overlap;j++)
192             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
193          for (j=0;j<2*N4;j++)
194             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
195          for (j=0;j<overlap;j++)
196             mdct_overlap[C*j+c] = x[N+N4+j];
197       }
198    }
199 }
200
201 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
202 {
203    int i, c, N, B, C, N4;
204    N = st->block_size;
205    B = st->nb_blocks;
206    C = st->mode->nbChannels;
207    float in[(B+1)*C*N];
208
209    float X[B*C*N];         /**< Interleaved signal MDCTs */
210    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
211    float mask[B*C*N];      /**< Masking curve */
212    float bandE[st->mode->nbEBands*C];
213    float gains[st->mode->nbPBands];
214    int pitch_index;
215
216    N4 = (N-st->overlap)/2;
217
218    for (c=0;c<C;c++)
219    {
220       for (i=0;i<N4;i++)
221          in[C*i+c] = 0;
222       for (i=0;i<st->overlap;i++)
223          in[C*(i+N4)+c] = st->in_mem[C*i+c];
224       for (i=0;i<B*N;i++)
225       {
226          float tmp = pcm[C*i+c];
227          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
228          st->preemph_memE[c] = tmp;
229       }
230       for (i=N*(B+1)-N4;i<N*(B+1);i++)
231          in[C*i+c] = 0;
232       for (i=0;i<st->overlap;i++)
233          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
234    }
235    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
236    /* Compute MDCTs */
237    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
238
239 #if 0 /* Mask disabled until it can be made to do something useful */
240    compute_mdct_masking(X, mask, B*C*N, st->Fs);
241
242    /* Invert and stretch the mask to length of X 
243       For some reason, I get better results by using the sqrt instead,
244       although there's no valid reason to. Must investigate further */
245    for (i=0;i<B*C*N;i++)
246       mask[i] = 1/(.1+mask[i]);
247 #else
248    for (i=0;i<B*C*N;i++)
249       mask[i] = 1;
250 #endif
251    /* Pitch analysis */
252    for (c=0;c<C;c++)
253    {
254       for (i=0;i<N;i++)
255       {
256          in[C*i+c] *= st->window[i];
257          in[C*(B*N+i)+c] *= st->window[N+i];
258       }
259    }
260    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
261    
262    /* Compute MDCTs of the pitch part */
263    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
264    
265    /*int j;
266    for (j=0;j<B*N;j++)
267       printf ("%f ", X[j]);
268    for (j=0;j<B*N;j++)
269       printf ("%f ", P[j]);
270    printf ("\n");*/
271
272    /* Band normalisation */
273    compute_band_energies(st->mode, X, bandE);
274    normalise_bands(st->mode, X, bandE);
275    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
276    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
277
278    /* Normalise the pitch vector as well (discard the energies) */
279    {
280       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
281       compute_band_energies(st->mode, P, bandEp);
282       normalise_bands(st->mode, P, bandEp);
283    }
284
285    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
286
287    if (C==2)
288    {
289       stereo_mix(st->mode, X, bandE, 1);
290       stereo_mix(st->mode, P, bandE, 1);
291    }
292    /* Simulates intensity stereo */
293    //for (i=30;i<N*B;i++)
294    //   X[i*C+1] = P[i*C+1] = 0;
295    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
296
297    /* Pitch prediction */
298    compute_pitch_gain(st->mode, X, P, gains, bandE);
299    quant_pitch(gains, st->mode->nbPBands, &st->enc);
300    
301    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
302
303    pitch_quant_bands(st->mode, X, P, gains);
304
305    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
306    /* Compute residual that we're going to encode */
307    for (i=0;i<B*C*N;i++)
308       X[i] -= P[i];
309
310    /*float sum=0;
311    for (i=0;i<B*N;i++)
312       sum += X[i]*X[i];
313    printf ("%f\n", sum);*/
314    /* Residual quantisation */
315    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
316    
317    if (C==2)
318       stereo_mix(st->mode, X, bandE, -1);
319
320    renormalise_bands(st->mode, X);
321    /* Synthesis */
322    denormalise_bands(st->mode, X, bandE);
323
324
325    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
326
327    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
328    /* De-emphasis and put everything back at the right place in the synthesis history */
329    for (c=0;c<C;c++)
330    {
331       for (i=0;i<B;i++)
332       {
333          int j;
334          for (j=0;j<N;j++)
335          {
336             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
337             st->preemph_memD[c] = tmp;
338             if (tmp > 32767) tmp = 32767;
339             if (tmp < -32767) tmp = -32767;
340             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
341          }
342       }
343    }
344    
345    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 8)
346       celt_warning_int ("too make unused bits", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
347    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
348    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
349    {
350       int val = 0;
351       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
352       {
353          ec_enc_uint(&st->enc, val, 2);
354          val = 1-val;
355       }
356    }
357    ec_enc_done(&st->enc);
358    {
359       unsigned char *data;
360       int nbBytes = ec_byte_bytes(&st->buf);
361       if (nbBytes > nbCompressedBytes)
362       {
363          celt_warning_int ("got too many bytes:", nbBytes);
364          return CELT_INTERNAL_ERROR;
365       }
366       //printf ("%d\n", *nbBytes);
367       data = ec_byte_get_buffer(&st->buf);
368       for (i=0;i<nbBytes;i++)
369          compressed[i] = data[i];
370       for (;i<nbCompressedBytes;i++)
371          compressed[i] = 0;
372    }
373    /* Reset the packing for the next encoding */
374    ec_byte_reset(&st->buf);
375    ec_enc_init(&st->enc,&st->buf);
376
377    return nbCompressedBytes;
378 }
379
380
381 /****************************************************************************/
382 /*                                                                          */
383 /*                                DECODER                                   */
384 /*                                                                          */
385 /****************************************************************************/
386
387
388
389 struct CELTDecoder {
390    const CELTMode *mode;
391    int frame_size;
392    int block_size;
393    int nb_blocks;
394    int overlap;
395
396    ec_byte_buffer buf;
397    ec_enc         enc;
398
399    float preemph;
400    float *preemph_memD;
401    
402    mdct_lookup mdct_lookup;
403    
404    float *window;
405    float *mdct_overlap;
406    float *out_mem;
407
408    float *oldBandE;
409    
410    int last_pitch_index;
411    
412    struct alloc_data alloc;
413 };
414
415 CELTDecoder *celt_decoder_new(const CELTMode *mode)
416 {
417    int i, N, B, C, N4;
418    N = mode->mdctSize;
419    B = mode->nbMdctBlocks;
420    C = mode->nbChannels;
421    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
422    
423    st->mode = mode;
424    st->frame_size = B*N;
425    st->block_size = N;
426    st->nb_blocks  = B;
427    st->overlap = mode->overlap;
428
429    N4 = (N-st->overlap)/2;
430    
431    mdct_init(&st->mdct_lookup, 2*N);
432    
433    st->window = celt_alloc(2*N*sizeof(float));
434    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
435    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
436
437    for (i=0;i<2*N;i++)
438       st->window[i] = 0;
439    for (i=0;i<st->overlap;i++)
440       st->window[N4+i] = st->window[2*N-N4-i-1] 
441             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
442    for (i=0;i<2*N4;i++)
443       st->window[N-N4+i] = 1;
444    
445    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
446
447    st->preemph = 0.8;
448    st->preemph_memD = celt_alloc(C*sizeof(float));;
449
450    st->last_pitch_index = 0;
451    alloc_init(&st->alloc, st->mode);
452
453    return st;
454 }
455
456 void celt_decoder_destroy(CELTDecoder *st)
457 {
458    if (st == NULL)
459    {
460       celt_warning("NULL passed to celt_encoder_destroy");
461       return;
462    }
463
464    mdct_clear(&st->mdct_lookup);
465
466    celt_free(st->window);
467    celt_free(st->mdct_overlap);
468    celt_free(st->out_mem);
469    
470    celt_free(st->oldBandE);
471    
472    celt_free(st->preemph_memD);
473
474    alloc_clear(&st->alloc);
475
476    celt_free(st);
477 }
478
479 static void celt_decode_lost(CELTDecoder *st, short *pcm)
480 {
481    int i, c, N, B, C;
482    N = st->block_size;
483    B = st->nb_blocks;
484    C = st->mode->nbChannels;
485    float X[C*B*N];         /**< Interleaved signal MDCTs */
486    int pitch_index;
487    
488    pitch_index = st->last_pitch_index;
489    
490    /* Use the pitch MDCT as the "guessed" signal */
491    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
492
493    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
494    /* Compute inverse MDCTs */
495    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
496
497    for (c=0;c<C;c++)
498    {
499       for (i=0;i<B;i++)
500       {
501          int j;
502          for (j=0;j<N;j++)
503          {
504             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
505             st->preemph_memD[c] = tmp;
506             if (tmp > 32767) tmp = 32767;
507             if (tmp < -32767) tmp = -32767;
508             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
509          }
510       }
511    }
512 }
513
514 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
515 {
516    int i, c, N, B, C;
517    N = st->block_size;
518    B = st->nb_blocks;
519    C = st->mode->nbChannels;
520    
521    float X[C*B*N];         /**< Interleaved signal MDCTs */
522    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
523    float bandE[st->mode->nbEBands*C];
524    float gains[st->mode->nbPBands];
525    int pitch_index;
526    ec_dec dec;
527    ec_byte_buffer buf;
528    
529    if (data == NULL)
530    {
531       celt_decode_lost(st, pcm);
532       return 0;
533    }
534    
535    ec_byte_readinit(&buf,data,len);
536    ec_dec_init(&dec,&buf);
537    
538    /* Get band energies */
539    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
540    
541    /* Get the pitch gains */
542    unquant_pitch(gains, st->mode->nbPBands, &dec);
543    
544    /* Get the pitch index */
545    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
546    st->last_pitch_index = pitch_index;
547    
548    /* Pitch MDCT */
549    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
550
551    {
552       float bandEp[st->mode->nbEBands*C];
553       compute_band_energies(st->mode, P, bandEp);
554       normalise_bands(st->mode, P, bandEp);
555    }
556
557    if (C==2)
558       stereo_mix(st->mode, P, bandE, 1);
559
560    /* Apply pitch gains */
561    pitch_quant_bands(st->mode, X, P, gains);
562
563    /* Decode fixed codebook and merge with pitch */
564    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
565
566    if (C==2)
567       stereo_mix(st->mode, X, bandE, -1);
568
569    renormalise_bands(st->mode, X);
570    
571    /* Synthesis */
572    denormalise_bands(st->mode, X, bandE);
573
574
575    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
576    /* Compute inverse MDCTs */
577    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
578
579    for (c=0;c<C;c++)
580    {
581       for (i=0;i<B;i++)
582       {
583          int j;
584          for (j=0;j<N;j++)
585          {
586             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
587             st->preemph_memD[c] = tmp;
588             if (tmp > 32767) tmp = 32767;
589             if (tmp < -32767) tmp = -32767;
590             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
591          }
592       }
593    }
594
595    {
596       int val = 0;
597       while (ec_dec_tell(&dec, 0) < len*8)
598       {
599          if (ec_dec_uint(&dec, 2) != val)
600          {
601             celt_warning("decode error");
602             return CELT_CORRUPTED_DATA;
603          }
604          val = 1-val;
605       }
606    }
607
608    return 0;
609    //printf ("\n");
610 }
611