Limiting intra-frame prediction codebook to 32 entries (plus sign)
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "kiss_fftr.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    kiss_fftr_cfg fft;
67    struct PsyDecay psy;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99
100    mdct_init(&st->mdct_lookup, 2*N);
101    st->fft = kiss_fftr_alloc(MAX_PERIOD*C, 0, 0);
102    psydecay_init(&st->psy, MAX_PERIOD/2, st->Fs);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    kiss_fft_free(st->fft);
136    psydecay_clear(&st->psy);
137
138    celt_free(st->window);
139    celt_free(st->in_mem);
140    celt_free(st->mdct_overlap);
141    celt_free(st->out_mem);
142    
143    celt_free(st->oldBandE);
144    
145    celt_free(st->preemph_memE);
146    celt_free(st->preemph_memD);
147    
148    alloc_clear(&st->alloc);
149
150    celt_free(st);
151 }
152
153
154 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
155 {
156    int i, c;
157    for (c=0;c<C;c++)
158    {
159       for (i=0;i<B;i++)
160       {
161          int j;
162          float x[2*N];
163          float tmp[N];
164          for (j=0;j<2*N;j++)
165             x[j] = window[j]*in[C*i*N+C*j+c];
166          mdct_forward(mdct_lookup, x, tmp);
167          /* Interleaving the sub-frames */
168          for (j=0;j<N;j++)
169             out[C*B*j+C*i+c] = tmp[j];
170       }
171    }
172 }
173
174 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
175 {
176    int i, c, N4;
177    N4 = (N-overlap)/2;
178    for (c=0;c<C;c++)
179    {
180       for (i=0;i<B;i++)
181       {
182          int j;
183          float x[2*N];
184          float tmp[N];
185          /* De-interleaving the sub-frames */
186          for (j=0;j<N;j++)
187             tmp[j] = X[C*B*j+C*i+c];
188          mdct_backward(mdct_lookup, tmp, x);
189          for (j=0;j<2*N;j++)
190             x[j] = window[j]*x[j];
191          for (j=0;j<overlap;j++)
192             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
193          for (j=0;j<2*N4;j++)
194             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
195          for (j=0;j<overlap;j++)
196             mdct_overlap[C*j+c] = x[N+N4+j];
197       }
198    }
199 }
200
201 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
202 {
203    int i, c, N, B, C, N4;
204    N = st->block_size;
205    B = st->nb_blocks;
206    C = st->mode->nbChannels;
207    float in[(B+1)*C*N];
208
209    float X[B*C*N];         /**< Interleaved signal MDCTs */
210    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
211    float mask[B*C*N];      /**< Masking curve */
212    float bandE[st->mode->nbEBands*C];
213    float gains[st->mode->nbPBands];
214    int pitch_index;
215
216    N4 = (N-st->overlap)/2;
217
218    for (c=0;c<C;c++)
219    {
220       for (i=0;i<N4;i++)
221          in[C*i+c] = 0;
222       for (i=0;i<st->overlap;i++)
223          in[C*(i+N4)+c] = st->in_mem[C*i+c];
224       for (i=0;i<B*N;i++)
225       {
226          float tmp = pcm[C*i+c];
227          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
228          st->preemph_memE[c] = tmp;
229       }
230       for (i=N*(B+1)-N4;i<N*(B+1);i++)
231          in[C*i+c] = 0;
232       for (i=0;i<st->overlap;i++)
233          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
234    }
235    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
236    /* Compute MDCTs */
237    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
238
239 #if 0 /* Mask disabled until it can be made to do something useful */
240    compute_mdct_masking(X, mask, B*C*N, st->Fs);
241
242    /* Invert and stretch the mask to length of X 
243       For some reason, I get better results by using the sqrt instead,
244       although there's no valid reason to. Must investigate further */
245    for (i=0;i<B*C*N;i++)
246       mask[i] = 1/(.1+mask[i]);
247 #else
248    for (i=0;i<B*C*N;i++)
249       mask[i] = 1;
250 #endif
251    /* Pitch analysis */
252    for (c=0;c<C;c++)
253    {
254       for (i=0;i<N;i++)
255       {
256          in[C*i+c] *= st->window[i];
257          in[C*(B*N+i)+c] *= st->window[N+i];
258       }
259    }
260    find_spectral_pitch(st->fft, &st->psy, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
261    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
262    
263    /* Compute MDCTs of the pitch part */
264    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
265    
266    /*int j;
267    for (j=0;j<B*N;j++)
268       printf ("%f ", X[j]);
269    for (j=0;j<B*N;j++)
270       printf ("%f ", P[j]);
271    printf ("\n");*/
272
273    /* Band normalisation */
274    compute_band_energies(st->mode, X, bandE);
275    normalise_bands(st->mode, X, bandE);
276    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
277    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
278
279    /* Normalise the pitch vector as well (discard the energies) */
280    {
281       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
282       compute_band_energies(st->mode, P, bandEp);
283       normalise_bands(st->mode, P, bandEp);
284    }
285
286    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
287
288    if (C==2)
289    {
290       stereo_mix(st->mode, X, bandE, 1);
291       stereo_mix(st->mode, P, bandE, 1);
292    }
293    /* Simulates intensity stereo */
294    //for (i=30;i<N*B;i++)
295    //   X[i*C+1] = P[i*C+1] = 0;
296    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
297
298    /* Pitch prediction */
299    compute_pitch_gain(st->mode, X, P, gains, bandE);
300    quant_pitch(gains, st->mode->nbPBands, &st->enc);
301    pitch_quant_bands(st->mode, X, P, gains);
302
303    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
304    /* Compute residual that we're going to encode */
305    for (i=0;i<B*C*N;i++)
306       X[i] -= P[i];
307
308    /*float sum=0;
309    for (i=0;i<B*N;i++)
310       sum += X[i]*X[i];
311    printf ("%f\n", sum);*/
312    /* Residual quantisation */
313    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
314    
315    if (C==2)
316       stereo_mix(st->mode, X, bandE, -1);
317
318    renormalise_bands(st->mode, X);
319    /* Synthesis */
320    denormalise_bands(st->mode, X, bandE);
321
322
323    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
324
325    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
326    /* De-emphasis and put everything back at the right place in the synthesis history */
327    for (c=0;c<C;c++)
328    {
329       for (i=0;i<B;i++)
330       {
331          int j;
332          for (j=0;j<N;j++)
333          {
334             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
335             st->preemph_memD[c] = tmp;
336             if (tmp > 32767) tmp = 32767;
337             if (tmp < -32767) tmp = -32767;
338             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
339          }
340       }
341    }
342    
343    if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 8)
344       celt_warning_int ("too make unused bits", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));
345    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
346    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
347    {
348       int val = 0;
349       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
350       {
351          ec_enc_uint(&st->enc, val, 2);
352          val = 1-val;
353       }
354    }
355    ec_enc_done(&st->enc);
356    {
357       unsigned char *data;
358       int nbBytes = ec_byte_bytes(&st->buf);
359       if (nbBytes > nbCompressedBytes)
360       {
361          celt_warning_int ("got too many bytes:", nbBytes);
362          return CELT_INTERNAL_ERROR;
363       }
364       //printf ("%d\n", *nbBytes);
365       data = ec_byte_get_buffer(&st->buf);
366       for (i=0;i<nbBytes;i++)
367          compressed[i] = data[i];
368       for (;i<nbCompressedBytes;i++)
369          compressed[i] = 0;
370    }
371    /* Reset the packing for the next encoding */
372    ec_byte_reset(&st->buf);
373    ec_enc_init(&st->enc,&st->buf);
374
375    return nbCompressedBytes;
376 }
377
378
379 /****************************************************************************/
380 /*                                                                          */
381 /*                                DECODER                                   */
382 /*                                                                          */
383 /****************************************************************************/
384
385
386
387 struct CELTDecoder {
388    const CELTMode *mode;
389    int frame_size;
390    int block_size;
391    int nb_blocks;
392    int overlap;
393
394    ec_byte_buffer buf;
395    ec_enc         enc;
396
397    float preemph;
398    float *preemph_memD;
399    
400    mdct_lookup mdct_lookup;
401    
402    float *window;
403    float *mdct_overlap;
404    float *out_mem;
405
406    float *oldBandE;
407    
408    int last_pitch_index;
409    
410    struct alloc_data alloc;
411 };
412
413 CELTDecoder *celt_decoder_new(const CELTMode *mode)
414 {
415    int i, N, B, C, N4;
416    N = mode->mdctSize;
417    B = mode->nbMdctBlocks;
418    C = mode->nbChannels;
419    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
420    
421    st->mode = mode;
422    st->frame_size = B*N;
423    st->block_size = N;
424    st->nb_blocks  = B;
425    st->overlap = mode->overlap;
426
427    N4 = (N-st->overlap)/2;
428    
429    mdct_init(&st->mdct_lookup, 2*N);
430    
431    st->window = celt_alloc(2*N*sizeof(float));
432    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
433    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
434
435    for (i=0;i<2*N;i++)
436       st->window[i] = 0;
437    for (i=0;i<st->overlap;i++)
438       st->window[N4+i] = st->window[2*N-N4-i-1] 
439             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
440    for (i=0;i<2*N4;i++)
441       st->window[N-N4+i] = 1;
442    
443    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
444
445    st->preemph = 0.8;
446    st->preemph_memD = celt_alloc(C*sizeof(float));;
447
448    st->last_pitch_index = 0;
449    alloc_init(&st->alloc, st->mode);
450
451    return st;
452 }
453
454 void celt_decoder_destroy(CELTDecoder *st)
455 {
456    if (st == NULL)
457    {
458       celt_warning("NULL passed to celt_encoder_destroy");
459       return;
460    }
461
462    mdct_clear(&st->mdct_lookup);
463
464    celt_free(st->window);
465    celt_free(st->mdct_overlap);
466    celt_free(st->out_mem);
467    
468    celt_free(st->oldBandE);
469    
470    celt_free(st->preemph_memD);
471
472    alloc_clear(&st->alloc);
473
474    celt_free(st);
475 }
476
477 static void celt_decode_lost(CELTDecoder *st, short *pcm)
478 {
479    int i, c, N, B, C;
480    N = st->block_size;
481    B = st->nb_blocks;
482    C = st->mode->nbChannels;
483    float X[C*B*N];         /**< Interleaved signal MDCTs */
484    int pitch_index;
485    
486    pitch_index = st->last_pitch_index;
487    
488    /* Use the pitch MDCT as the "guessed" signal */
489    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
490
491    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
492    /* Compute inverse MDCTs */
493    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
494
495    for (c=0;c<C;c++)
496    {
497       for (i=0;i<B;i++)
498       {
499          int j;
500          for (j=0;j<N;j++)
501          {
502             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
503             st->preemph_memD[c] = tmp;
504             if (tmp > 32767) tmp = 32767;
505             if (tmp < -32767) tmp = -32767;
506             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
507          }
508       }
509    }
510 }
511
512 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
513 {
514    int i, c, N, B, C;
515    N = st->block_size;
516    B = st->nb_blocks;
517    C = st->mode->nbChannels;
518    
519    float X[C*B*N];         /**< Interleaved signal MDCTs */
520    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
521    float bandE[st->mode->nbEBands*C];
522    float gains[st->mode->nbPBands];
523    int pitch_index;
524    ec_dec dec;
525    ec_byte_buffer buf;
526    
527    if (data == NULL)
528    {
529       celt_decode_lost(st, pcm);
530       return 0;
531    }
532    
533    ec_byte_readinit(&buf,data,len);
534    ec_dec_init(&dec,&buf);
535    
536    /* Get the pitch index */
537    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
538    st->last_pitch_index = pitch_index;
539    
540    /* Get band energies */
541    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
542    
543    /* Pitch MDCT */
544    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
545
546    {
547       float bandEp[st->mode->nbEBands];
548       compute_band_energies(st->mode, P, bandEp);
549       normalise_bands(st->mode, P, bandEp);
550    }
551
552    if (C==2)
553       stereo_mix(st->mode, P, bandE, 1);
554
555    /* Get the pitch gains */
556    unquant_pitch(gains, st->mode->nbPBands, &dec);
557
558    /* Apply pitch gains */
559    pitch_quant_bands(st->mode, X, P, gains);
560
561    /* Decode fixed codebook and merge with pitch */
562    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
563
564    if (C==2)
565       stereo_mix(st->mode, X, bandE, -1);
566
567    renormalise_bands(st->mode, X);
568    
569    /* Synthesis */
570    denormalise_bands(st->mode, X, bandE);
571
572
573    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
574    /* Compute inverse MDCTs */
575    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
576
577    for (c=0;c<C;c++)
578    {
579       for (i=0;i<B;i++)
580       {
581          int j;
582          for (j=0;j<N;j++)
583          {
584             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
585             st->preemph_memD[c] = tmp;
586             if (tmp > 32767) tmp = 32767;
587             if (tmp < -32767) tmp = -32767;
588             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
589          }
590       }
591    }
592
593    {
594       int val = 0;
595       while (ec_dec_tell(&dec, 0) < len*8)
596       {
597          if (ec_dec_uint(&dec, 2) != val)
598          {
599             celt_warning("decode error");
600             return CELT_CORRUPTED_DATA;
601          }
602          val = 1-val;
603       }
604    }
605
606    return 0;
607    //printf ("\n");
608 }
609