cleaning up
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "entcode.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44 #include "rate.h"
45
46 #define MAX_PERIOD 1024
47
48
49 struct CELTEncoder {
50    const CELTMode *mode;
51    int frame_size;
52    int block_size;
53    int nb_blocks;
54    int overlap;
55    int channels;
56    int Fs;
57    
58    ec_byte_buffer buf;
59    ec_enc         enc;
60
61    float preemph;
62    float *preemph_memE;
63    float *preemph_memD;
64    
65    mdct_lookup mdct_lookup;
66    void *fft;
67    
68    float *window;
69    float *in_mem;
70    float *mdct_overlap;
71    float *out_mem;
72
73    float *oldBandE;
74    
75    struct alloc_data alloc;
76 };
77
78
79
80 CELTEncoder *celt_encoder_new(const CELTMode *mode)
81 {
82    int i, N, B, C, N4;
83    N = mode->mdctSize;
84    B = mode->nbMdctBlocks;
85    C = mode->nbChannels;
86    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
87    
88    st->mode = mode;
89    st->frame_size = B*N;
90    st->block_size = N;
91    st->nb_blocks  = B;
92    st->overlap = mode->overlap;
93    st->Fs = 44100;
94
95    N4 = (N-st->overlap)/2;
96    ec_byte_writeinit(&st->buf);
97    ec_enc_init(&st->enc,&st->buf);
98
99    mdct_init(&st->mdct_lookup, 2*N);
100    st->fft = spx_fft_init(MAX_PERIOD*C);
101    
102    st->window = celt_alloc(2*N*sizeof(float));
103    st->in_mem = celt_alloc(N*C*sizeof(float));
104    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
105    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
106    for (i=0;i<2*N;i++)
107       st->window[i] = 0;
108    for (i=0;i<st->overlap;i++)
109       st->window[N4+i] = st->window[2*N-N4-i-1] 
110             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
111    for (i=0;i<2*N4;i++)
112       st->window[N-N4+i] = 1;
113    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
114
115    st->preemph = 0.8;
116    st->preemph_memE = celt_alloc(C*sizeof(float));;
117    st->preemph_memD = celt_alloc(C*sizeof(float));;
118
119    alloc_init(&st->alloc, st->mode);
120    return st;
121 }
122
123 void celt_encoder_destroy(CELTEncoder *st)
124 {
125    if (st == NULL)
126    {
127       celt_warning("NULL passed to celt_encoder_destroy");
128       return;
129    }
130    ec_byte_writeclear(&st->buf);
131
132    mdct_clear(&st->mdct_lookup);
133    spx_fft_destroy(st->fft);
134
135    celt_free(st->window);
136    celt_free(st->in_mem);
137    celt_free(st->mdct_overlap);
138    celt_free(st->out_mem);
139    
140    celt_free(st->oldBandE);
141    alloc_clear(&st->alloc);
142
143    celt_free(st);
144 }
145
146
147 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
148 {
149    int i, c;
150    for (c=0;c<C;c++)
151    {
152       for (i=0;i<B;i++)
153       {
154          int j;
155          float x[2*N];
156          float tmp[N];
157          for (j=0;j<2*N;j++)
158             x[j] = window[j]*in[C*i*N+C*j+c];
159          mdct_forward(mdct_lookup, x, tmp);
160          /* Interleaving the sub-frames */
161          for (j=0;j<N;j++)
162             out[C*B*j+C*i+c] = tmp[j];
163       }
164    }
165 }
166
167 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int overlap, int B, int C)
168 {
169    int i, c, N4;
170    N4 = (N-overlap)/2;
171    for (c=0;c<C;c++)
172    {
173       for (i=0;i<B;i++)
174       {
175          int j;
176          float x[2*N];
177          float tmp[N];
178          /* De-interleaving the sub-frames */
179          for (j=0;j<N;j++)
180             tmp[j] = X[C*B*j+C*i+c];
181          mdct_backward(mdct_lookup, tmp, x);
182          for (j=0;j<2*N;j++)
183             x[j] = window[j]*x[j];
184          for (j=0;j<overlap;j++)
185             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[N4+j]+mdct_overlap[C*j+c];
186          for (j=0;j<2*N4;j++)
187             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*(j+overlap)+c] = x[j+N4+overlap];
188          for (j=0;j<overlap;j++)
189             mdct_overlap[C*j+c] = x[N+N4+j];
190       }
191    }
192 }
193
194 int celt_encode(CELTEncoder *st, celt_int16_t *pcm, unsigned char *compressed, int nbCompressedBytes)
195 {
196    int i, c, N, B, C, N4;
197    N = st->block_size;
198    B = st->nb_blocks;
199    C = st->mode->nbChannels;
200    float in[(B+1)*C*N];
201
202    float X[B*C*N];         /**< Interleaved signal MDCTs */
203    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
204    float mask[B*C*N];      /**< Masking curve */
205    float bandE[st->mode->nbEBands*C];
206    float gains[st->mode->nbPBands];
207    int pitch_index;
208
209    N4 = (N-st->overlap)/2;
210
211    for (c=0;c<C;c++)
212    {
213       for (i=0;i<N4;i++)
214          in[C*i+c] = 0;
215       for (i=0;i<st->overlap;i++)
216          in[C*(i+N4)+c] = st->in_mem[C*i+c];
217       for (i=0;i<B*N;i++)
218       {
219          float tmp = pcm[C*i+c];
220          in[C*(i+st->overlap+N4)+c] = tmp - st->preemph*st->preemph_memE[c];
221          st->preemph_memE[c] = tmp;
222       }
223       for (i=N*(B+1)-N4;i<N*(B+1);i++)
224          in[C*i+c] = 0;
225       for (i=0;i<st->overlap;i++)
226          st->in_mem[C*i+c] = in[C*(N*(B+1)-N4-st->overlap+i)+c];
227    }
228    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
229    /* Compute MDCTs */
230    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
231
232    compute_mdct_masking(X, mask, B*C*N, st->Fs);
233
234    /* Invert and stretch the mask to length of X 
235       For some reason, I get better results by using the sqrt instead,
236       although there's no valid reason to. Must investigate further */
237    for (i=0;i<B*C*N;i++)
238       mask[i] = 1/(.1+mask[i]);
239
240    /* Pitch analysis */
241    for (c=0;c<C;c++)
242    {
243       for (i=0;i<N;i++)
244       {
245          in[C*i+c] *= st->window[i];
246          in[C*(B*N+i)+c] *= st->window[N+i];
247       }
248    }
249    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
250    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
251    
252    /* Compute MDCTs of the pitch part */
253    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
254    
255    /*int j;
256    for (j=0;j<B*N;j++)
257       printf ("%f ", X[j]);
258    for (j=0;j<B*N;j++)
259       printf ("%f ", P[j]);
260    printf ("\n");*/
261
262    /* Band normalisation */
263    compute_band_energies(st->mode, X, bandE);
264    normalise_bands(st->mode, X, bandE);
265    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
266    //for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");
267
268    /* Normalise the pitch vector as well (discard the energies) */
269    {
270       float bandEp[st->mode->nbEBands*st->mode->nbChannels];
271       compute_band_energies(st->mode, P, bandEp);
272       normalise_bands(st->mode, P, bandEp);
273    }
274
275    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
276
277    if (C==2)
278    {
279       stereo_mix(st->mode, X, bandE, 1);
280       stereo_mix(st->mode, P, bandE, 1);
281       //haar1(X, B*N*C, 1);
282       //haar1(P, B*N*C, 1);
283    }
284    /* Simulates intensity stereo */
285    //for (i=30;i<N*B;i++)
286    //   X[i*C+1] = P[i*C+1] = 0;
287    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
288
289    /* Pitch prediction */
290    compute_pitch_gain(st->mode, X, P, gains, bandE);
291    quant_pitch(gains, st->mode->nbPBands, &st->enc);
292    pitch_quant_bands(st->mode, X, P, gains);
293
294    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
295    /* Compute residual that we're going to encode */
296    for (i=0;i<B*C*N;i++)
297       X[i] -= P[i];
298
299    /*float sum=0;
300    for (i=0;i<B*N;i++)
301       sum += X[i]*X[i];
302    printf ("%f\n", sum);*/
303    /* Residual quantisation */
304    quant_bands(st->mode, X, P, mask, &st->alloc, nbCompressedBytes*8, &st->enc);
305    
306    if (C==2)
307       stereo_mix(st->mode, X, bandE, -1);
308
309    renormalise_bands(st->mode, X);
310    /* Synthesis */
311    denormalise_bands(st->mode, X, bandE);
312
313
314    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
315
316    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
317    /* De-emphasis and put everything back at the right place in the synthesis history */
318    for (c=0;c<C;c++)
319    {
320       for (i=0;i<B;i++)
321       {
322          int j;
323          for (j=0;j<N;j++)
324          {
325             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
326             st->preemph_memD[c] = tmp;
327             if (tmp > 32767) tmp = 32767;
328             if (tmp < -32767) tmp = -32767;
329             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
330          }
331       }
332    }
333    
334    //printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);
335    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
336    {
337       int val = 0;
338       while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
339       {
340          ec_enc_uint(&st->enc, val, 2);
341          val = 1-val;
342       }
343    }
344    ec_enc_done(&st->enc);
345    {
346       unsigned char *data;
347       int nbBytes = ec_byte_bytes(&st->buf);
348       if (nbBytes != nbCompressedBytes)
349       {
350          if (nbBytes > nbCompressedBytes)
351             celt_warning("got too many bytes");
352          else
353             celt_warning("not enough bytes");
354          return CELT_INTERNAL_ERROR;
355       }
356       //printf ("%d\n", *nbBytes);
357       data = ec_byte_get_buffer(&st->buf);
358       for (i=0;i<nbBytes;i++)
359          compressed[i] = data[i];
360    }
361    /* Reset the packing for the next encoding */
362    ec_byte_reset(&st->buf);
363    ec_enc_init(&st->enc,&st->buf);
364
365    return nbCompressedBytes;
366 }
367
368
369 /****************************************************************************/
370 /*                                                                          */
371 /*                                DECODER                                   */
372 /*                                                                          */
373 /****************************************************************************/
374
375
376
377 struct CELTDecoder {
378    const CELTMode *mode;
379    int frame_size;
380    int block_size;
381    int nb_blocks;
382    int overlap;
383
384    ec_byte_buffer buf;
385    ec_enc         enc;
386
387    float preemph;
388    float *preemph_memD;
389    
390    mdct_lookup mdct_lookup;
391    
392    float *window;
393    float *mdct_overlap;
394    float *out_mem;
395
396    float *oldBandE;
397    
398    int last_pitch_index;
399    
400    struct alloc_data alloc;
401 };
402
403 CELTDecoder *celt_decoder_new(const CELTMode *mode)
404 {
405    int i, N, B, C, N4;
406    N = mode->mdctSize;
407    B = mode->nbMdctBlocks;
408    C = mode->nbChannels;
409    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
410    
411    st->mode = mode;
412    st->frame_size = B*N;
413    st->block_size = N;
414    st->nb_blocks  = B;
415    st->overlap = mode->overlap;
416
417    N4 = (N-st->overlap)/2;
418    
419    mdct_init(&st->mdct_lookup, 2*N);
420    
421    st->window = celt_alloc(2*N*sizeof(float));
422    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
423    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
424
425    for (i=0;i<2*N;i++)
426       st->window[i] = 0;
427    for (i=0;i<st->overlap;i++)
428       st->window[N4+i] = st->window[2*N-N4-i-1] 
429             = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/st->overlap) * sin(.5*M_PI*(i+.5)/st->overlap));
430    for (i=0;i<2*N4;i++)
431       st->window[N-N4+i] = 1;
432    
433    st->oldBandE = celt_alloc(C*mode->nbEBands*sizeof(float));
434
435    st->preemph = 0.8;
436    st->preemph_memD = celt_alloc(C*sizeof(float));;
437
438    st->last_pitch_index = 0;
439    alloc_init(&st->alloc, st->mode);
440
441    return st;
442 }
443
444 void celt_decoder_destroy(CELTDecoder *st)
445 {
446    if (st == NULL)
447    {
448       celt_warning("NULL passed to celt_encoder_destroy");
449       return;
450    }
451
452    mdct_clear(&st->mdct_lookup);
453
454    celt_free(st->window);
455    celt_free(st->mdct_overlap);
456    celt_free(st->out_mem);
457    
458    celt_free(st->oldBandE);
459    alloc_clear(&st->alloc);
460
461    celt_free(st);
462 }
463
464 static void celt_decode_lost(CELTDecoder *st, short *pcm)
465 {
466    int i, c, N, B, C;
467    N = st->block_size;
468    B = st->nb_blocks;
469    C = st->mode->nbChannels;
470    float X[C*B*N];         /**< Interleaved signal MDCTs */
471    int pitch_index;
472    
473    pitch_index = st->last_pitch_index;
474    
475    /* Use the pitch MDCT as the "guessed" signal */
476    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
477
478    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
479    /* Compute inverse MDCTs */
480    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
481
482    for (c=0;c<C;c++)
483    {
484       for (i=0;i<B;i++)
485       {
486          int j;
487          for (j=0;j<N;j++)
488          {
489             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
490             st->preemph_memD[c] = tmp;
491             if (tmp > 32767) tmp = 32767;
492             if (tmp < -32767) tmp = -32767;
493             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
494          }
495       }
496    }
497 }
498
499 int celt_decode(CELTDecoder *st, char *data, int len, celt_int16_t *pcm)
500 {
501    int i, c, N, B, C;
502    N = st->block_size;
503    B = st->nb_blocks;
504    C = st->mode->nbChannels;
505    
506    float X[C*B*N];         /**< Interleaved signal MDCTs */
507    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
508    float bandE[st->mode->nbEBands*C];
509    float gains[st->mode->nbPBands];
510    int pitch_index;
511    ec_dec dec;
512    ec_byte_buffer buf;
513    
514    if (data == NULL)
515    {
516       celt_decode_lost(st, pcm);
517       return 0;
518    }
519    
520    ec_byte_readinit(&buf,data,len);
521    ec_dec_init(&dec,&buf);
522    
523    /* Get the pitch index */
524    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
525    st->last_pitch_index = pitch_index;
526    
527    /* Get band energies */
528    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
529    
530    /* Pitch MDCT */
531    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
532
533    {
534       float bandEp[st->mode->nbEBands];
535       compute_band_energies(st->mode, P, bandEp);
536       normalise_bands(st->mode, P, bandEp);
537    }
538
539    if (C==2)
540       stereo_mix(st->mode, P, bandE, 1);
541
542    /* Get the pitch gains */
543    unquant_pitch(gains, st->mode->nbPBands, &dec);
544
545    /* Apply pitch gains */
546    pitch_quant_bands(st->mode, X, P, gains);
547
548    /* Decode fixed codebook and merge with pitch */
549    unquant_bands(st->mode, X, P, &st->alloc, len*8, &dec);
550
551    if (C==2)
552       stereo_mix(st->mode, X, bandE, -1);
553
554    renormalise_bands(st->mode, X);
555    
556    /* Synthesis */
557    denormalise_bands(st->mode, X, bandE);
558
559
560    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
561    /* Compute inverse MDCTs */
562    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, st->overlap, B, C);
563
564    for (c=0;c<C;c++)
565    {
566       for (i=0;i<B;i++)
567       {
568          int j;
569          for (j=0;j<N;j++)
570          {
571             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
572             st->preemph_memD[c] = tmp;
573             if (tmp > 32767) tmp = 32767;
574             if (tmp < -32767) tmp = -32767;
575             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
576          }
577       }
578    }
579
580    {
581       int val = 0;
582       while (ec_dec_tell(&dec, 0) < len*8)
583       {
584          if (ec_dec_uint(&dec, 2) != val)
585          {
586             celt_warning("decode error");
587             return CELT_CORRUPTED_DATA;
588          }
589          val = 1-val;
590       }
591    }
592
593    return 0;
594    //printf ("\n");
595 }
596