Close to getting CBR working
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146 static void haar1(float *X, int N, int stride)
147 {
148    int i, k;
149    for (k=0;k<stride;k++)
150    {
151       for (i=k;i<N*stride;i+=2*stride)
152       {
153          float a, b;
154          a = X[i];
155          b = X[i+stride];
156          X[i] = .707107f*(a+b);
157          X[i+stride] = .707107f*(a-b);
158       }
159    }
160 }
161
162 static void time_dct(float *X, int N, int B, int stride)
163 {
164    switch (B)
165    {
166       case 1:
167          break;
168       case 2:
169          haar1(X, B*N, stride);
170          break;
171       default:
172          celt_warning("time_dct not defined for B > 2");
173    };
174 }
175
176 static void time_idct(float *X, int N, int B, int stride)
177 {
178    switch (B)
179    {
180       case 1:
181          break;
182       case 2:
183          haar1(X, B*N, stride);
184          break;
185       default:
186          celt_warning("time_dct not defined for B > 2");
187    };
188 }
189
190 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
191 {
192    int i, c;
193    for (c=0;c<C;c++)
194    {
195       for (i=0;i<B;i++)
196       {
197          int j;
198          float x[2*N];
199          float tmp[N];
200          for (j=0;j<2*N;j++)
201             x[j] = window[j]*in[C*i*N+C*j+c];
202          mdct_forward(mdct_lookup, x, tmp);
203          /* Interleaving the sub-frames */
204          for (j=0;j<N;j++)
205             out[C*B*j+C*i+c] = tmp[j];
206       }
207    }
208 }
209
210 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
211 {
212    int i, c, N4;
213    N4 = (N-overlap)/2;
214    for (c=0;c<C;c++)
215    {
216       for (i=0;i<B;i++)
217       {
218          int j;
219          float x[2*N];
220          float tmp[N];
221          /* De-interleaving the sub-frames */
222          for (j=0;j<N;j++)
223             tmp[j] = X[C*B*j+C*i+c];
224          mdct_backward(mdct_lookup, tmp, x);
225          for (j=0;j<2*N;j++)
226             x[j] = window[j]*x[j];
227          for (j=0;j<overlap;j++)
228             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
229          for (j=0;j<2*N4;j++)
230             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
231          for (j=0;j<overlap;j++)
232             mdct_overlap[C*j+c] = x[N+N4+j];
233       }
234    }
235 }
236
237 int celt_encode(CELTEncoder *st, short *pcm)
238 {
239    int i, c, N, B, C, N4;
240    N = st->block_size;
241    B = st->nb_blocks;
242    C = st->mode->nbChannels;
243    float in[(B+1)*C*N];
244
245    float X[B*C*N];         /**< Interleaved signal MDCTs */
246    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
247    float mask[B*C*N];      /**< Masking curve */
248    float bandE[st->mode->nbEBands*C];
249    float gains[st->mode->nbPBands];
250    int pitch_index;
251
252    N4 = (N-st->overlap)/2;
253
254    for (c=0;c<C;c++)
255    {
256       for (i=0;i<N4;i++)
257          in[C*i+c] = 0;
258       for (i=0;i<st->overlap;i++)
259          in[C*(i+N4)+c] = st->in_mem[C*i+c];
260       for (i=0;i<B*N;i++)
261       {
262          float tmp = pcm[C*i+c];
263          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
264          st->preemph_memE[c] = tmp;
265       }
266       for (i=N*(B+1)-N4;i<N*(B+1);i++)
267          in[C*i+c] = 0;
268       for (i=0;i<st->overlap;i++)
269          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
270    }
271    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
272    /* Compute MDCTs */
273    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
274
275    compute_mdct_masking(X, mask, B*C*N, st->Fs);
276
277    /* Invert and stretch the mask to length of X 
278       For some reason, I get better results by using the sqrt instead,
279       although there's no valid reason to. Must investigate further */
280    for (i=0;i<B*C*N;i++)
281       mask[i] = 1/(.1+mask[i]);
282
283    /* Pitch analysis */
284    for (c=0;c<C;c++)
285    {
286       for (i=0;i<N;i++)
287       {
288          in[C*i+c] *= st->window[i];
289          in[C*(B*N+i)+c] *= st->window[N+i];
290       }
291    }
292    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
293    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
294    
295    /* Compute MDCTs of the pitch part */
296    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
297    
298    /*int j;
299    for (j=0;j<B*N;j++)
300       printf ("%f ", X[j]);
301    for (j=0;j<B*N;j++)
302       printf ("%f ", P[j]);
303    printf ("\n");*/
304
305    /* Band normalisation */
306    compute_band_energies(st->mode, X, bandE);
307    normalise_bands(st->mode, X, bandE);
308    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
309    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
310
311    /* Normalise the pitch vector as well (discard the energies) */
312    {
313       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
314       compute_band_energies(st->mode, P, bandEp);
315       normalise_bands(st->mode, P, bandEp);
316    }
317
318    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
319
320    if (C==2)
321    {
322       stereo_mix(st->mode, X, bandE, 1);
323       stereo_mix(st->mode, P, bandE, 1);
324       //haar1(X, B*N*C, 1);
325       //haar1(P, B*N*C, 1);
326    }
327    /* Simulates intensity stereo */
328    //for (i=30;i<N*B;i++)
329    //   X[i*C+1] = P[i*C+1] = 0;
330    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
331    time_dct(X, N, B, C);
332    time_dct(P, N, B, C);
333
334
335    /* Pitch prediction */
336    compute_pitch_gain(st->mode, X, P, gains, bandE);
337    quant_pitch(gains, st->mode->nbPBands, &st->enc);
338    pitch_quant_bands(st->mode, X, P, gains);
339
340    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
341    /* Compute residual that we're going to encode */
342    for (i=0;i<B*C*N;i++)
343       X[i] -= P[i];
344
345    /*float sum=0;
346    for (i=0;i<B*N;i++)
347       sum += X[i]*X[i];
348    printf ("%f\n", sum);*/
349    /* Residual quantisation */
350    quant_bands(st->mode, X, P, mask, &st->alloc, 770, &st->enc);
351    
352    time_idct(X, N, B, C);
353    if (C==2)
354       //haar1(X, B*N*C, 1);
355       stereo_mix(st->mode, X, bandE, -1);
356
357    renormalise_bands(st->mode, X);
358    /* Synthesis */
359    denormalise_bands(st->mode, X, bandE);
360
361
362    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
363
364    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
365    /* De-emphasis and put everything back at the right place in the synthesis history */
366    for (c=0;c<C;c++)
367    {
368       for (i=0;i<B;i++)
369       {
370          int j;
371          for (j=0;j<N;j++)
372          {
373             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
374             st->preemph_memD[c] = tmp;
375             if (tmp > 32767) tmp = 32767;
376             if (tmp < -32767) tmp = -32767;
377             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
378          }
379       }
380    }
381    return 0;
382 }
383
384 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
385 {
386    char *data;
387    ec_enc_done(&st->enc);
388    *nbBytes = ec_byte_bytes(&st->buf);
389    data = ec_byte_get_buffer(&st->buf);
390    //printf ("%d\n", *nbBytes);
391    
392    /* Reset the packing for the next encoding */
393    ec_byte_reset(&st->buf);
394    ec_enc_init(&st->enc,&st->buf);
395
396    return data;
397 }
398
399
400 /****************************************************************************/
401 /*                                                                          */
402 /*                                DECODER                                   */
403 /*                                                                          */
404 /****************************************************************************/
405
406
407
408 struct CELTDecoder {
409    const CELTMode *mode;
410    int frame_size;
411    int block_size;
412    int nb_blocks;
413    int overlap;
414
415    ec_byte_buffer buf;
416    ec_enc         enc;
417
418    float preemph;
419    float *preemph_memD;
420    
421    mdct_lookup mdct_lookup;
422    
423    float *window;
424    float *mdct_overlap;
425    float *out_mem;
426
427    float *oldBandE;
428    
429    int last_pitch_index;
430    
431    struct alloc_data alloc;
432 };
433
434 CELTDecoder *celt_decoder_new(const CELTMode *mode)
435 {
436    int i, N, B, C, N4;
437    N = mode->mdctSize;
438    B = mode->nbMdctBlocks;
439    C = mode->nbChannels;
440    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
441    
442    st->mode = mode;
443    st->frame_size = B*N;
444    st->block_size = N;
445    st->nb_blocks  = B;
446    st->overlap = mode->overlap;
447
448    N4 = (N-st->overlap)/2;
449    
450    mdct_init(&st->mdct_lookup, 2*N);
451    
452    st->window = celt_alloc(2*N*sizeof(float));
453    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
454    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
455
456    for (i=0;i<2*N;i++)
457       st->window[i] = 0;
458    for (i=0;i<st->overlap;i++)
459       st->window[N4+i] = st->window[2*N-N4-i-1] 
460             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
461    for (i=0;i<2*N4;i++)
462       st->window[N-N4+i] = 1;
463    
464    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
465
466    st->preemph = 0.8;
467    st->preemph_memD = celt_alloc(C*sizeof(float));;
468
469    st->last_pitch_index = 0;
470    alloc_init(&st->alloc, st->mode);
471
472    return st;
473 }
474
475 void celt_decoder_destroy(CELTDecoder *st)
476 {
477    if (st == NULL)
478    {
479       celt_warning("NULL passed to celt_encoder_destroy");
480       return;
481    }
482
483    mdct_clear(&st->mdct_lookup);
484
485    celt_free(st->window);
486    celt_free(st->mdct_overlap);
487    celt_free(st->out_mem);
488    
489    celt_free(st->oldBandE);
490    alloc_clear(&st->alloc);
491
492    celt_free(st);
493 }
494
495 static void celt_decode_lost(CELTDecoder *st, short *pcm)
496 {
497    int i, c, N, B, C;
498    N = st->block_size;
499    B = st->nb_blocks;
500    C = st->mode->nbChannels;
501    float X[C*B*N];         /**< Interleaved signal MDCTs */
502    int pitch_index;
503    
504    pitch_index = st->last_pitch_index;
505    
506    /* Use the pitch MDCT as the "guessed" signal */
507    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
508
509    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
510    /* Compute inverse MDCTs */
511    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
512
513    for (c=0;c<C;c++)
514    {
515       for (i=0;i<B;i++)
516       {
517          int j;
518          for (j=0;j<N;j++)
519          {
520             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
521             st->preemph_memD[c] = tmp;
522             if (tmp > 32767) tmp = 32767;
523             if (tmp < -32767) tmp = -32767;
524             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
525          }
526       }
527    }
528 }
529
530 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
531 {
532    int i, c, N, B, C;
533    N = st->block_size;
534    B = st->nb_blocks;
535    C = st->mode->nbChannels;
536    
537    float X[C*B*N];         /**< Interleaved signal MDCTs */
538    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
539    float bandE[st->mode->nbEBands*C];
540    float gains[st->mode->nbPBands];
541    int pitch_index;
542    ec_dec dec;
543    ec_byte_buffer buf;
544    
545    if (data == NULL)
546    {
547       celt_decode_lost(st, pcm);
548       return 0;
549    }
550    
551    ec_byte_readinit(&buf,data,len);
552    ec_dec_init(&dec,&buf);
553    
554    /* Get the pitch index */
555    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
556    st->last_pitch_index = pitch_index;
557    
558    /* Get band energies */
559    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
560    
561    /* Pitch MDCT */
562    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
563
564    {
565       float bandEp[st->mode->nbEBands];
566       compute_band_energies(st->mode, P, bandEp);
567       normalise_bands(st->mode, P, bandEp);
568    }
569
570    if (C==2)
571       //haar1(P, B*N*C, 1);
572       stereo_mix(st->mode, P, bandE, 1);
573    time_dct(P, N, B, C);
574
575    /* Get the pitch gains */
576    unquant_pitch(gains, st->mode->nbPBands, &dec);
577
578    /* Apply pitch gains */
579    pitch_quant_bands(st->mode, X, P, gains);
580
581    /* Decode fixed codebook and merge with pitch */
582    unquant_bands(st->mode, X, P, &st->alloc, 770, &dec);
583
584    time_idct(X, N, B, C);
585    if (C==2)
586       //haar1(X, B*N*C, 1);
587       stereo_mix(st->mode, X, bandE, -1);
588
589    renormalise_bands(st->mode, X);
590    
591    /* Synthesis */
592    denormalise_bands(st->mode, X, bandE);
593
594
595    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
596    /* Compute inverse MDCTs */
597    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
598
599    for (c=0;c<C;c++)
600    {
601       for (i=0;i<B;i++)
602       {
603          int j;
604          for (j=0;j<N;j++)
605          {
606             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
607             st->preemph_memD[c] = tmp;
608             if (tmp > 32767) tmp = 32767;
609             if (tmp < -32767) tmp = -32767;
610             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
611          }
612       }
613    }
614    return 0;
615    //printf ("\n");
616 }
617