Fixes a bug in the init() functions where were weren't zeroing the entire state
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         memset(st, 0, opus_decoder_get_size(channels));
71         /* Initialize SILK encoder */
72     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
73     if( ret ) {
74         return NULL;
75     }
76     silkDecSizeBytes = align(silkDecSizeBytes);
77     st->silk_dec_offset = align(sizeof(OpusDecoder));
78     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
79     silk_dec = (char*)st+st->silk_dec_offset;
80     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
81     st->stream_channels = st->channels = channels;
82
83     st->Fs = Fs;
84
85     /* Reset decoder */
86     ret = silk_InitDecoder( silk_dec );
87     if( ret ) {
88         goto failure;
89     }
90
91         /* Initialize CELT decoder */
92         celt_decoder_init(celt_dec, Fs, channels, &ret);
93         if (ret != CELT_OK)
94                 goto failure;
95     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
96
97         st->prev_mode = 0;
98         return st;
99 failure:
100     free(st);
101     return NULL;
102 }
103
104 OpusDecoder *opus_decoder_create(int Fs, int channels)
105 {
106     char *raw_state = malloc(opus_decoder_get_size(channels));
107     if (raw_state == NULL)
108         return NULL;
109     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
110 }
111
112 static void smooth_fade(const short *in1, const short *in2, short *out,
113         int overlap, int channels, const celt_word16 *window, int Fs)
114 {
115         int i, c;
116         int inc = 48000/Fs;
117         for (c=0;c<channels;c++)
118         {
119                 for (i=0;i<overlap;i++)
120                 {
121                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
122                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
123                             Q15ONE-w, in1[i*channels+c]), 15);
124                 }
125         }
126 }
127
128 static int opus_packet_get_mode(const unsigned char *data)
129 {
130         int mode;
131     if (data[0]&0x80)
132     {
133         mode = MODE_CELT_ONLY;
134     } else if ((data[0]&0x60) == 0x60)
135     {
136         mode = MODE_HYBRID;
137     } else {
138
139         mode = MODE_SILK_ONLY;
140     }
141     return mode;
142 }
143
144 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
145                 int len, short *pcm, int frame_size, int decode_fec)
146 {
147         void *silk_dec;
148         CELTDecoder *celt_dec;
149         int i, silk_ret=0, celt_ret=0;
150         ec_dec dec;
151     silk_DecControlStruct DecControl;
152     SKP_int32 silk_frame_size;
153     short pcm_celt[960*2];
154     short pcm_transition[960*2];
155     int audiosize;
156     int mode;
157     int transition=0;
158     int start_band;
159     int redundancy=0;
160     int redundancy_bytes = 0;
161     int celt_to_silk=0;
162     short redundant_audio[240*2];
163     int c;
164     int F2_5, F5, F10;
165     const celt_word16 *window;
166
167     silk_dec = (char*)st+st->silk_dec_offset;
168     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
169     F10 = st->Fs/100;
170     F5 = F10>>1;
171     F2_5 = F5>>1;
172     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
173     if (len<=1)
174         data = NULL;
175
176         audiosize = st->frame_size;
177     if (data != NULL)
178     {
179         mode = st->mode;
180         ec_dec_init(&dec,(unsigned char*)data,len);
181     } else {
182         mode = st->prev_mode;
183     }
184
185     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
186                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
187                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
188     {
189         transition = 1;
190         if (mode == MODE_CELT_ONLY)
191             opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
192     }
193     if (audiosize > frame_size)
194     {
195         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
196         return OPUS_BAD_ARG;
197     } else {
198         frame_size = audiosize;
199     }
200
201     /* SILK processing */
202     if (mode != MODE_CELT_ONLY)
203     {
204         int lost_flag, decoded_samples;
205         SKP_int16 *pcm_ptr = pcm;
206
207         if (st->prev_mode==MODE_CELT_ONLY)
208                 silk_InitDecoder( silk_dec );
209
210         DecControl.API_sampleRate = st->Fs;
211         DecControl.nChannelsAPI      = st->channels;
212         DecControl.nChannelsInternal = st->stream_channels;
213         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
214         if( mode == MODE_SILK_ONLY ) {
215             if( st->bandwidth == BANDWIDTH_NARROWBAND ) {
216                 DecControl.internalSampleRate = 8000;
217             } else if( st->bandwidth == BANDWIDTH_MEDIUMBAND ) {
218                 DecControl.internalSampleRate = 12000;
219             } else if( st->bandwidth == BANDWIDTH_WIDEBAND ) {
220                 DecControl.internalSampleRate = 16000;
221             } else {
222                 DecControl.internalSampleRate = 16000;
223                 SKP_assert( 0 );
224             }
225         } else {
226             /* Hybrid mode */
227             DecControl.internalSampleRate = 16000;
228         }
229
230         lost_flag = data == NULL ? 1 : 2 * decode_fec;
231         decoded_samples = 0;
232         do {
233             /* Call SILK decoder */
234             int first_frame = decoded_samples == 0;
235             silk_ret = silk_Decode( silk_dec, &DecControl, 
236                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
237             if( silk_ret ) {
238                 fprintf (stderr, "SILK decode error\n");
239                 /* Handle error */
240             }
241             pcm_ptr += silk_frame_size * st->channels;
242             decoded_samples += silk_frame_size;
243         } while( decoded_samples < frame_size );
244     } else {
245         for (i=0;i<frame_size*st->channels;i++)
246             pcm[i] = 0;
247     }
248
249     start_band = 0;
250     if (mode != MODE_CELT_ONLY && data != NULL)
251     {
252         /* Check if we have a redundant 0-8 kHz band */
253         redundancy = ec_dec_bit_logp(&dec, 12);
254         if (redundancy)
255         {
256             celt_to_silk = ec_dec_bit_logp(&dec, 1);
257             if (mode == MODE_HYBRID)
258                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
259             else
260                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
261             len -= redundancy_bytes;
262             if (len<0)
263                 return CELT_CORRUPTED_DATA;
264             /* Shrink decoder because of raw bits */
265             dec.storage -= redundancy_bytes;
266         }
267     }
268     if (mode != MODE_CELT_ONLY)
269         start_band = 17;
270
271     if (mode != MODE_SILK_ONLY)
272     {
273         int endband=21;
274
275         switch(st->bandwidth)
276         {
277         case BANDWIDTH_NARROWBAND:
278             endband = 13;
279             break;
280         case BANDWIDTH_WIDEBAND:
281             endband = 17;
282             break;
283         case BANDWIDTH_SUPERWIDEBAND:
284             endband = 19;
285             break;
286         case BANDWIDTH_FULLBAND:
287             endband = 21;
288             break;
289         }
290         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
291         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
292     }
293
294     if (redundancy)
295         transition = 0;
296
297     if (transition && mode != MODE_CELT_ONLY)
298         opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
299
300     /* 5 ms redundant frame for CELT->SILK*/
301     if (redundancy && celt_to_silk)
302     {
303         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
304         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
305     }
306
307     /* MUST be after PLC */
308     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
309
310     if (transition)
311         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
312
313     if (mode != MODE_SILK_ONLY)
314     {
315         /* Decode CELT */
316         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, frame_size, &dec);
317         for (i=0;i<frame_size*st->channels;i++)
318             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
319     }
320
321     {
322         const CELTMode *celt_mode;
323         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
324         window = celt_mode->window;
325     }
326
327     /* 5 ms redundant frame for SILK->CELT */
328     if (redundancy && !celt_to_silk)
329     {
330         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
331         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
332
333         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
334         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
335                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
336     }
337     if (redundancy && celt_to_silk)
338     {
339         for (c=0;c<st->channels;c++)
340         {
341             for (i=0;i<F2_5;i++)
342                 pcm[st->channels*i+c] = redundant_audio[st->channels*i];
343         }
344         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
345                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
346     }
347     if (transition)
348     {
349         for (i=0;i<F2_5;i++)
350                 pcm[i] = pcm_transition[i];
351         if (audiosize >= F5)
352             smooth_fade(pcm_transition+F2_5, pcm+F2_5, pcm+F2_5, F2_5,
353                     st->channels, window, st->Fs);
354     }
355 #if OPUS_TEST_RANGE_CODER_STATE
356     st->rangeFinal = dec.rng;
357 #endif
358
359     st->prev_mode = mode;
360     st->prev_redundancy = redundancy;
361         return celt_ret<0 ? celt_ret : audiosize;
362
363 }
364
365 static int parse_size(const unsigned char *data, int len, short *size)
366 {
367         if (len<1)
368         {
369                 *size = -1;
370                 return -1;
371         } else if (data[0]<252)
372         {
373                 *size = data[0];
374                 return 1;
375         } else if (len<2)
376         {
377                 *size = -1;
378                 return -1;
379         } else {
380                 *size = 4*data[1] + data[0];
381                 return 2;
382         }
383 }
384
385 int opus_decode(OpusDecoder *st, const unsigned char *data,
386                 int len, short *pcm, int frame_size, int decode_fec)
387 {
388         int i, bytes, nb_samples;
389         int count;
390         unsigned char ch, toc;
391         /* 48 x 2.5 ms = 120 ms */
392         short size[48];
393         if (len==0 || data==NULL)
394             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
395         else if (len<0)
396                 return CELT_BAD_ARG;
397         st->mode = opus_packet_get_mode(data);
398         st->bandwidth = opus_packet_get_bandwidth(data);
399         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
400         st->stream_channels = opus_packet_get_nb_channels(data);
401         toc = *data++;
402         len--;
403         switch (toc&0x3)
404         {
405         /* One frame */
406         case 0:
407                 count=1;
408                 size[0] = len;
409                 break;
410                 /* Two CBR frames */
411         case 1:
412                 count=2;
413                 if (len&0x1)
414                         return OPUS_CORRUPTED_DATA;
415                 size[0] = size[1] = len/2;
416                 break;
417                 /* Two VBR frames */
418         case 2:
419                 count = 2;
420                 bytes = parse_size(data, len, size);
421                 len -= bytes;
422                 if (size[0]<0 || size[0] > len)
423                         return OPUS_CORRUPTED_DATA;
424                 data += bytes;
425                 size[1] = len-size[0];
426                 break;
427                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
428         case 3:
429                 if (len<1)
430                         return OPUS_CORRUPTED_DATA;
431                 /* Number of frames encoded in bits 0 to 5 */
432                 ch = *data++;
433                 count = ch&0x3F;
434                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
435                     return OPUS_CORRUPTED_DATA;
436                 len--;
437                 /* Padding bit */
438                 if (ch&0x40)
439                 {
440                         int padding=0;
441                         int p;
442                         do {
443                                 if (len<=0)
444                                         return OPUS_CORRUPTED_DATA;
445                                 p = *data++;
446                                 len--;
447                                 padding += p==255 ? 254: p;
448                         } while (p==255);
449                         len -= padding;
450                 }
451                 if (len<0)
452                         return OPUS_CORRUPTED_DATA;
453                 /* Bit 7 is VBR flag (bit 6 is ignored) */
454                 if (ch&0x80)
455                 {
456                         /* VBR case */
457                         int last_size=len;
458                         for (i=0;i<count-1;i++)
459                         {
460                                 bytes = parse_size(data, len, size+i);
461                                 len -= bytes;
462                                 if (size[i]<0 || size[i] > len)
463                                         return OPUS_CORRUPTED_DATA;
464                                 data += bytes;
465                                 last_size -= bytes+size[i];
466                         }
467                         if (last_size<0)
468                                 return OPUS_CORRUPTED_DATA;
469                         size[count-1]=last_size;
470                 } else {
471                         /* CBR case */
472                         int sz = len/count;
473                         if (sz*count!=len)
474                                 return OPUS_CORRUPTED_DATA;
475                         for (i=0;i<count;i++)
476                                 size[i] = sz;
477                 }
478                 break;
479         }
480         /* Because it's not encoded explicitly, it's possible the size of the
481             last packet (or all the packets, for the CBR case) is larger than
482             1275.
483            Reject them here.*/
484         if (size[count-1] > MAX_PACKET)
485                 return OPUS_CORRUPTED_DATA;
486         if (count*st->frame_size > frame_size)
487                 return OPUS_BAD_ARG;
488         nb_samples=0;
489         for (i=0;i<count;i++)
490         {
491                 int ret;
492                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
493                 if (ret<0)
494                         return ret;
495                 data += size[i];
496                 pcm += ret;
497                 nb_samples += ret;
498         }
499         return nb_samples;
500 }
501 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
502 {
503     va_list ap;
504
505     va_start(ap, request);
506
507     switch (request)
508     {
509         case OPUS_GET_MODE_REQUEST:
510         {
511             int *value = va_arg(ap, int*);
512             *value = st->prev_mode;
513         }
514         break;
515         case OPUS_SET_BANDWIDTH_REQUEST:
516         {
517             int value = va_arg(ap, int);
518             st->bandwidth = value;
519         }
520         break;
521         case OPUS_GET_BANDWIDTH_REQUEST:
522         {
523             int *value = va_arg(ap, int*);
524             *value = st->bandwidth;
525         }
526         break;
527         default:
528             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
529             break;
530     }
531
532     va_end(ap);
533     return OPUS_OK;
534 }
535
536 void opus_decoder_destroy(OpusDecoder *st)
537 {
538         free(st);
539 }
540
541 #if OPUS_TEST_RANGE_CODER_STATE
542 int opus_decoder_get_final_range(OpusDecoder *st)
543 {
544     return st->rangeFinal;
545 }
546 #endif
547
548
549 int opus_packet_get_bandwidth(const unsigned char *data)
550 {
551         int bandwidth;
552     if (data[0]&0x80)
553     {
554         bandwidth = BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
555         if (bandwidth == BANDWIDTH_MEDIUMBAND)
556             bandwidth = BANDWIDTH_NARROWBAND;
557     } else if ((data[0]&0x60) == 0x60)
558     {
559         bandwidth = (data[0]&0x10) ? BANDWIDTH_FULLBAND : BANDWIDTH_SUPERWIDEBAND;
560     } else {
561
562         bandwidth = BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
563     }
564     return bandwidth;
565 }
566
567 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
568 {
569         int audiosize;
570     if (data[0]&0x80)
571     {
572         audiosize = ((data[0]>>3)&0x3);
573         audiosize = (Fs<<audiosize)/400;
574     } else if ((data[0]&0x60) == 0x60)
575     {
576         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
577     } else {
578
579         audiosize = ((data[0]>>3)&0x3);
580         if (audiosize == 3)
581             audiosize = Fs*60/1000;
582         else
583             audiosize = (Fs<<audiosize)/100;
584     }
585     return audiosize;
586 }
587
588 int opus_packet_get_nb_channels(const unsigned char *data)
589 {
590     return (data[0]&0x4) ? 2 : 1;
591 }
592
593 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
594 {
595         int count;
596         if (len<1)
597                 return OPUS_BAD_ARG;
598         count = packet[0]&0x3;
599         if (count==0)
600                 return 1;
601         else if (count!=3)
602                 return 2;
603         else if (len<2)
604                 return OPUS_CORRUPTED_DATA;
605         else
606                 return packet[1]&0x3F;
607 }
608
609 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
610 {
611         int samples;
612         int count = opus_packet_get_nb_frames(packet, len);
613         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
614         /* Can't have more than 120 ms */
615         if (samples*25 > dec->Fs*3)
616                 return OPUS_CORRUPTED_DATA;
617         else
618                 return samples;
619 }
620