Enabled pure CBR, though there's still some decoder issues.
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    int bytesPerFrame;
59    ec_byte_buffer buf;
60    ec_enc         enc;
61
62    float preemph;
63    float *preemph_memE;
64    float *preemph_memD;
65    
66    mdct_lookup mdct_lookup;
67    void *fft;
68    
69    float *window;
70    float *in_mem;
71    float *mdct_overlap;
72    float *out_mem;
73
74    float *oldBandE;
75    
76    struct alloc_data alloc;
77 };
78
79
80
81 CELTEncoder *celt_encoder_new(const CELTMode *mode)
82 {
83    int i, N, B, C, N4;
84    N = mode->mdctSize;
85    B = mode->nbMdctBlocks;
86    C = mode->nbChannels;
87    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
88    
89    st->mode = mode;
90    st->frame_size = B*N;
91    st->block_size = N;
92    st->nb_blocks  = B;
93    st->overlap = mode->overlap;
94    st->Fs = 44100;
95
96    N4 = (N-st->overlap)/2;
97    ec_byte_writeinit(&st->buf);
98    ec_enc_init(&st->enc,&st->buf);
99    st->bytesPerFrame = mode->defaultRate;
100    
101    mdct_init(&st->mdct_lookup, 2*N);
102    st->fft = spx_fft_init(MAX_PERIOD*C);
103    
104    st->window = celt_alloc(2*N*sizeof(float));
105    st->in_mem = celt_alloc(N*C*sizeof(float));
106    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
107    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
108    for (i=0;i<2*N;i++)
109       st->window[i] = 0;
110    for (i=0;i<st->overlap;i++)
111       st->window[N4+i] = st->window[2*N-N4-i-1] 
112             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
113    for (i=0;i<2*N4;i++)
114       st->window[N-N4+i] = 1;
115    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
116
117    st->preemph = 0.8;
118    st->preemph_memE = celt_alloc(C*sizeof(float));;
119    st->preemph_memD = celt_alloc(C*sizeof(float));;
120
121    alloc_init(&st->alloc, st->mode);
122    return st;
123 }
124
125 void celt_encoder_destroy(CELTEncoder *st)
126 {
127    if (st == NULL)
128    {
129       celt_warning("NULL passed to celt_encoder_destroy");
130       return;
131    }
132    ec_byte_writeclear(&st->buf);
133
134    mdct_clear(&st->mdct_lookup);
135    spx_fft_destroy(st->fft);
136
137    celt_free(st->window);
138    celt_free(st->in_mem);
139    celt_free(st->mdct_overlap);
140    celt_free(st->out_mem);
141    
142    celt_free(st->oldBandE);
143    alloc_clear(&st->alloc);
144
145    celt_free(st);
146 }
147
148 static void haar1(float *X, int N, int stride)
149 {
150    int i, k;
151    for (k=0;k<stride;k++)
152    {
153       for (i=k;i<N*stride;i+=2*stride)
154       {
155          float a, b;
156          a = X[i];
157          b = X[i+stride];
158          X[i] = .707107f*(a+b);
159          X[i+stride] = .707107f*(a-b);
160       }
161    }
162 }
163
164 static void time_dct(float *X, int N, int B, int stride)
165 {
166    switch (B)
167    {
168       case 1:
169          break;
170       case 2:
171          haar1(X, B*N, stride);
172          break;
173       default:
174          celt_warning("time_dct not defined for B > 2");
175    };
176 }
177
178 static void time_idct(float *X, int N, int B, int stride)
179 {
180    switch (B)
181    {
182       case 1:
183          break;
184       case 2:
185          haar1(X, B*N, stride);
186          break;
187       default:
188          celt_warning("time_dct not defined for B > 2");
189    };
190 }
191
192 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
193 {
194    int i, c;
195    for (c=0;c<C;c++)
196    {
197       for (i=0;i<B;i++)
198       {
199          int j;
200          float x[2*N];
201          float tmp[N];
202          for (j=0;j<2*N;j++)
203             x[j] = window[j]*in[C*i*N+C*j+c];
204          mdct_forward(mdct_lookup, x, tmp);
205          /* Interleaving the sub-frames */
206          for (j=0;j<N;j++)
207             out[C*B*j+C*i+c] = tmp[j];
208       }
209    }
210 }
211
212 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
213 {
214    int i, c, N4;
215    N4 = (N-overlap)/2;
216    for (c=0;c<C;c++)
217    {
218       for (i=0;i<B;i++)
219       {
220          int j;
221          float x[2*N];
222          float tmp[N];
223          /* De-interleaving the sub-frames */
224          for (j=0;j<N;j++)
225             tmp[j] = X[C*B*j+C*i+c];
226          mdct_backward(mdct_lookup, tmp, x);
227          for (j=0;j<2*N;j++)
228             x[j] = window[j]*x[j];
229          for (j=0;j<overlap;j++)
230             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
231          for (j=0;j<2*N4;j++)
232             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
233          for (j=0;j<overlap;j++)
234             mdct_overlap[C*j+c] = x[N+N4+j];
235       }
236    }
237 }
238
239 int celt_encoder_set_output_size(CELTEncoder *st, int bytesPerFrame)
240 {
241    if (bytesPerFrame<= 0)
242       return -1;
243    st->bytesPerFrame = bytesPerFrame;
244    return st->bytesPerFrame;
245 }
246
247 int celt_encode(CELTEncoder *st, short *pcm, char *compressed)
248 {
249    int i, c, N, B, C, N4;
250    N = st->block_size;
251    B = st->nb_blocks;
252    C = st->mode->nbChannels;
253    float in[(B+1)*C*N];
254
255    float X[B*C*N];         /**< Interleaved signal MDCTs */
256    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
257    float mask[B*C*N];      /**< Masking curve */
258    float bandE[st->mode->nbEBands*C];
259    float gains[st->mode->nbPBands];
260    int pitch_index;
261
262    N4 = (N-st->overlap)/2;
263
264    for (c=0;c<C;c++)
265    {
266       for (i=0;i<N4;i++)
267          in[C*i+c] = 0;
268       for (i=0;i<st->overlap;i++)
269          in[C*(i+N4)+c] = st->in_mem[C*i+c];
270       for (i=0;i<B*N;i++)
271       {
272          float tmp = pcm[C*i+c];
273          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
274          st->preemph_memE[c] = tmp;
275       }
276       for (i=N*(B+1)-N4;i<N*(B+1);i++)
277          in[C*i+c] = 0;
278       for (i=0;i<st->overlap;i++)
279          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
280    }
281    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
282    /* Compute MDCTs */
283    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
284
285    compute_mdct_masking(X, mask, B*C*N, st->Fs);
286
287    /* Invert and stretch the mask to length of X 
288       For some reason, I get better results by using the sqrt instead,
289       although there's no valid reason to. Must investigate further */
290    for (i=0;i<B*C*N;i++)
291       mask[i] = 1/(.1+mask[i]);
292
293    /* Pitch analysis */
294    for (c=0;c<C;c++)
295    {
296       for (i=0;i<N;i++)
297       {
298          in[C*i+c] *= st->window[i];
299          in[C*(B*N+i)+c] *= st->window[N+i];
300       }
301    }
302    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
303    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
304    
305    /* Compute MDCTs of the pitch part */
306    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
307    
308    /*int j;
309    for (j=0;j<B*N;j++)
310       printf ("%f ", X[j]);
311    for (j=0;j<B*N;j++)
312       printf ("%f ", P[j]);
313    printf ("\n");*/
314
315    /* Band normalisation */
316    compute_band_energies(st->mode, X, bandE);
317    normalise_bands(st->mode, X, bandE);
318    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
319    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
320
321    /* Normalise the pitch vector as well (discard the energies) */
322    {
323       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
324       compute_band_energies(st->mode, P, bandEp);
325       normalise_bands(st->mode, P, bandEp);
326    }
327
328    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
329
330    if (C==2)
331    {
332       stereo_mix(st->mode, X, bandE, 1);
333       stereo_mix(st->mode, P, bandE, 1);
334       //haar1(X, B*N*C, 1);
335       //haar1(P, B*N*C, 1);
336    }
337    /* Simulates intensity stereo */
338    //for (i=30;i<N*B;i++)
339    //   X[i*C+1] = P[i*C+1] = 0;
340    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
341    time_dct(X, N, B, C);
342    time_dct(P, N, B, C);
343
344
345    /* Pitch prediction */
346    compute_pitch_gain(st->mode, X, P, gains, bandE);
347    quant_pitch(gains, st->mode->nbPBands, &st->enc);
348    pitch_quant_bands(st->mode, X, P, gains);
349
350    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
351    /* Compute residual that we're going to encode */
352    for (i=0;i<B*C*N;i++)
353       X[i] -= P[i];
354
355    /*float sum=0;
356    for (i=0;i<B*N;i++)
357       sum += X[i]*X[i];
358    printf ("%f\n", sum);*/
359    /* Residual quantisation */
360    quant_bands(st->mode, X, P, mask, &st->alloc, st->bytesPerFrame*8, &st->enc);
361    
362    time_idct(X, N, B, C);
363    if (C==2)
364       //haar1(X, B*N*C, 1);
365       stereo_mix(st->mode, X, bandE, -1);
366
367    renormalise_bands(st->mode, X);
368    /* Synthesis */
369    denormalise_bands(st->mode, X, bandE);
370
371
372    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
373
374    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
375    /* De-emphasis and put everything back at the right place in the synthesis history */
376    for (c=0;c<C;c++)
377    {
378       for (i=0;i<B;i++)
379       {
380          int j;
381          for (j=0;j<N;j++)
382          {
383             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
384             st->preemph_memD[c] = tmp;
385             if (tmp > 32767) tmp = 32767;
386             if (tmp < -32767) tmp = -32767;
387             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
388          }
389       }
390    }
391    
392    while (ec_enc_tell(&st->enc, 0) < st->bytesPerFrame*8)
393       ec_enc_bits(&st->enc, 1, 1);
394    ec_enc_done(&st->enc);
395    int nbBytes = ec_byte_bytes(&st->buf);
396    char *data = ec_byte_get_buffer(&st->buf);
397    for (i=0;i<nbBytes;i++)
398       compressed[i] = data[i];
399    /* Fill the last byte with the right pattern so the decoder doesn't get confused
400       if the encoder didn't return enough bytes */
401    /* FIXME: This isn't quite what the decoder expects, but it's the best we can do for now */
402    for (;i<st->bytesPerFrame;i++)
403       compressed[i] = 0x00;
404    //printf ("%d\n", *nbBytes);
405    
406    /* Reset the packing for the next encoding */
407    ec_byte_reset(&st->buf);
408    ec_enc_init(&st->enc,&st->buf);
409
410    return st->bytesPerFrame;
411 }
412
413 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
414 {
415    char *data;
416    ec_enc_done(&st->enc);
417    *nbBytes = ec_byte_bytes(&st->buf);
418    data = ec_byte_get_buffer(&st->buf);
419    //printf ("%d\n", *nbBytes);
420    
421    /* Reset the packing for the next encoding */
422    ec_byte_reset(&st->buf);
423    ec_enc_init(&st->enc,&st->buf);
424
425    return data;
426 }
427
428
429 /****************************************************************************/
430 /*                                                                          */
431 /*                                DECODER                                   */
432 /*                                                                          */
433 /****************************************************************************/
434
435
436
437 struct CELTDecoder {
438    const CELTMode *mode;
439    int frame_size;
440    int block_size;
441    int nb_blocks;
442    int overlap;
443
444    ec_byte_buffer buf;
445    ec_enc         enc;
446
447    float preemph;
448    float *preemph_memD;
449    
450    mdct_lookup mdct_lookup;
451    
452    float *window;
453    float *mdct_overlap;
454    float *out_mem;
455
456    float *oldBandE;
457    
458    int last_pitch_index;
459    
460    struct alloc_data alloc;
461 };
462
463 CELTDecoder *celt_decoder_new(const CELTMode *mode)
464 {
465    int i, N, B, C, N4;
466    N = mode->mdctSize;
467    B = mode->nbMdctBlocks;
468    C = mode->nbChannels;
469    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
470    
471    st->mode = mode;
472    st->frame_size = B*N;
473    st->block_size = N;
474    st->nb_blocks  = B;
475    st->overlap = mode->overlap;
476
477    N4 = (N-st->overlap)/2;
478    
479    mdct_init(&st->mdct_lookup, 2*N);
480    
481    st->window = celt_alloc(2*N*sizeof(float));
482    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
483    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
484
485    for (i=0;i<2*N;i++)
486       st->window[i] = 0;
487    for (i=0;i<st->overlap;i++)
488       st->window[N4+i] = st->window[2*N-N4-i-1] 
489             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
490    for (i=0;i<2*N4;i++)
491       st->window[N-N4+i] = 1;
492    
493    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
494
495    st->preemph = 0.8;
496    st->preemph_memD = celt_alloc(C*sizeof(float));;
497
498    st->last_pitch_index = 0;
499    alloc_init(&st->alloc, st->mode);
500
501    return st;
502 }
503
504 void celt_decoder_destroy(CELTDecoder *st)
505 {
506    if (st == NULL)
507    {
508       celt_warning("NULL passed to celt_encoder_destroy");
509       return;
510    }
511
512    mdct_clear(&st->mdct_lookup);
513
514    celt_free(st->window);
515    celt_free(st->mdct_overlap);
516    celt_free(st->out_mem);
517    
518    celt_free(st->oldBandE);
519    alloc_clear(&st->alloc);
520
521    celt_free(st);
522 }
523
524 static void celt_decode_lost(CELTDecoder *st, short *pcm)
525 {
526    int i, c, N, B, C;
527    N = st->block_size;
528    B = st->nb_blocks;
529    C = st->mode->nbChannels;
530    float X[C*B*N];         /**< Interleaved signal MDCTs */
531    int pitch_index;
532    
533    pitch_index = st->last_pitch_index;
534    
535    /* Use the pitch MDCT as the "guessed" signal */
536    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
537
538    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
539    /* Compute inverse MDCTs */
540    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
541
542    for (c=0;c<C;c++)
543    {
544       for (i=0;i<B;i++)
545       {
546          int j;
547          for (j=0;j<N;j++)
548          {
549             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
550             st->preemph_memD[c] = tmp;
551             if (tmp > 32767) tmp = 32767;
552             if (tmp < -32767) tmp = -32767;
553             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
554          }
555       }
556    }
557 }
558
559 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
560 {
561    int i, c, N, B, C;
562    N = st->block_size;
563    B = st->nb_blocks;
564    C = st->mode->nbChannels;
565    
566    float X[C*B*N];         /**< Interleaved signal MDCTs */
567    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
568    float bandE[st->mode->nbEBands*C];
569    float gains[st->mode->nbPBands];
570    int pitch_index;
571    ec_dec dec;
572    ec_byte_buffer buf;
573    
574    if (data == NULL)
575    {
576       celt_decode_lost(st, pcm);
577       return 0;
578    }
579    
580    ec_byte_readinit(&buf,data,len);
581    ec_dec_init(&dec,&buf);
582    
583    /* Get the pitch index */
584    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
585    st->last_pitch_index = pitch_index;
586    
587    /* Get band energies */
588    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
589    
590    /* Pitch MDCT */
591    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
592
593    {
594       float bandEp[st->mode->nbEBands];
595       compute_band_energies(st->mode, P, bandEp);
596       normalise_bands(st->mode, P, bandEp);
597    }
598
599    if (C==2)
600       //haar1(P, B*N*C, 1);
601       stereo_mix(st->mode, P, bandE, 1);
602    time_dct(P, N, B, C);
603
604    /* Get the pitch gains */
605    unquant_pitch(gains, st->mode->nbPBands, &dec);
606
607    /* Apply pitch gains */
608    pitch_quant_bands(st->mode, X, P, gains);
609
610    /* Decode fixed codebook and merge with pitch */
611    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
612
613    time_idct(X, N, B, C);
614    if (C==2)
615       //haar1(X, B*N*C, 1);
616       stereo_mix(st->mode, X, bandE, -1);
617
618    renormalise_bands(st->mode, X);
619    
620    /* Synthesis */
621    denormalise_bands(st->mode, X, bandE);
622
623
624    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
625    /* Compute inverse MDCTs */
626    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
627
628    for (c=0;c<C;c++)
629    {
630       for (i=0;i<B;i++)
631       {
632          int j;
633          for (j=0;j<N;j++)
634          {
635             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
636             st->preemph_memD[c] = tmp;
637             if (tmp > 32767) tmp = 32767;
638             if (tmp < -32767) tmp = -32767;
639             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
640          }
641       }
642    }
643    return 0;
644    //printf ("\n");
645 }
646