As extra safety, make sure not to use pitch prediction when it would make
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99
100    mdct_init(&st->mdct_lookup, 2*N);
101    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
102    psydecay_init(&st->psy, MAX_PERIOD*C/2, st->Fs);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    kiss_fft_free(st->fft);
136    psydecay_clear(&st->psy);
137
138    celt_free(st->window);
139    celt_free(st->in_mem);
140    celt_free(st->mdct_overlap);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148    alloc_clear(&st->alloc);
149
150    celt_free(st);
151 }
152
153
154 static float compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
155 {
156    int i, c;
157    float E = 1e-15;
158    for (c=0;c<C;c++)
159    {
160       for (i=0;i<B;i++)
161       {
162          int j;
163          float x[2*N];
164          float tmp[N];
165          for (j=0;j<2*N;j++)
166          {
167             x[j] = window[j]*in[C*i*N+C*j+c];
168             E += x[j]*x[j];
169          }
170          mdct_forward(mdct_lookup, x, tmp);
171          /* Interleaving the sub-frames */
172          for (j=0;j<N;j++)
173             out[C*B*j+C*i+c] = tmp[j];
174       }
175    }
176    return E;
177 }
178
179 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
180 {
181    int i, c, N4;
182    N4 = (N-overlap)/2;
183    for (c=0;c<C;c++)
184    {
185       for (i=0;i<B;i++)
186       {
187          int j;
188          float x[2*N];
189          float tmp[N];
190          /* De-interleaving the sub-frames */
191          for (j=0;j<N;j++)
192             tmp[j] = X[C*B*j+C*i+c];
193          mdct_backward(mdct_lookup, tmp, x);
194          for (j=0;j<2*N;j++)
195             x[j] = window[j]*x[j];
196          for (j=0;j<overlap;j++)
197             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
198          for (j=0;j<2*N4;j++)
199             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
200          for (j=0;j<overlap;j++)
201             mdct_overlap[C*j+c] = x[N+N4+j];
202       }
203    }
204 }
205
206 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
207 {
208    int i, c, N, B, C, N4;
209    int has_pitch;
210    N = st->block_size;
211    B = st->nb_blocks;
212    C = st->mode->nbChannels;
213    float in[(B+1)*C*N];
214
215    float X[B*C*N];         /**< Interleaved signal MDCTs */
216    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
217    float mask[B*C*N];      /**< Masking curve */
218    float bandE[st->mode->nbEBands*C];
219    float gains[st->mode->nbPBands];
220    int pitch_index;
221    float curr_power, pitch_power;
222    
223    N4 = (N-st->overlap)/2;
224
225    for (c=0;c<C;c++)
226    {
227       for (i=0;i<N4;i++)
228          in[C*i+c] = 0;
229       for (i=0;i<st->overlap;i++)
230          in[C*(i+N4)+c] = st->in_mem[C*i+c];
231       for (i=0;i<B*N;i++)
232       {
233          float tmp = pcm[C*i+c];
234          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
235          st->preemph_memE[c] = tmp;
236       }
237       for (i=N*(B+1)-N4;i<N*(B+1);i++)
238          in[C*i+c] = 0;
239       for (i=0;i<st->overlap;i++)
240          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
241    }
242    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
243    /* Compute MDCTs */
244    curr_power = compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
245
246 #if 0 /* Mask disabled until it can be made to do something useful */
247    compute_mdct_masking(X, mask, B*C*N, st->Fs);
248
249    /* Invert and stretch the mask to length of X 
250       For some reason, I get better results by using the sqrt instead,
251       although there's no valid reason to. Must investigate further */
252    for (i=0;i<B*C*N;i++)
253       mask[i] = 1/(.1+mask[i]);
254 #else
255    for (i=0;i<B*C*N;i++)
256       mask[i] = 1;
257 #endif
258    /* Pitch analysis */
259    for (c=0;c<C;c++)
260    {
261       for (i=0;i<N;i++)
262       {
263          in[C*i+c] *= st->window[i];
264          in[C*(B*N+i)+c] *= st->window[N+i];
265       }
266    }
267    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
268    
269    /* Compute MDCTs of the pitch part */
270    pitch_power = compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
271    
272    //printf ("%f %f\n", curr_power, pitch_power);
273    /*int j;
274    for (j=0;j<B*N;j++)
275       printf ("%f ", X[j]);
276    for (j=0;j<B*N;j++)
277       printf ("%f ", P[j]);
278    printf ("\n");*/
279
280    /* Band normalisation */
281    compute_band_energies(st->mode, X, bandE);
282    normalise_bands(st->mode, X, bandE);
283    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
284    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
285
286    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
287
288    /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
289    if (curr_power + 1e5f < 10.f*pitch_power)
290    {
291       /* Normalise the pitch vector as well (discard the energies) */
292       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
293       compute_band_energies(st->mode, P, bandEp);
294       normalise_bands(st->mode, P, bandEp);
295
296       if (C==2)
297       {
298          stereo_mix(st->mode, X, bandE, 1);
299          stereo_mix(st->mode, P, bandE, 1);
300       }
301       /* Simulates intensity stereo */
302       //for (i=30;i<N*B;i++)
303       //   X[i*C+1] = P[i*C+1] = 0;
304
305       /* Pitch prediction */
306       compute_pitch_gain(st->mode, X, P, gains, bandE);
307       has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
308       if (has_pitch)
309          ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
310    } else {
311       /* No pitch, so we just pretend we found a gain of zero */
312       for (i=0;i<st->mode->nbPBands;i++)
313          gains[i] = 0;
314       ec_enc_uint(&st->enc, 0, 128);
315       for (i=0;i<B*C*N;i++)
316          P[i] = 0;
317    }
318    
319
320    pitch_quant_bands(st->mode, X, P, gains);
321
322    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
323    /* Compute residual that we're going to encode */
324    for (i=0;i<B*C*N;i++)
325       X[i] -= P[i];
326
327    /*float sum=0;
328    for (i=0;i<B*N;i++)
329       sum += X[i]*X[i];
330    printf ("%f\n", sum);*/
331    /* Residual quantisation */
332    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
333    
334    if (C==2)
335       stereo_mix(st->mode, X, bandE, -1);
336
337    renormalise_bands(st->mode, X);
338    /* Synthesis */
339    denormalise_bands(st->mode, X, bandE);
340
341
342    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
343
344    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
345    /* De-emphasis and put everything back at the right place in the synthesis history */
346    for (c=0;c<C;c++)
347    {
348       for (i=0;i<B;i++)
349       {
350          int j;
351          for (j=0;j<N;j++)
352          {
353             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
354             st->preemph_memD[c] = tmp;
355             if (tmp > 32767) tmp = 32767;
356             if (tmp < -32767) tmp = -32767;
357             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
358          }
359       }
360    }
361    
362    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 8)
363       celt_warning_int ("too make unused bits", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
364    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
365    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
366    {
367       int val = 0;
368       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
369       {
370          ec_enc_uint(&st->enc, val, 2);
371          val = 1-val;
372       }
373    }
374    ec_enc_done(&st->enc);
375    {
376       unsigned char *data;
377       int nbBytes = ec_byte_bytes(&st->buf);
378       if (nbBytes > nbCompressedBytes)
379       {
380          celt_warning_int ("got too many bytes:", nbBytes);
381          return CELT_INTERNAL_ERROR;
382       }
383       //printf ("%d\n", *nbBytes);
384       data = ec_byte_get_buffer(&st->buf);
385       for (i=0;i<nbBytes;i++)
386          compressed[i] = data[i];
387       for (;i<nbCompressedBytes;i++)
388          compressed[i] = 0;
389    }
390    /* Reset the packing for the next encoding */
391    ec_byte_reset(&st->buf);
392    ec_enc_init(&st->enc,&st->buf);
393
394    return nbCompressedBytes;
395 }
396
397
398 /****************************************************************************/
399 /*                                                                          */
400 /*                                DECODER                                   */
401 /*                                                                          */
402 /****************************************************************************/
403
404
405
406 struct CELTDecoder {
407    const CELTMode *mode;
408    int frame_size;
409    int block_size;
410    int nb_blocks;
411    int overlap;
412
413    ec_byte_buffer buf;
414    ec_enc         enc;
415
416    float preemph;
417    float *preemph_memD;
418    
419    mdct_lookup mdct_lookup;
420    
421    float *window;
422    float *mdct_overlap;
423    float *out_mem;
424
425    float *oldBandE;
426    
427    int last_pitch_index;
428    
429    struct alloc_data alloc;
430 };
431
432 CELTDecoder *celt_decoder_new(const CELTMode *mode)
433 {
434    int i, N, B, C, N4;
435    N = mode->mdctSize;
436    B = mode->nbMdctBlocks;
437    C = mode->nbChannels;
438    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
439    
440    st->mode = mode;
441    st->frame_size = B*N;
442    st->block_size = N;
443    st->nb_blocks  = B;
444    st->overlap = mode->overlap;
445
446    N4 = (N-st->overlap)/2;
447    
448    mdct_init(&st->mdct_lookup, 2*N);
449    
450    st->window = celt_alloc(2*N*sizeof(float));
451    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
452    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
453
454    for (i=0;i<2*N;i++)
455       st->window[i] = 0;
456    for (i=0;i<st->overlap;i++)
457       st->window[N4+i] = st->window[2*N-N4-i-1] 
458             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
459    for (i=0;i<2*N4;i++)
460       st->window[N-N4+i] = 1;
461    
462    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
463
464    st->preemph = 0.8;
465    st->preemph_memD = celt_alloc(C*sizeof(float));;
466
467    st->last_pitch_index = 0;
468    alloc_init(&st->alloc, st->mode);
469
470    return st;
471 }
472
473 void celt_decoder_destroy(CELTDecoder *st)
474 {
475    if (st == NULL)
476    {
477       celt_warning("NULL passed to celt_encoder_destroy");
478       return;
479    }
480
481    mdct_clear(&st->mdct_lookup);
482
483    celt_free(st->window);
484    celt_free(st->mdct_overlap);
485    celt_free(st->out_mem);
486    
487    celt_free(st->oldBandE);
488    
489    celt_free(st->preemph_memD);
490
491    alloc_clear(&st->alloc);
492
493    celt_free(st);
494 }
495
496 static void celt_decode_lost(CELTDecoder *st, short *pcm)
497 {
498    int i, c, N, B, C;
499    N = st->block_size;
500    B = st->nb_blocks;
501    C = st->mode->nbChannels;
502    float X[C*B*N];         /**< Interleaved signal MDCTs */
503    int pitch_index;
504    
505    pitch_index = st->last_pitch_index;
506    
507    /* Use the pitch MDCT as the "guessed" signal */
508    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
509
510    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
511    /* Compute inverse MDCTs */
512    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
513
514    for (c=0;c<C;c++)
515    {
516       for (i=0;i<B;i++)
517       {
518          int j;
519          for (j=0;j<N;j++)
520          {
521             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
522             st->preemph_memD[c] = tmp;
523             if (tmp > 32767) tmp = 32767;
524             if (tmp < -32767) tmp = -32767;
525             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
526          }
527       }
528    }
529 }
530
531 int celt_decode(CELTDecoder *st, unsigned char *data, int len, celt_int16_t *pcm)
532 {
533    int i, c, N, B, C;
534    int has_pitch;
535    N = st->block_size;
536    B = st->nb_blocks;
537    C = st->mode->nbChannels;
538    
539    float X[C*B*N];         /**< Interleaved signal MDCTs */
540    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
541    float bandE[st->mode->nbEBands*C];
542    float gains[st->mode->nbPBands];
543    int pitch_index;
544    ec_dec dec;
545    ec_byte_buffer buf;
546    
547    if (data == NULL)
548    {
549       celt_decode_lost(st, pcm);
550       return 0;
551    }
552    
553    ec_byte_readinit(&buf,data,len);
554    ec_dec_init(&dec,&buf);
555    
556    /* Get band energies */
557    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
558    
559    /* Get the pitch gains */
560    has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
561    
562    /* Get the pitch index */
563    if (has_pitch)
564    {
565       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
566       st->last_pitch_index = pitch_index;
567    } else {
568       /* FIXME: We could be more intelligent here and just not compute the MDCT */
569       pitch_index = 0;
570    }
571    
572    /* Pitch MDCT */
573    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
574
575    {
576       float bandEp[st->mode->nbEBands*C];
577       compute_band_energies(st->mode, P, bandEp);
578       normalise_bands(st->mode, P, bandEp);
579    }
580
581    if (C==2)
582       stereo_mix(st->mode, P, bandE, 1);
583
584    /* Apply pitch gains */
585    pitch_quant_bands(st->mode, X, P, gains);
586
587    /* Decode fixed codebook and merge with pitch */
588    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
589
590    if (C==2)
591       stereo_mix(st->mode, X, bandE, -1);
592
593    renormalise_bands(st->mode, X);
594    
595    /* Synthesis */
596    denormalise_bands(st->mode, X, bandE);
597
598
599    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
600    /* Compute inverse MDCTs */
601    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
602
603    for (c=0;c<C;c++)
604    {
605       for (i=0;i<B;i++)
606       {
607          int j;
608          for (j=0;j<N;j++)
609          {
610             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
611             st->preemph_memD[c] = tmp;
612             if (tmp > 32767) tmp = 32767;
613             if (tmp < -32767) tmp = -32767;
614             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
615          }
616       }
617    }
618
619    {
620       int val = 0;
621       while (ec_dec_tell(&dec, 0) < len*8)
622       {
623          if (ec_dec_uint(&dec, 2) != val)
624          {
625             celt_warning("decode error");
626             return CELT_CORRUPTED_DATA;
627          }
628          val = 1-val;
629       }
630    }
631
632    return 0;
633    //printf ("\n");
634 }
635