Re-enabled intra-frame prediction, which seems to have exposed a few issues
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
148 {
149    int i, c;
150    for (c=0;c<C;c++)
151    {
152       for (i=0;i<B;i++)
153       {
154          int j;
155          float x[2*N];
156          float tmp[N];
157          for (j=0;j<2*N;j++)
158             x[j] = window[j]*in[C*i*N+C*j+c];
159          mdct_forward(mdct_lookup, x, tmp);
160          /* Interleaving the sub-frames */
161          for (j=0;j<N;j++)
162             out[C*B*j+C*i+c] = tmp[j];
163       }
164    }
165 }
166
167 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
168 {
169    int i, c, N4;
170    N4 = (N-overlap)/2;
171    for (c=0;c<C;c++)
172    {
173       for (i=0;i<B;i++)
174       {
175          int j;
176          float x[2*N];
177          float tmp[N];
178          /* De-interleaving the sub-frames */
179          for (j=0;j<N;j++)
180             tmp[j] = X[C*B*j+C*i+c];
181          mdct_backward(mdct_lookup, tmp, x);
182          for (j=0;j<2*N;j++)
183             x[j] = window[j]*x[j];
184          for (j=0;j<overlap;j++)
185             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
186          for (j=0;j<2*N4;j++)
187             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
188          for (j=0;j<overlap;j++)
189             mdct_overlap[C*j+c] = x[N+N4+j];
190       }
191    }
192 }
193
194 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
195 {
196    int i, c, N, B, C, N4;
197    N = st->block_size;
198    B = st->nb_blocks;
199    C = st->mode->nbChannels;
200    float in[(B+1)*C*N];
201
202    float X[B*C*N];         /**< Interleaved signal MDCTs */
203    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
204    float mask[B*C*N];      /**< Masking curve */
205    float bandE[st->mode->nbEBands*C];
206    float gains[st->mode->nbPBands];
207    int pitch_index;
208
209    N4 = (N-st->overlap)/2;
210
211    for (c=0;c<C;c++)
212    {
213       for (i=0;i<N4;i++)
214          in[C*i+c] = 0;
215       for (i=0;i<st->overlap;i++)
216          in[C*(i+N4)+c] = st->in_mem[C*i+c];
217       for (i=0;i<B*N;i++)
218       {
219          float tmp = pcm[C*i+c];
220          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
221          st->preemph_memE[c] = tmp;
222       }
223       for (i=N*(B+1)-N4;i<N*(B+1);i++)
224          in[C*i+c] = 0;
225       for (i=0;i<st->overlap;i++)
226          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
227    }
228    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
229    /* Compute MDCTs */
230    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
231
232    compute_mdct_masking(X, mask, B*C*N, st->Fs);
233
234    /* Invert and stretch the mask to length of X 
235       For some reason, I get better results by using the sqrt instead,
236       although there's no valid reason to. Must investigate further */
237    for (i=0;i<B*C*N;i++)
238       mask[i] = 1/(.1+mask[i]);
239
240    /* Pitch analysis */
241    for (c=0;c<C;c++)
242    {
243       for (i=0;i<N;i++)
244       {
245          in[C*i+c] *= st->window[i];
246          in[C*(B*N+i)+c] *= st->window[N+i];
247       }
248    }
249    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
250    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
251    
252    /* Compute MDCTs of the pitch part */
253    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
254    
255    /*int j;
256    for (j=0;j<B*N;j++)
257       printf ("%f ", X[j]);
258    for (j=0;j<B*N;j++)
259       printf ("%f ", P[j]);
260    printf ("\n");*/
261
262    /* Band normalisation */
263    compute_band_energies(st->mode, X, bandE);
264    normalise_bands(st->mode, X, bandE);
265    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
266    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
267
268    /* Normalise the pitch vector as well (discard the energies) */
269    {
270       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
271       compute_band_energies(st->mode, P, bandEp);
272       normalise_bands(st->mode, P, bandEp);
273    }
274
275    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
276
277    if (C==2)
278    {
279       stereo_mix(st->mode, X, bandE, 1);
280       stereo_mix(st->mode, P, bandE, 1);
281    }
282    /* Simulates intensity stereo */
283    //for (i=30;i<N*B;i++)
284    //   X[i*C+1] = P[i*C+1] = 0;
285    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
286
287    /* Pitch prediction */
288    compute_pitch_gain(st->mode, X, P, gains, bandE);
289    quant_pitch(gains, st->mode->nbPBands, &st->enc);
290    pitch_quant_bands(st->mode, X, P, gains);
291
292    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
293    /* Compute residual that we're going to encode */
294    for (i=0;i<B*C*N;i++)
295       X[i] -= P[i];
296
297    /*float sum=0;
298    for (i=0;i<B*N;i++)
299       sum += X[i]*X[i];
300    printf ("%f\n", sum);*/
301    /* Residual quantisation */
302    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
303    
304    if (C==2)
305       stereo_mix(st->mode, X, bandE, -1);
306
307    renormalise_bands(st->mode, X);
308    /* Synthesis */
309    denormalise_bands(st->mode, X, bandE);
310
311
312    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
313
314    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
315    /* De-emphasis and put everything back at the right place in the synthesis history */
316    for (c=0;c<C;c++)
317    {
318       for (i=0;i<B;i++)
319       {
320          int j;
321          for (j=0;j<N;j++)
322          {
323             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
324             st->preemph_memD[c] = tmp;
325             if (tmp > 32767) tmp = 32767;
326             if (tmp < -32767) tmp = -32767;
327             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
328          }
329       }
330    }
331    
332    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
333    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
334    {
335       int val = 0;
336       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
337       {
338          ec_enc_uint(&st->enc, val, 2);
339          val = 1-val;
340       }
341    }
342    ec_enc_done(&st->enc);
343    {
344       unsigned char *data;
345       int nbBytes = ec_byte_bytes(&st->buf);
346       if (nbBytes != nbCompressedBytes)
347       {
348          if (nbBytes > nbCompressedBytes)
349             celt_warning("got too many bytes");
350          else
351             celt_warning("not enough bytes");
352          celt_warning_int ("output bytes:", nbBytes);
353          return CELT_INTERNAL_ERROR;
354       }
355       //printf ("%d\n", *nbBytes);
356       data = ec_byte_get_buffer(&st->buf);
357       for (i=0;i<nbBytes;i++)
358          compressed[i] = data[i];
359    }
360    /* Reset the packing for the next encoding */
361    ec_byte_reset(&st->buf);
362    ec_enc_init(&st->enc,&st->buf);
363
364    return nbCompressedBytes;
365 }
366
367
368 /****************************************************************************/
369 /*                                                                          */
370 /*                                DECODER                                   */
371 /*                                                                          */
372 /****************************************************************************/
373
374
375
376 struct CELTDecoder {
377    const CELTMode *mode;
378    int frame_size;
379    int block_size;
380    int nb_blocks;
381    int overlap;
382
383    ec_byte_buffer buf;
384    ec_enc         enc;
385
386    float preemph;
387    float *preemph_memD;
388    
389    mdct_lookup mdct_lookup;
390    
391    float *window;
392    float *mdct_overlap;
393    float *out_mem;
394
395    float *oldBandE;
396    
397    int last_pitch_index;
398    
399    struct alloc_data alloc;
400 };
401
402 CELTDecoder *celt_decoder_new(const CELTMode *mode)
403 {
404    int i, N, B, C, N4;
405    N = mode->mdctSize;
406    B = mode->nbMdctBlocks;
407    C = mode->nbChannels;
408    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
409    
410    st->mode = mode;
411    st->frame_size = B*N;
412    st->block_size = N;
413    st->nb_blocks  = B;
414    st->overlap = mode->overlap;
415
416    N4 = (N-st->overlap)/2;
417    
418    mdct_init(&st->mdct_lookup, 2*N);
419    
420    st->window = celt_alloc(2*N*sizeof(float));
421    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
422    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
423
424    for (i=0;i<2*N;i++)
425       st->window[i] = 0;
426    for (i=0;i<st->overlap;i++)
427       st->window[N4+i] = st->window[2*N-N4-i-1] 
428             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
429    for (i=0;i<2*N4;i++)
430       st->window[N-N4+i] = 1;
431    
432    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
433
434    st->preemph = 0.8;
435    st->preemph_memD = celt_alloc(C*sizeof(float));;
436
437    st->last_pitch_index = 0;
438    alloc_init(&st->alloc, st->mode);
439
440    return st;
441 }
442
443 void celt_decoder_destroy(CELTDecoder *st)
444 {
445    if (st == NULL)
446    {
447       celt_warning("NULL passed to celt_encoder_destroy");
448       return;
449    }
450
451    mdct_clear(&st->mdct_lookup);
452
453    celt_free(st->window);
454    celt_free(st->mdct_overlap);
455    celt_free(st->out_mem);
456    
457    celt_free(st->oldBandE);
458    alloc_clear(&st->alloc);
459
460    celt_free(st);
461 }
462
463 static void celt_decode_lost(CELTDecoder *st, short *pcm)
464 {
465    int i, c, N, B, C;
466    N = st->block_size;
467    B = st->nb_blocks;
468    C = st->mode->nbChannels;
469    float X[C*B*N];         /**< Interleaved signal MDCTs */
470    int pitch_index;
471    
472    pitch_index = st->last_pitch_index;
473    
474    /* Use the pitch MDCT as the "guessed" signal */
475    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
476
477    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
478    /* Compute inverse MDCTs */
479    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
480
481    for (c=0;c<C;c++)
482    {
483       for (i=0;i<B;i++)
484       {
485          int j;
486          for (j=0;j<N;j++)
487          {
488             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
489             st->preemph_memD[c] = tmp;
490             if (tmp > 32767) tmp = 32767;
491             if (tmp < -32767) tmp = -32767;
492             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
493          }
494       }
495    }
496 }
497
498 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
499 {
500    int i, c, N, B, C;
501    N = st->block_size;
502    B = st->nb_blocks;
503    C = st->mode->nbChannels;
504    
505    float X[C*B*N];         /**< Interleaved signal MDCTs */
506    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
507    float bandE[st->mode->nbEBands*C];
508    float gains[st->mode->nbPBands];
509    int pitch_index;
510    ec_dec dec;
511    ec_byte_buffer buf;
512    
513    if (data == NULL)
514    {
515       celt_decode_lost(st, pcm);
516       return 0;
517    }
518    
519    ec_byte_readinit(&buf,data,len);
520    ec_dec_init(&dec,&buf);
521    
522    /* Get the pitch index */
523    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
524    st->last_pitch_index = pitch_index;
525    
526    /* Get band energies */
527    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
528    
529    /* Pitch MDCT */
530    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
531
532    {
533       float bandEp[st->mode->nbEBands];
534       compute_band_energies(st->mode, P, bandEp);
535       normalise_bands(st->mode, P, bandEp);
536    }
537
538    if (C==2)
539       stereo_mix(st->mode, P, bandE, 1);
540
541    /* Get the pitch gains */
542    unquant_pitch(gains, st->mode->nbPBands, &dec);
543
544    /* Apply pitch gains */
545    pitch_quant_bands(st->mode, X, P, gains);
546
547    /* Decode fixed codebook and merge with pitch */
548    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
549
550    if (C==2)
551       stereo_mix(st->mode, X, bandE, -1);
552
553    renormalise_bands(st->mode, X);
554    
555    /* Synthesis */
556    denormalise_bands(st->mode, X, bandE);
557
558
559    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
560    /* Compute inverse MDCTs */
561    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
562
563    for (c=0;c<C;c++)
564    {
565       for (i=0;i<B;i++)
566       {
567          int j;
568          for (j=0;j<N;j++)
569          {
570             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
571             st->preemph_memD[c] = tmp;
572             if (tmp > 32767) tmp = 32767;
573             if (tmp < -32767) tmp = -32767;
574             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
575          }
576       }
577    }
578
579    {
580       int val = 0;
581       while (ec_dec_tell(&dec, 0) < len*8)
582       {
583          if (ec_dec_uint(&dec, 2) != val)
584          {
585             celt_warning("decode error");
586             return CELT_CORRUPTED_DATA;
587          }
588          val = 1-val;
589       }
590    }
591
592    return 0;
593    //printf ("\n");
594 }
595