More decoder corner case fixes
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         if (channels<1 || channels > 2)
71             return NULL;
72         memset(st, 0, opus_decoder_get_size(channels));
73         /* Initialize SILK encoder */
74     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
75     if( ret ) {
76         return NULL;
77     }
78     silkDecSizeBytes = align(silkDecSizeBytes);
79     st->silk_dec_offset = align(sizeof(OpusDecoder));
80     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
81     silk_dec = (char*)st+st->silk_dec_offset;
82     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
83     st->stream_channels = st->channels = channels;
84
85     st->Fs = Fs;
86
87     /* Reset decoder */
88     ret = silk_InitDecoder( silk_dec );
89     if( ret ) {
90         goto failure;
91     }
92
93         /* Initialize CELT decoder */
94         celt_decoder_init(celt_dec, Fs, channels, &ret);
95         if (ret != CELT_OK)
96                 goto failure;
97     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
98
99         st->prev_mode = 0;
100         return st;
101 failure:
102     free(st);
103     return NULL;
104 }
105
106 OpusDecoder *opus_decoder_create(int Fs, int channels)
107 {
108     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
109     if (raw_state == NULL)
110         return NULL;
111     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
112 }
113
114 static void smooth_fade(const short *in1, const short *in2, short *out,
115         int overlap, int channels, const celt_word16 *window, int Fs)
116 {
117         int i, c;
118         int inc = 48000/Fs;
119         for (c=0;c<channels;c++)
120         {
121                 for (i=0;i<overlap;i++)
122                 {
123                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
124                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
125                             Q15ONE-w, in1[i*channels+c]), 15);
126                 }
127         }
128 }
129
130 static int opus_packet_get_mode(const unsigned char *data)
131 {
132         int mode;
133     if (data[0]&0x80)
134     {
135         mode = MODE_CELT_ONLY;
136     } else if ((data[0]&0x60) == 0x60)
137     {
138         mode = MODE_HYBRID;
139     } else {
140
141         mode = MODE_SILK_ONLY;
142     }
143     return mode;
144 }
145
146 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
147                 int len, short *pcm, int frame_size, int decode_fec)
148 {
149         void *silk_dec;
150         CELTDecoder *celt_dec;
151         int i, silk_ret=0, celt_ret=0;
152         ec_dec dec;
153     silk_DecControlStruct DecControl;
154     SKP_int32 silk_frame_size;
155     short pcm_celt[960*2];
156     short pcm_transition[480*2];
157
158     int audiosize;
159     int mode;
160     int transition=0;
161     int start_band;
162     int redundancy=0;
163     int redundancy_bytes = 0;
164     int celt_to_silk=0;
165     short redundant_audio[240*2];
166     int c;
167     int F2_5, F5, F10, F20;
168     const celt_word16 *window;
169
170     silk_dec = (char*)st+st->silk_dec_offset;
171     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
172     F20 = st->Fs/50;
173     F10 = F20>>1;
174     F5 = F10>>1;
175     F2_5 = F5>>1;
176     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
177     if (len<=1)
178     {
179         data = NULL;
180         /* In that case, don't conceal more than what the ToC says */
181         frame_size = IMIN(frame_size, st->frame_size);
182     }
183     if (data != NULL)
184     {
185         audiosize = st->frame_size;
186         mode = st->mode;
187         ec_dec_init(&dec,(unsigned char*)data,len);
188     } else {
189         audiosize = frame_size;
190         if (st->prev_mode == 0)
191         {
192                 /* If we haven't got any packet yet, all we can do is return zeros */
193                 for (i=0;i<audiosize;i++)
194                         pcm[i] = 0;
195                 return audiosize;
196         } else {
197                 mode = st->prev_mode;
198         }
199     }
200
201     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
202                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
203                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
204     {
205         transition = 1;
206         if (mode == MODE_CELT_ONLY)
207             opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
208     }
209     if (audiosize > frame_size)
210     {
211         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
212         return OPUS_BAD_ARG;
213     } else {
214         frame_size = audiosize;
215     }
216
217     /* SILK processing */
218     if (mode != MODE_CELT_ONLY)
219     {
220         int lost_flag, decoded_samples;
221         SKP_int16 *pcm_ptr = pcm;
222
223         if (st->prev_mode==MODE_CELT_ONLY)
224                 silk_InitDecoder( silk_dec );
225
226         DecControl.API_sampleRate = st->Fs;
227         DecControl.nChannelsAPI      = st->channels;
228         DecControl.nChannelsInternal = st->stream_channels;
229         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
230         if( mode == MODE_SILK_ONLY ) {
231             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
232                 DecControl.internalSampleRate = 8000;
233             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
234                 DecControl.internalSampleRate = 12000;
235             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
236                 DecControl.internalSampleRate = 16000;
237             } else {
238                 DecControl.internalSampleRate = 16000;
239                 SKP_assert( 0 );
240             }
241         } else {
242             /* Hybrid mode */
243             DecControl.internalSampleRate = 16000;
244         }
245
246         lost_flag = data == NULL ? 1 : 2 * decode_fec;
247         decoded_samples = 0;
248         do {
249             /* Call SILK decoder */
250             int first_frame = decoded_samples == 0;
251             silk_ret = silk_Decode( silk_dec, &DecControl, 
252                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
253             if( silk_ret ) {
254                 if (lost_flag) {
255                         /* PLC failure should not be fatal */
256                         silk_frame_size = frame_size;
257                         for (i=0;i<frame_size*st->channels;i++)
258                                 pcm_ptr[i] = 0;
259                 } else
260                     return OPUS_CORRUPTED_DATA;
261             }
262             pcm_ptr += silk_frame_size * st->channels;
263             decoded_samples += silk_frame_size;
264         } while( decoded_samples < frame_size );
265     } else {
266         for (i=0;i<frame_size*st->channels;i++)
267             pcm[i] = 0;
268     }
269
270     start_band = 0;
271     if (mode != MODE_CELT_ONLY && data != NULL)
272     {
273         /* Check if we have a redundant 0-8 kHz band */
274         redundancy = ec_dec_bit_logp(&dec, 12);
275         if (redundancy)
276         {
277             celt_to_silk = ec_dec_bit_logp(&dec, 1);
278             if (mode == MODE_HYBRID)
279                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
280             else {
281                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
282                 /* Can only happen on an invalid packet */
283                 if (redundancy_bytes<0)
284                 {
285                         redundancy_bytes = 0;
286                         redundancy = 0;
287                 }
288             }
289             len -= redundancy_bytes;
290             if (len<0)
291                 return OPUS_CORRUPTED_DATA;
292             /* Shrink decoder because of raw bits */
293             dec.storage -= redundancy_bytes;
294         }
295     }
296     if (mode != MODE_CELT_ONLY)
297         start_band = 17;
298
299     if (mode != MODE_SILK_ONLY)
300     {
301         int endband=21;
302
303         switch(st->bandwidth)
304         {
305         case OPUS_BANDWIDTH_NARROWBAND:
306             endband = 13;
307             break;
308         case OPUS_BANDWIDTH_WIDEBAND:
309             endband = 17;
310             break;
311         case OPUS_BANDWIDTH_SUPERWIDEBAND:
312             endband = 19;
313             break;
314         case OPUS_BANDWIDTH_FULLBAND:
315             endband = 21;
316             break;
317         }
318         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
319         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
320     }
321
322     if (redundancy)
323         transition = 0;
324
325     if (transition && mode != MODE_CELT_ONLY)
326         opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
327
328     /* 5 ms redundant frame for CELT->SILK*/
329     if (redundancy && celt_to_silk)
330     {
331         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
332         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
333         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
334     }
335
336     /* MUST be after PLC */
337     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
338
339     if (transition)
340         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
341
342     if (mode != MODE_SILK_ONLY)
343     {
344         int celt_frame_size = IMIN(F20, frame_size);
345         /* Decode CELT */
346         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, celt_frame_size, &dec);
347         for (i=0;i<celt_frame_size*st->channels;i++)
348             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
349     }
350
351     {
352         const CELTMode *celt_mode;
353         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
354         window = celt_mode->window;
355     }
356
357     /* 5 ms redundant frame for SILK->CELT */
358     if (redundancy && !celt_to_silk)
359     {
360         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
361         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
362
363         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
364         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
365                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
366     }
367     if (redundancy && celt_to_silk)
368     {
369         for (c=0;c<st->channels;c++)
370         {
371             for (i=0;i<F2_5;i++)
372                 pcm[st->channels*i+c] = redundant_audio[st->channels*i+c];
373         }
374         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
375                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
376     }
377     if (transition)
378     {
379         for (i=0;i<st->channels*F2_5;i++)
380                 pcm[i] = pcm_transition[i];
381         if (audiosize >= F5)
382             smooth_fade(pcm_transition+st->channels*F2_5, pcm+st->channels*F2_5,
383                     pcm+st->channels*F2_5, F2_5,
384                     st->channels, window, st->Fs);
385     }
386 #if OPUS_TEST_RANGE_CODER_STATE
387     st->rangeFinal = dec.rng;
388 #endif
389
390     st->prev_mode = mode;
391     st->prev_redundancy = redundancy;
392         return celt_ret<0 ? celt_ret : audiosize;
393
394 }
395
396 static int parse_size(const unsigned char *data, int len, short *size)
397 {
398         if (len<1)
399         {
400                 *size = -1;
401                 return -1;
402         } else if (data[0]<252)
403         {
404                 *size = data[0];
405                 return 1;
406         } else if (len<2)
407         {
408                 *size = -1;
409                 return -1;
410         } else {
411                 *size = 4*data[1] + data[0];
412                 return 2;
413         }
414 }
415
416 int opus_decode(OpusDecoder *st, const unsigned char *data,
417                 int len, short *pcm, int frame_size, int decode_fec)
418 {
419         int i, bytes, nb_samples;
420         int count;
421         unsigned char ch, toc;
422         /* 48 x 2.5 ms = 120 ms */
423         short size[48];
424         if (len==0 || data==NULL)
425             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
426         else if (len<0)
427                 return OPUS_BAD_ARG;
428         st->mode = opus_packet_get_mode(data);
429         st->bandwidth = opus_packet_get_bandwidth(data);
430         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
431         st->stream_channels = opus_packet_get_nb_channels(data);
432         toc = *data++;
433         len--;
434         switch (toc&0x3)
435         {
436         /* One frame */
437         case 0:
438                 count=1;
439                 size[0] = len;
440                 break;
441                 /* Two CBR frames */
442         case 1:
443                 count=2;
444                 if (len&0x1)
445                         return OPUS_CORRUPTED_DATA;
446                 size[0] = size[1] = len/2;
447                 break;
448                 /* Two VBR frames */
449         case 2:
450                 count = 2;
451                 bytes = parse_size(data, len, size);
452                 len -= bytes;
453                 if (size[0]<0 || size[0] > len)
454                         return OPUS_CORRUPTED_DATA;
455                 data += bytes;
456                 size[1] = len-size[0];
457                 break;
458                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
459         case 3:
460                 if (len<1)
461                         return OPUS_CORRUPTED_DATA;
462                 /* Number of frames encoded in bits 0 to 5 */
463                 ch = *data++;
464                 count = ch&0x3F;
465                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
466                     return OPUS_CORRUPTED_DATA;
467                 len--;
468                 /* Padding bit */
469                 if (ch&0x40)
470                 {
471                         int padding=0;
472                         int p;
473                         do {
474                                 if (len<=0)
475                                         return OPUS_CORRUPTED_DATA;
476                                 p = *data++;
477                                 len--;
478                                 padding += p==255 ? 254: p;
479                         } while (p==255);
480                         len -= padding;
481                 }
482                 if (len<0)
483                         return OPUS_CORRUPTED_DATA;
484                 /* Bit 7 is VBR flag (bit 6 is ignored) */
485                 if (ch&0x80)
486                 {
487                         /* VBR case */
488                         int last_size=len;
489                         for (i=0;i<count-1;i++)
490                         {
491                                 bytes = parse_size(data, len, size+i);
492                                 len -= bytes;
493                                 if (size[i]<0 || size[i] > len)
494                                         return OPUS_CORRUPTED_DATA;
495                                 data += bytes;
496                                 last_size -= bytes+size[i];
497                         }
498                         if (last_size<0)
499                                 return OPUS_CORRUPTED_DATA;
500                         size[count-1]=last_size;
501                 } else {
502                         /* CBR case */
503                         int sz = len/count;
504                         if (sz*count!=len)
505                                 return OPUS_CORRUPTED_DATA;
506                         for (i=0;i<count;i++)
507                                 size[i] = sz;
508                 }
509                 break;
510         }
511         /* Because it's not encoded explicitly, it's possible the size of the
512             last packet (or all the packets, for the CBR case) is larger than
513             1275.
514            Reject them here.*/
515         if (size[count-1] > MAX_PACKET)
516                 return OPUS_CORRUPTED_DATA;
517         if (count*st->frame_size > frame_size)
518                 return OPUS_BAD_ARG;
519         nb_samples=0;
520         for (i=0;i<count;i++)
521         {
522                 int ret;
523                 ret = opus_decode_frame(st, data, size[i], pcm, frame_size-nb_samples, decode_fec);
524                 if (ret<0)
525                         return ret;
526                 data += size[i];
527                 pcm += ret;
528                 nb_samples += ret;
529         }
530         return nb_samples;
531 }
532 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
533 {
534     va_list ap;
535
536     va_start(ap, request);
537
538     switch (request)
539     {
540         case OPUS_GET_MODE_REQUEST:
541         {
542             int *value = va_arg(ap, int*);
543             *value = st->prev_mode;
544         }
545         break;
546         case OPUS_SET_BANDWIDTH_REQUEST:
547         {
548             int value = va_arg(ap, int);
549             st->bandwidth = value;
550         }
551         break;
552         case OPUS_GET_BANDWIDTH_REQUEST:
553         {
554             int *value = va_arg(ap, int*);
555             *value = st->bandwidth;
556         }
557         break;
558         default:
559             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
560             break;
561     }
562
563     va_end(ap);
564     return OPUS_OK;
565 }
566
567 void opus_decoder_destroy(OpusDecoder *st)
568 {
569         free(st);
570 }
571
572 #if OPUS_TEST_RANGE_CODER_STATE
573 int opus_decoder_get_final_range(OpusDecoder *st)
574 {
575     return st->rangeFinal;
576 }
577 #endif
578
579
580 int opus_packet_get_bandwidth(const unsigned char *data)
581 {
582         int bandwidth;
583     if (data[0]&0x80)
584     {
585         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
586         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
587             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
588     } else if ((data[0]&0x60) == 0x60)
589     {
590         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
591     } else {
592
593         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
594     }
595     return bandwidth;
596 }
597
598 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
599 {
600         int audiosize;
601     if (data[0]&0x80)
602     {
603         audiosize = ((data[0]>>3)&0x3);
604         audiosize = (Fs<<audiosize)/400;
605     } else if ((data[0]&0x60) == 0x60)
606     {
607         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
608     } else {
609
610         audiosize = ((data[0]>>3)&0x3);
611         if (audiosize == 3)
612             audiosize = Fs*60/1000;
613         else
614             audiosize = (Fs<<audiosize)/100;
615     }
616     return audiosize;
617 }
618
619 int opus_packet_get_nb_channels(const unsigned char *data)
620 {
621     return (data[0]&0x4) ? 2 : 1;
622 }
623
624 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
625 {
626         int count;
627         if (len<1)
628                 return OPUS_BAD_ARG;
629         count = packet[0]&0x3;
630         if (count==0)
631                 return 1;
632         else if (count!=3)
633                 return 2;
634         else if (len<2)
635                 return OPUS_CORRUPTED_DATA;
636         else
637                 return packet[1]&0x3F;
638 }
639
640 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
641 {
642         int samples;
643         int count = opus_packet_get_nb_frames(packet, len);
644         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
645         /* Can't have more than 120 ms */
646         if (samples*25 > dec->Fs*3)
647                 return OPUS_CORRUPTED_DATA;
648         else
649                 return samples;
650 }
651