Updated pulse coding to simpler (slightly faster) code included with
[opus.git] / libcelt / celt.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include "os_support.h"
33 #include "mdct.h"
34 #include <math.h>
35 #include "celt.h"
36 #include "pitch.h"
37 #include "fftwrap.h"
38 #include "bands.h"
39 #include "modes.h"
40 #include "probenc.h"
41 #include "quant_pitch.h"
42 #include "quant_bands.h"
43 #include "psy.h"
44
45 #define MAX_PERIOD 1024
46
47
48 struct CELTEncoder {
49    const CELTMode *mode;
50    int frame_size;
51    int block_size;
52    int nb_blocks;
53    int channels;
54    int Fs;
55    
56    ec_byte_buffer buf;
57    ec_enc         enc;
58
59    float preemph;
60    float *preemph_memE;
61    float *preemph_memD;
62    
63    mdct_lookup mdct_lookup;
64    void *fft;
65    
66    float *window;
67    float *in_mem;
68    float *mdct_overlap;
69    float *out_mem;
70
71    float *oldBandE;
72 };
73
74
75
76 CELTEncoder *celt_encoder_new(const CELTMode *mode)
77 {
78    int i, N, B, C;
79    N = mode->mdctSize;
80    B = mode->nbMdctBlocks;
81    C = mode->nbChannels;
82    CELTEncoder *st = celt_alloc(sizeof(CELTEncoder));
83    
84    st->mode = mode;
85    st->frame_size = B*N;
86    st->block_size = N;
87    st->nb_blocks  = B;
88    st->Fs = 44100;
89    
90    ec_byte_writeinit(&st->buf);
91    ec_enc_init(&st->enc,&st->buf);
92
93    mdct_init(&st->mdct_lookup, 2*N);
94    st->fft = spx_fft_init(MAX_PERIOD*C);
95    
96    st->window = celt_alloc(2*N*sizeof(float));
97    st->in_mem = celt_alloc(N*C*sizeof(float));
98    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
99    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
100    for (i=0;i<N;i++)
101       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
102    
103    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
104
105    st->preemph = 0.8;
106    st->preemph_memE = celt_alloc(C*sizeof(float));;
107    st->preemph_memD = celt_alloc(C*sizeof(float));;
108
109    return st;
110 }
111
112 void celt_encoder_destroy(CELTEncoder *st)
113 {
114    if (st == NULL)
115    {
116       celt_warning("NULL passed to celt_encoder_destroy");
117       return;
118    }
119    ec_byte_writeclear(&st->buf);
120
121    mdct_clear(&st->mdct_lookup);
122    spx_fft_destroy(st->fft);
123
124    celt_free(st->window);
125    celt_free(st->in_mem);
126    celt_free(st->mdct_overlap);
127    celt_free(st->out_mem);
128    
129    celt_free(st->oldBandE);
130    celt_free(st);
131 }
132
133 static void haar1(float *X, int N, int stride)
134 {
135    int i, k;
136    for (k=0;k<stride;k++)
137    {
138       for (i=k;i<N*stride;i+=2*stride)
139       {
140          float a, b;
141          a = X[i];
142          b = X[i+stride];
143          X[i] = .707107f*(a+b);
144          X[i+stride] = .707107f*(a-b);
145       }
146    }
147 }
148
149 static void time_dct(float *X, int N, int B, int stride)
150 {
151    switch (B)
152    {
153       case 1:
154          break;
155       case 2:
156          haar1(X, B*N, stride);
157          break;
158       default:
159          celt_warning("time_dct not defined for B > 2");
160    };
161 }
162
163 static void time_idct(float *X, int N, int B, int stride)
164 {
165    switch (B)
166    {
167       case 1:
168          break;
169       case 2:
170          haar1(X, B*N, stride);
171          break;
172       default:
173          celt_warning("time_dct not defined for B > 2");
174    };
175 }
176
177 static void compute_mdcts(mdct_lookup *mdct_lookup, float *window, float *in, float *out, int N, int B, int C)
178 {
179    int i, c;
180    for (c=0;c<C;c++)
181    {
182       for (i=0;i<B;i++)
183       {
184          int j;
185          float x[2*N];
186          float tmp[N];
187          for (j=0;j<2*N;j++)
188             x[j] = window[j]*in[C*i*N+C*j+c];
189          mdct_forward(mdct_lookup, x, tmp);
190          /* Interleaving the sub-frames */
191          for (j=0;j<N;j++)
192             out[C*B*j+C*i+c] = tmp[j];
193       }
194    }
195 }
196
197 static void compute_inv_mdcts(mdct_lookup *mdct_lookup, float *window, float *X, float *out_mem, float *mdct_overlap, int N, int B, int C)
198 {
199    int i, c;
200    for (c=0;c<C;c++)
201    {
202       for (i=0;i<B;i++)
203       {
204          int j;
205          float x[2*N];
206          float tmp[N];
207          /* De-interleaving the sub-frames */
208          for (j=0;j<N;j++)
209             tmp[j] = X[C*B*j+C*i+c];
210          mdct_backward(mdct_lookup, tmp, x);
211          for (j=0;j<2*N;j++)
212             x[j] = window[j]*x[j];
213          for (j=0;j<N;j++)
214             out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] = x[j]+mdct_overlap[C*j+c];
215          for (j=0;j<N;j++)
216             mdct_overlap[C*j+c] = x[N+j];
217       }
218    }
219 }
220
221 int celt_encode(CELTEncoder *st, short *pcm)
222 {
223    int i, c, N, B, C;
224    N = st->block_size;
225    B = st->nb_blocks;
226    C = st->mode->nbChannels;
227    float in[(B+1)*C*N];
228    
229    float X[B*C*N];         /**< Interleaved signal MDCTs */
230    float P[B*C*N];         /**< Interleaved pitch MDCTs*/
231    float mask[B*C*N];      /**< Masking curve */
232    float bandE[st->mode->nbEBands];
233    float gains[st->mode->nbPBands];
234    int pitch_index;
235    
236    for (c=0;c<C;c++)
237    {
238       for (i=0;i<N;i++)
239          in[C*i+c] = st->in_mem[C*i+c];
240       for (;i<(B+1)*N;i++)
241       {
242          float tmp = pcm[C*(i-N)+c];
243          in[C*i+c] = tmp - st->preemph*st->preemph_memE[c];
244          st->preemph_memE[c] = tmp;
245       }
246       for (i=0;i<N;i++)
247          st->in_mem[C*i+c] = in[C*(B*N+i)+c];
248    }
249    //for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");
250    /* Compute MDCTs */
251    compute_mdcts(&st->mdct_lookup, st->window, in, X, N, B, C);
252    
253    compute_masking(X, mask, B*C*N, st->Fs);
254    /* Invert and stretch the mask to length of X */
255    for (i=B*C*N-1;i>=0;i--)
256       mask[i] = 1/(1+mask[i>>1]);
257    
258    /* Pitch analysis */
259    for (c=0;c<C;c++)
260    {
261       for (i=0;i<N;i++)
262       {
263          in[C*i+c] *= st->window[i];
264          in[C*(B*N+i)+c] *= st->window[N+i];
265       }
266    }
267    find_spectral_pitch(st->fft, in, st->out_mem, MAX_PERIOD, (B+1)*N, C, &pitch_index);
268    ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(B+1)*N);
269    
270    /* Compute MDCTs of the pitch part */
271    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
272    
273    /*int j;
274    for (j=0;j<B*N;j++)
275       printf ("%f ", X[j]);
276    for (j=0;j<B*N;j++)
277       printf ("%f ", P[j]);
278    printf ("\n");*/
279    if (C==2)
280    {
281       haar1(X, B*N*C, 1);
282       haar1(P, B*N*C, 1);
283    }
284    
285    /* Get a tiny bit more frequency resolution and prevent unstable energy when quantising */
286    time_dct(X, N, B, C);
287    time_dct(P, N, B, C);
288
289    /* Band normalisation */
290    compute_band_energies(st->mode, X, bandE);
291    normalise_bands(st->mode, X, bandE);
292    //for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");
293
294    /* Normalise the pitch vector as well (discard the energies) */
295    {
296       float bandEp[st->mode->nbEBands];
297       compute_band_energies(st->mode, P, bandEp);
298       normalise_bands(st->mode, P, bandEp);
299    }
300
301    quant_energy(st->mode, bandE, st->oldBandE, &st->enc);
302
303    /* Pitch prediction */
304    compute_pitch_gain(st->mode, X, P, gains, bandE);
305    quant_pitch(gains, st->mode->nbPBands, &st->enc);
306    pitch_quant_bands(st->mode, X, P, gains);
307
308    //for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");
309    /* Compute residual that we're going to encode */
310    for (i=0;i<B*C*N;i++)
311       X[i] -= P[i];
312
313    /*float sum=0;
314    for (i=0;i<B*N;i++)
315       sum += X[i]*X[i];
316    printf ("%f\n", sum);*/
317    /* Residual quantisation */
318    quant_bands(st->mode, X, P, mask, &st->enc);
319
320    /* Synthesis */
321    denormalise_bands(st->mode, X, bandE);
322
323    time_idct(X, N, B, C);
324    if (C==2)
325       haar1(X, B*N*C, 1);
326
327    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
328
329    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
330    /* De-emphasis and put everything back at the right place in the synthesis history */
331    for (c=0;c<C;c++)
332    {
333       for (i=0;i<B;i++)
334       {
335          int j;
336          for (j=0;j<N;j++)
337          {
338             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
339             st->preemph_memD[c] = tmp;
340             if (tmp > 32767) tmp = 32767;
341             if (tmp < -32767) tmp = -32767;
342             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
343          }
344       }
345    }
346    return 0;
347 }
348
349 char *celt_encoder_get_bytes(CELTEncoder *st, int *nbBytes)
350 {
351    char *data;
352    ec_enc_done(&st->enc);
353    *nbBytes = ec_byte_bytes(&st->buf);
354    data = ec_byte_get_buffer(&st->buf);
355    //printf ("%d\n", *nbBytes);
356    
357    /* Reset the packing for the next encoding */
358    ec_byte_reset(&st->buf);
359    ec_enc_init(&st->enc,&st->buf);
360
361    return data;
362 }
363
364
365 /****************************************************************************/
366 /*                                                                          */
367 /*                                DECODER                                   */
368 /*                                                                          */
369 /****************************************************************************/
370
371
372
373 struct CELTDecoder {
374    const CELTMode *mode;
375    int frame_size;
376    int block_size;
377    int nb_blocks;
378    
379    ec_byte_buffer buf;
380    ec_enc         enc;
381
382    float preemph;
383    float *preemph_memD;
384    
385    mdct_lookup mdct_lookup;
386    
387    float *window;
388    float *mdct_overlap;
389    float *out_mem;
390
391    float *oldBandE;
392    
393    int last_pitch_index;
394 };
395
396 CELTDecoder *celt_decoder_new(const CELTMode *mode)
397 {
398    int i, N, B, C;
399    N = mode->mdctSize;
400    B = mode->nbMdctBlocks;
401    C = mode->nbChannels;
402    CELTDecoder *st = celt_alloc(sizeof(CELTDecoder));
403    
404    st->mode = mode;
405    st->frame_size = B*N;
406    st->block_size = N;
407    st->nb_blocks  = B;
408    
409    mdct_init(&st->mdct_lookup, 2*N);
410    
411    st->window = celt_alloc(2*N*sizeof(float));
412    st->mdct_overlap = celt_alloc(N*C*sizeof(float));
413    st->out_mem = celt_alloc(MAX_PERIOD*C*sizeof(float));
414    for (i=0;i<N;i++)
415       st->window[i] = st->window[2*N-i-1] = sin(.5*M_PI* sin(.5*M_PI*(i+.5)/N) * sin(.5*M_PI*(i+.5)/N));
416    
417    st->oldBandE = celt_alloc(mode->nbEBands*sizeof(float));
418
419    st->preemph = 0.8;
420    st->preemph_memD = celt_alloc(C*sizeof(float));;
421
422    st->last_pitch_index = 0;
423    return st;
424 }
425
426 void celt_decoder_destroy(CELTDecoder *st)
427 {
428    if (st == NULL)
429    {
430       celt_warning("NULL passed to celt_encoder_destroy");
431       return;
432    }
433
434    mdct_clear(&st->mdct_lookup);
435
436    celt_free(st->window);
437    celt_free(st->mdct_overlap);
438    celt_free(st->out_mem);
439    
440    celt_free(st->oldBandE);
441    celt_free(st);
442 }
443
444 static void celt_decode_lost(CELTDecoder *st, short *pcm)
445 {
446    int i, c, N, B, C;
447    N = st->block_size;
448    B = st->nb_blocks;
449    C = st->mode->nbChannels;
450    float X[C*B*N];         /**< Interleaved signal MDCTs */
451    int pitch_index;
452    
453    pitch_index = st->last_pitch_index;
454    
455    /* Use the pitch MDCT as the "guessed" signal */
456    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, X, N, B, C);
457
458    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
459    /* Compute inverse MDCTs */
460    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
461
462    for (c=0;c<C;c++)
463    {
464       for (i=0;i<B;i++)
465       {
466          int j;
467          for (j=0;j<N;j++)
468          {
469             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
470             st->preemph_memD[c] = tmp;
471             if (tmp > 32767) tmp = 32767;
472             if (tmp < -32767) tmp = -32767;
473             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
474          }
475       }
476    }
477 }
478
479 int celt_decode(CELTDecoder *st, char *data, int len, short *pcm)
480 {
481    int i, c, N, B, C;
482    N = st->block_size;
483    B = st->nb_blocks;
484    C = st->mode->nbChannels;
485    
486    float X[C*B*N];         /**< Interleaved signal MDCTs */
487    float P[C*B*N];         /**< Interleaved pitch MDCTs*/
488    float bandE[st->mode->nbEBands];
489    float gains[st->mode->nbPBands];
490    int pitch_index;
491    ec_dec dec;
492    ec_byte_buffer buf;
493    
494    if (data == NULL)
495    {
496       celt_decode_lost(st, pcm);
497       return 0;
498    }
499    
500    ec_byte_readinit(&buf,data,len);
501    ec_dec_init(&dec,&buf);
502    
503    /* Get the pitch index */
504    pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(B+1)*N);
505    st->last_pitch_index = pitch_index;
506    
507    /* Get band energies */
508    unquant_energy(st->mode, bandE, st->oldBandE, &dec);
509    
510    /* Pitch MDCT */
511    compute_mdcts(&st->mdct_lookup, st->window, st->out_mem+pitch_index*C, P, N, B, C);
512
513    if (C==2)
514       haar1(P, B*N*C, 1);
515    time_dct(P, N, B, C);
516
517    {
518       float bandEp[st->mode->nbEBands];
519       compute_band_energies(st->mode, P, bandEp);
520       normalise_bands(st->mode, P, bandEp);
521    }
522
523    /* Get the pitch gains */
524    unquant_pitch(gains, st->mode->nbPBands, &dec);
525
526    /* Apply pitch gains */
527    pitch_quant_bands(st->mode, X, P, gains);
528
529    /* Decode fixed codebook and merge with pitch */
530    unquant_bands(st->mode, X, P, &dec);
531
532    /* Synthesis */
533    denormalise_bands(st->mode, X, bandE);
534
535    time_idct(X, N, B, C);
536    if (C==2)
537       haar1(X, B*N*C, 1);
538
539    CELT_MOVE(st->out_mem, st->out_mem+C*B*N, C*(MAX_PERIOD-B*N));
540    /* Compute inverse MDCTs */
541    compute_inv_mdcts(&st->mdct_lookup, st->window, X, st->out_mem, st->mdct_overlap, N, B, C);
542
543    for (c=0;c<C;c++)
544    {
545       for (i=0;i<B;i++)
546       {
547          int j;
548          for (j=0;j<N;j++)
549          {
550             float tmp = st->out_mem[C*(MAX_PERIOD+(i-B)*N)+C*j+c] + st->preemph*st->preemph_memD[c];
551             st->preemph_memD[c] = tmp;
552             if (tmp > 32767) tmp = 32767;
553             if (tmp < -32767) tmp = -32767;
554             pcm[C*i*N+C*j+c] = (short)floor(.5+tmp);
555          }
556       }
557    }
558    return 0;
559    //printf ("\n");
560 }
561