Fixed incorrect assumption about the number of bytes returned by the
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
148 {
149    int i, c;
150    for (c=0;c<C;c++)
151    {
152       for (i=0;i<B;i++)
153       {
154          int j;
155          float x[2*N];
156          float tmp[N];
157          for (j=0;j<2*N;j++)
158             x[j] = window[j]*in[C*i*N+C*j+c];
159          mdct_forward(mdct_lookup, x, tmp);
160          /* Interleaving the sub-frames */
161          for (j=0;j<N;j++)
162             out[C*B*j+C*i+c] = tmp[j];
163       }
164    }
165 }
166
167 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
168 {
169    int i, c, N4;
170    N4 = (N-overlap)/2;
171    for (c=0;c<C;c++)
172    {
173       for (i=0;i<B;i++)
174       {
175          int j;
176          float x[2*N];
177          float tmp[N];
178          /* De-interleaving the sub-frames */
179          for (j=0;j<N;j++)
180             tmp[j] = X[C*B*j+C*i+c];
181          mdct_backward(mdct_lookup, tmp, x);
182          for (j=0;j<2*N;j++)
183             x[j] = window[j]*x[j];
184          for (j=0;j<overlap;j++)
185             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
186          for (j=0;j<2*N4;j++)
187             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
188          for (j=0;j<overlap;j++)
189             mdct_overlap[C*j+c] = x[N+N4+j];
190       }
191    }
192 }
193
194 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
195 {
196    int i, c, N, B, C, N4;
197    N = st->block_size;
198    B = st->nb_blocks;
199    C = st->mode->nbChannels;
200    float in[(B+1)*C*N];
201
202    float X[B*C*N];         /**< Interleaved signal MDCTs */
203    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
204    float mask[B*C*N];      /**< Masking curve */
205    float bandE[st->mode->nbEBands*C];
206    float gains[st->mode->nbPBands];
207    int pitch_index;
208
209    N4 = (N-st->overlap)/2;
210
211    for (c=0;c<C;c++)
212    {
213       for (i=0;i<N4;i++)
214          in[C*i+c] = 0;
215       for (i=0;i<st->overlap;i++)
216          in[C*(i+N4)+c] = st->in_mem[C*i+c];
217       for (i=0;i<B*N;i++)
218       {
219          float tmp = pcm[C*i+c];
220          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
221          st->preemph_memE[c] = tmp;
222       }
223       for (i=N*(B+1)-N4;i<N*(B+1);i++)
224          in[C*i+c] = 0;
225       for (i=0;i<st->overlap;i++)
226          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
227    }
228    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
229    /* Compute MDCTs */
230    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
231
232    compute_mdct_masking(X, mask, B*C*N, st->Fs);
233
234    /* Invert and stretch the mask to length of X 
235       For some reason, I get better results by using the sqrt instead,
236       although there's no valid reason to. Must investigate further */
237    for (i=0;i<B*C*N;i++)
238       mask[i] = 1/(.1+mask[i]);
239
240    /* Pitch analysis */
241    for (c=0;c<C;c++)
242    {
243       for (i=0;i<N;i++)
244       {
245          in[C*i+c] *= st->window[i];
246          in[C*(B*N+i)+c] *= st->window[N+i];
247       }
248    }
249    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
250    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
251    
252    /* Compute MDCTs of the pitch part */
253    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
254    
255    /*int j;
256    for (j=0;j<B*N;j++)
257       printf ("%f ", X[j]);
258    for (j=0;j<B*N;j++)
259       printf ("%f ", P[j]);
260    printf ("\n");*/
261
262    /* Band normalisation */
263    compute_band_energies(st->mode, X, bandE);
264    normalise_bands(st->mode, X, bandE);
265    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
266    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
267
268    /* Normalise the pitch vector as well (discard the energies) */
269    {
270       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
271       compute_band_energies(st->mode, P, bandEp);
272       normalise_bands(st->mode, P, bandEp);
273    }
274
275    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
276
277    if (C==2)
278    {
279       stereo_mix(st->mode, X, bandE, 1);
280       stereo_mix(st->mode, P, bandE, 1);
281    }
282    /* Simulates intensity stereo */
283    //for (i=30;i<N*B;i++)
284    //   X[i*C+1] = P[i*C+1] = 0;
285    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
286
287    /* Pitch prediction */
288    compute_pitch_gain(st->mode, X, P, gains, bandE);
289    quant_pitch(gains, st->mode->nbPBands, &st->enc);
290    pitch_quant_bands(st->mode, X, P, gains);
291
292    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
293    /* Compute residual that we're going to encode */
294    for (i=0;i<B*C*N;i++)
295       X[i] -= P[i];
296
297    /*float sum=0;
298    for (i=0;i<B*N;i++)
299       sum += X[i]*X[i];
300    printf ("%f\n", sum);*/
301    /* Residual quantisation */
302    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
303    
304    if (C==2)
305       stereo_mix(st->mode, X, bandE, -1);
306
307    renormalise_bands(st->mode, X);
308    /* Synthesis */
309    denormalise_bands(st->mode, X, bandE);
310
311
312    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
313
314    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
315    /* De-emphasis and put everything back at the right place in the synthesis history */
316    for (c=0;c<C;c++)
317    {
318       for (i=0;i<B;i++)
319       {
320          int j;
321          for (j=0;j<N;j++)
322          {
323             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
324             st->preemph_memD[c] = tmp;
325             if (tmp > 32767) tmp = 32767;
326             if (tmp < -32767) tmp = -32767;
327             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
328          }
329       }
330    }
331    
332    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
333    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
334    {
335       int val = 0;
336       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
337       {
338          ec_enc_uint(&st->enc, val, 2);
339          val = 1-val;
340       }
341    }
342    ec_enc_done(&st->enc);
343    {
344       unsigned char *data;
345       int nbBytes = ec_byte_bytes(&st->buf);
346       if (nbBytes > nbCompressedBytes)
347       {
348          celt_warning_int ("got too many bytes:", nbBytes);
349          return CELT_INTERNAL_ERROR;
350       }
351       //printf ("%d\n", *nbBytes);
352       data = ec_byte_get_buffer(&st->buf);
353       for (i=0;i<nbBytes;i++)
354          compressed[i] = data[i];
355       for (;i<nbCompressedBytes;i++)
356          compressed[i] = 0;
357    }
358    /* Reset the packing for the next encoding */
359    ec_byte_reset(&st->buf);
360    ec_enc_init(&st->enc,&st->buf);
361
362    return nbCompressedBytes;
363 }
364
365
366 /****************************************************************************/
367 /*                                                                          */
368 /*                                DECODER                                   */
369 /*                                                                          */
370 /****************************************************************************/
371
372
373
374 struct CELTDecoder {
375    const CELTMode *mode;
376    int frame_size;
377    int block_size;
378    int nb_blocks;
379    int overlap;
380
381    ec_byte_buffer buf;
382    ec_enc         enc;
383
384    float preemph;
385    float *preemph_memD;
386    
387    mdct_lookup mdct_lookup;
388    
389    float *window;
390    float *mdct_overlap;
391    float *out_mem;
392
393    float *oldBandE;
394    
395    int last_pitch_index;
396    
397    struct alloc_data alloc;
398 };
399
400 CELTDecoder *celt_decoder_new(const CELTMode *mode)
401 {
402    int i, N, B, C, N4;
403    N = mode->mdctSize;
404    B = mode->nbMdctBlocks;
405    C = mode->nbChannels;
406    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
407    
408    st->mode = mode;
409    st->frame_size = B*N;
410    st->block_size = N;
411    st->nb_blocks  = B;
412    st->overlap = mode->overlap;
413
414    N4 = (N-st->overlap)/2;
415    
416    mdct_init(&st->mdct_lookup, 2*N);
417    
418    st->window = celt_alloc(2*N*sizeof(float));
419    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
420    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
421
422    for (i=0;i<2*N;i++)
423       st->window[i] = 0;
424    for (i=0;i<st->overlap;i++)
425       st->window[N4+i] = st->window[2*N-N4-i-1] 
426             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
427    for (i=0;i<2*N4;i++)
428       st->window[N-N4+i] = 1;
429    
430    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
431
432    st->preemph = 0.8;
433    st->preemph_memD = celt_alloc(C*sizeof(float));;
434
435    st->last_pitch_index = 0;
436    alloc_init(&st->alloc, st->mode);
437
438    return st;
439 }
440
441 void celt_decoder_destroy(CELTDecoder *st)
442 {
443    if (st == NULL)
444    {
445       celt_warning("NULL passed to celt_encoder_destroy");
446       return;
447    }
448
449    mdct_clear(&st->mdct_lookup);
450
451    celt_free(st->window);
452    celt_free(st->mdct_overlap);
453    celt_free(st->out_mem);
454    
455    celt_free(st->oldBandE);
456    alloc_clear(&st->alloc);
457
458    celt_free(st);
459 }
460
461 static void celt_decode_lost(CELTDecoder *st, short *pcm)
462 {
463    int i, c, N, B, C;
464    N = st->block_size;
465    B = st->nb_blocks;
466    C = st->mode->nbChannels;
467    float X[C*B*N];         /**< Interleaved signal MDCTs */
468    int pitch_index;
469    
470    pitch_index = st->last_pitch_index;
471    
472    /* Use the pitch MDCT as the "guessed" signal */
473    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
474
475    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
476    /* Compute inverse MDCTs */
477    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
478
479    for (c=0;c<C;c++)
480    {
481       for (i=0;i<B;i++)
482       {
483          int j;
484          for (j=0;j<N;j++)
485          {
486             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
487             st->preemph_memD[c] = tmp;
488             if (tmp > 32767) tmp = 32767;
489             if (tmp < -32767) tmp = -32767;
490             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
491          }
492       }
493    }
494 }
495
496 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
497 {
498    int i, c, N, B, C;
499    N = st->block_size;
500    B = st->nb_blocks;
501    C = st->mode->nbChannels;
502    
503    float X[C*B*N];         /**< Interleaved signal MDCTs */
504    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
505    float bandE[st->mode->nbEBands*C];
506    float gains[st->mode->nbPBands];
507    int pitch_index;
508    ec_dec dec;
509    ec_byte_buffer buf;
510    
511    if (data == NULL)
512    {
513       celt_decode_lost(st, pcm);
514       return 0;
515    }
516    
517    ec_byte_readinit(&buf,data,len);
518    ec_dec_init(&dec,&buf);
519    
520    /* Get the pitch index */
521    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
522    st->last_pitch_index = pitch_index;
523    
524    /* Get band energies */
525    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
526    
527    /* Pitch MDCT */
528    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
529
530    {
531       float bandEp[st->mode->nbEBands];
532       compute_band_energies(st->mode, P, bandEp);
533       normalise_bands(st->mode, P, bandEp);
534    }
535
536    if (C==2)
537       stereo_mix(st->mode, P, bandE, 1);
538
539    /* Get the pitch gains */
540    unquant_pitch(gains, st->mode->nbPBands, &dec);
541
542    /* Apply pitch gains */
543    pitch_quant_bands(st->mode, X, P, gains);
544
545    /* Decode fixed codebook and merge with pitch */
546    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
547
548    if (C==2)
549       stereo_mix(st->mode, X, bandE, -1);
550
551    renormalise_bands(st->mode, X);
552    
553    /* Synthesis */
554    denormalise_bands(st->mode, X, bandE);
555
556
557    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
558    /* Compute inverse MDCTs */
559    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
560
561    for (c=0;c<C;c++)
562    {
563       for (i=0;i<B;i++)
564       {
565          int j;
566          for (j=0;j<N;j++)
567          {
568             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
569             st->preemph_memD[c] = tmp;
570             if (tmp > 32767) tmp = 32767;
571             if (tmp < -32767) tmp = -32767;
572             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
573          }
574       }
575    }
576
577    {
578       int val = 0;
579       while (ec_dec_tell(&dec, 0) < len*8)
580       {
581          if (ec_dec_uint(&dec, 2) != val)
582          {
583             celt_warning("decode error");
584             return CELT_CORRUPTED_DATA;
585          }
586          val = 1-val;
587       }
588    }
589
590    return 0;
591    //printf ("\n");
592 }
593