Reject more invalid packets in the decoder.
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "SKP_Silk_SDK_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = SKP_Silk_SDK_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes, celtDecSizeBytes;
69
70         memset(st, 0, sizeof(OpusDecoder));
71         /* Initialize SILK encoder */
72     ret = SKP_Silk_SDK_Get_Decoder_Size( &silkDecSizeBytes );
73     if( ret ) {
74         return NULL;
75     }
76     silkDecSizeBytes = align(silkDecSizeBytes);
77     celtDecSizeBytes = celt_decoder_get_size(channels);
78     st->silk_dec_offset = align(sizeof(OpusDecoder));
79     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
80     silk_dec = (char*)st+st->silk_dec_offset;
81     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
82     st->stream_channels = st->channels = channels;
83
84     st->Fs = Fs;
85
86     /* Reset decoder */
87     ret = SKP_Silk_SDK_InitDecoder( silk_dec );
88     if( ret ) {
89         goto failure;
90     }
91
92         /* Initialize CELT decoder */
93         celt_decoder_init(celt_dec, Fs, channels, &ret);
94         if (ret != CELT_OK)
95                 goto failure;
96     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
97
98         st->prev_mode = 0;
99         return st;
100 failure:
101     free(st);
102     return NULL;
103 }
104
105 OpusDecoder *opus_decoder_create(int Fs, int channels)
106 {
107     char *raw_state = malloc(opus_decoder_get_size(channels));
108     if (raw_state == NULL)
109         return NULL;
110     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
111 }
112
113 static void smooth_fade(const short *in1, const short *in2, short *out,
114         int overlap, int channels, const celt_word16 *window, int Fs)
115 {
116         int i, c;
117         int inc = 48000/Fs;
118         for (c=0;c<channels;c++)
119         {
120                 for (i=0;i<overlap;i++)
121                 {
122                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
123                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
124                             Q15ONE-w, in1[i*channels+c]), 15);
125                 }
126         }
127 }
128
129 static int opus_packet_get_mode(const unsigned char *data)
130 {
131         int mode;
132     if (data[0]&0x80)
133     {
134         mode = MODE_CELT_ONLY;
135     } else if ((data[0]&0x60) == 0x60)
136     {
137         mode = MODE_HYBRID;
138     } else {
139
140         mode = MODE_SILK_ONLY;
141     }
142     return mode;
143 }
144
145 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
146                 int len, short *pcm, int frame_size, int decode_fec)
147 {
148         void *silk_dec;
149         CELTDecoder *celt_dec;
150         int i, silk_ret=0, celt_ret=0;
151         ec_dec dec;
152     SKP_SILK_SDK_DecControlStruct DecControl;
153     SKP_int32 silk_frame_size;
154     short pcm_celt[960*2];
155     short pcm_transition[960*2];
156     int audiosize;
157     int mode;
158     int transition=0;
159     int start_band;
160     int redundancy=0;
161     int redundancy_bytes = 0;
162     int celt_to_silk=0;
163     short redundant_audio[240*2];
164     int c;
165     int F2_5, F5, F10;
166     const celt_word16 *window;
167
168     silk_dec = (char*)st+st->silk_dec_offset;
169     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
170     F10 = st->Fs/100;
171     F5 = F10>>1;
172     F2_5 = F5>>1;
173     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
174     if (len<=1)
175         data = NULL;
176
177         audiosize = st->frame_size;
178     if (data != NULL)
179     {
180         mode = st->mode;
181         ec_dec_init(&dec,(unsigned char*)data,len);
182     } else {
183         mode = st->prev_mode;
184     }
185
186     if (st->stream_channels > st->channels)
187         return OPUS_CORRUPTED_DATA;
188
189     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
190                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
191                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
192     {
193         transition = 1;
194         if (mode == MODE_CELT_ONLY)
195             opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
196     }
197     if (audiosize > frame_size)
198     {
199         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
200         return OPUS_BAD_ARG;
201     } else {
202         frame_size = audiosize;
203     }
204
205     /* SILK processing */
206     if (mode != MODE_CELT_ONLY)
207     {
208         int lost_flag, decoded_samples;
209         SKP_int16 *pcm_ptr = pcm;
210
211         if (st->prev_mode==MODE_CELT_ONLY)
212                 SKP_Silk_SDK_InitDecoder( silk_dec );
213
214         DecControl.API_sampleRate = st->Fs;
215         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
216         if( mode == MODE_SILK_ONLY ) {
217             if( st->bandwidth == BANDWIDTH_NARROWBAND ) {
218                 DecControl.internalSampleRate = 8000;
219             } else if( st->bandwidth == BANDWIDTH_MEDIUMBAND ) {
220                 DecControl.internalSampleRate = 12000;
221             } else if( st->bandwidth == BANDWIDTH_WIDEBAND ) {
222                 DecControl.internalSampleRate = 16000;
223             } else {
224                 DecControl.internalSampleRate = 16000;
225                 SKP_assert( 0 );
226             }
227         } else {
228             /* Hybrid mode */
229             DecControl.internalSampleRate = 16000;
230         }
231         DecControl.nChannels = st->channels;
232
233         lost_flag = data == NULL ? 1 : 2 * decode_fec;
234         decoded_samples = 0;
235         do {
236             /* Call SILK decoder */
237             int first_frame = decoded_samples == 0;
238             silk_ret = SKP_Silk_SDK_Decode( silk_dec, &DecControl,
239                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
240             if( silk_ret ) {
241                 fprintf (stderr, "SILK decode error\n");
242                 /* Handle error */
243             }
244             pcm_ptr += silk_frame_size * st->channels;
245             decoded_samples += silk_frame_size;
246         } while( decoded_samples < frame_size );
247     } else {
248         for (i=0;i<frame_size*st->channels;i++)
249             pcm[i] = 0;
250     }
251
252     start_band = 0;
253     if (mode != MODE_CELT_ONLY && data != NULL)
254     {
255         /* Check if we have a redundant 0-8 kHz band */
256         redundancy = ec_dec_bit_logp(&dec, 12);
257         if (redundancy)
258         {
259             celt_to_silk = ec_dec_bit_logp(&dec, 1);
260             if (mode == MODE_HYBRID)
261                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
262             else
263                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
264             len -= redundancy_bytes;
265             if (len<0)
266                 return CELT_CORRUPTED_DATA;
267             /* Shrink decoder because of raw bits */
268             dec.storage -= redundancy_bytes;
269         }
270     }
271     if (mode != MODE_CELT_ONLY)
272         start_band = 17;
273
274     if (mode != MODE_SILK_ONLY)
275     {
276         int endband=21;
277
278         switch(st->bandwidth)
279         {
280         case BANDWIDTH_NARROWBAND:
281             endband = 13;
282             break;
283         case BANDWIDTH_WIDEBAND:
284             endband = 17;
285             break;
286         case BANDWIDTH_SUPERWIDEBAND:
287             endband = 19;
288             break;
289         case BANDWIDTH_FULLBAND:
290             endband = 21;
291             break;
292         }
293         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
294         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
295     }
296
297     if (redundancy)
298         transition = 0;
299
300     if (transition && mode != MODE_CELT_ONLY)
301         opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
302
303     /* 5 ms redundant frame for CELT->SILK*/
304     if (redundancy && celt_to_silk)
305     {
306         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
307         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
308     }
309
310     /* MUST be after PLC */
311     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
312
313     if (transition)
314         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
315
316     if (mode != MODE_SILK_ONLY)
317     {
318         /* Decode CELT */
319         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, frame_size, &dec);
320         for (i=0;i<frame_size*st->channels;i++)
321             pcm[i] = ADD_SAT16(pcm[i], pcm_celt[i]);
322     }
323
324
325     {
326         const CELTMode *celt_mode;
327         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
328         window = celt_mode->window;
329     }
330
331     /* 5 ms redundant frame for SILK->CELT */
332     if (redundancy && !celt_to_silk)
333     {
334         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
335         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
336
337         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
338         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
339                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
340     }
341     if (redundancy && celt_to_silk)
342     {
343         for (c=0;c<st->channels;c++)
344         {
345             for (i=0;i<F2_5;i++)
346                 pcm[st->channels*i+c] = redundant_audio[st->channels*i];
347         }
348         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
349                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
350     }
351     if (transition)
352     {
353         for (i=0;i<F2_5;i++)
354                 pcm[i] = pcm_transition[i];
355         if (audiosize >= F5)
356             smooth_fade(pcm_transition+F2_5, pcm+F2_5, pcm+F2_5, F2_5,
357                     st->channels, window, st->Fs);
358     }
359 #if OPUS_TEST_RANGE_CODER_STATE
360     st->rangeFinal = dec.rng;
361 #endif
362
363     st->prev_mode = mode;
364     st->prev_redundancy = redundancy;
365         return celt_ret<0 ? celt_ret : audiosize;
366
367 }
368
369 static int parse_size(const unsigned char *data, int len, short *size)
370 {
371         if (len<1)
372         {
373                 *size = -1;
374                 return -1;
375         } else if (data[0]<252)
376         {
377                 *size = data[0];
378                 return 1;
379         } else if (len<2)
380         {
381                 *size = -1;
382                 return -1;
383         } else {
384                 *size = 4*data[1] + data[0];
385                 return 2;
386         }
387 }
388
389 int opus_decode(OpusDecoder *st, const unsigned char *data,
390                 int len, short *pcm, int frame_size, int decode_fec)
391 {
392         int i, bytes, nb_samples;
393         int count;
394         unsigned char ch, toc;
395         /* 48 x 2.5 ms = 120 ms */
396         short size[48];
397         if (len==0 || data==NULL)
398             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
399         else if (len<0)
400                 return CELT_BAD_ARG;
401         st->mode = opus_packet_get_mode(data);
402         st->bandwidth = opus_packet_get_bandwidth(data);
403         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
404         st->stream_channels = opus_packet_get_nb_channels(data);
405         toc = *data++;
406         len--;
407         switch (toc&0x3)
408         {
409         /* One frame */
410         case 0:
411                 count=1;
412                 size[0] = len;
413                 break;
414                 /* Two CBR frames */
415         case 1:
416                 count=2;
417                 if (len&0x1)
418                         return OPUS_CORRUPTED_DATA;
419                 size[0] = size[1] = len/2;
420                 break;
421                 /* Two VBR frames */
422         case 2:
423                 count = 2;
424                 bytes = parse_size(data, len, size);
425                 len -= bytes;
426                 if (size[0]<0 || size[0] > len)
427                         return OPUS_CORRUPTED_DATA;
428                 data += bytes;
429                 size[1] = len-size[0];
430                 break;
431                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
432         case 3:
433                 if (len<1)
434                         return OPUS_CORRUPTED_DATA;
435                 /* Number of frames encoded in bits 0 to 5 */
436                 ch = *data++;
437                 count = ch&0x3F;
438                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
439                     return OPUS_CORRUPTED_DATA;
440                 len--;
441                 /* Padding bit */
442                 if (ch&0x40)
443                 {
444                         int padding=0;
445                         int p;
446                         do {
447                                 if (len<=0)
448                                         return OPUS_CORRUPTED_DATA;
449                                 p = *data++;
450                                 len--;
451                                 padding += p==255 ? 254: p;
452                         } while (p==255);
453                         len -= padding;
454                 }
455                 if (len<0)
456                         return OPUS_CORRUPTED_DATA;
457                 /* Bit 7 is VBR flag (bit 6 is ignored) */
458                 if (ch&0x80)
459                 {
460                         /* VBR case */
461                         int last_size=len;
462                         for (i=0;i<count-1;i++)
463                         {
464                                 bytes = parse_size(data, len, size+i);
465                                 len -= bytes;
466                                 if (size[i]<0 || size[i] > len)
467                                         return OPUS_CORRUPTED_DATA;
468                                 data += bytes;
469                                 last_size -= bytes+size[i];
470                         }
471                         if (last_size<0)
472                                 return OPUS_CORRUPTED_DATA;
473                         size[count-1]=last_size;
474                 } else {
475                         /* CBR case */
476                         int sz = len/count;
477                         if (sz*count!=len)
478                                 return OPUS_CORRUPTED_DATA;
479                         for (i=0;i<count;i++)
480                                 size[i] = sz;
481                 }
482                 break;
483         }
484         /* Because it's not encoded explicitly, it's possible the size of the
485             last packet (or all the packets, for the CBR case) is larger than
486             1275.
487            Reject them here.*/
488         if (size[count-1] > MAX_PACKET)
489                 return OPUS_CORRUPTED_DATA;
490         if (count*st->frame_size > frame_size)
491                 return OPUS_BAD_ARG;
492         nb_samples=0;
493         for (i=0;i<count;i++)
494         {
495                 int ret;
496                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
497                 if (ret<0)
498                         return ret;
499                 data += size[i];
500                 pcm += ret;
501                 nb_samples += ret;
502         }
503         return nb_samples;
504 }
505 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
506 {
507     va_list ap;
508
509     va_start(ap, request);
510
511     switch (request)
512     {
513         case OPUS_GET_MODE_REQUEST:
514         {
515             int *value = va_arg(ap, int*);
516             *value = st->prev_mode;
517         }
518         break;
519         case OPUS_SET_BANDWIDTH_REQUEST:
520         {
521             int value = va_arg(ap, int);
522             st->bandwidth = value;
523         }
524         break;
525         case OPUS_GET_BANDWIDTH_REQUEST:
526         {
527             int *value = va_arg(ap, int*);
528             *value = st->bandwidth;
529         }
530         break;
531         default:
532             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
533             break;
534     }
535
536     va_end(ap);
537     return OPUS_OK;
538 }
539
540 void opus_decoder_destroy(OpusDecoder *st)
541 {
542         free(st);
543 }
544
545 #if OPUS_TEST_RANGE_CODER_STATE
546 int opus_decoder_get_final_range(OpusDecoder *st)
547 {
548     return st->rangeFinal;
549 }
550 #endif
551
552
553 int opus_packet_get_bandwidth(const unsigned char *data)
554 {
555         int bandwidth;
556     if (data[0]&0x80)
557     {
558         bandwidth = BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
559         if (bandwidth == BANDWIDTH_MEDIUMBAND)
560             bandwidth = BANDWIDTH_NARROWBAND;
561     } else if ((data[0]&0x60) == 0x60)
562     {
563         bandwidth = (data[0]&0x10) ? BANDWIDTH_FULLBAND : BANDWIDTH_SUPERWIDEBAND;
564     } else {
565
566         bandwidth = BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
567     }
568     return bandwidth;
569 }
570
571 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
572 {
573         int audiosize;
574     if (data[0]&0x80)
575     {
576         audiosize = ((data[0]>>3)&0x3);
577         audiosize = (Fs<<audiosize)/400;
578     } else if ((data[0]&0x60) == 0x60)
579     {
580         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
581     } else {
582
583         audiosize = ((data[0]>>3)&0x3);
584         if (audiosize == 3)
585             audiosize = Fs*60/1000;
586         else
587             audiosize = (Fs<<audiosize)/100;
588     }
589     return audiosize;
590 }
591
592 int opus_packet_get_nb_channels(const unsigned char *data)
593 {
594     return (data[0]&0x4) ? 2 : 1;
595 }
596
597 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
598 {
599         int count;
600         if (len<1)
601                 return OPUS_BAD_ARG;
602         count = packet[0]&0x3;
603         if (count==0)
604                 return 1;
605         else if (count!=3)
606                 return 2;
607         else if (len<2)
608                 return OPUS_CORRUPTED_DATA;
609         else
610                 return packet[1]&0x3F;
611 }
612
613 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
614 {
615         int samples;
616         int count = opus_packet_get_nb_frames(packet, len);
617         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
618         /* Can't have more than 120 ms */
619         if (samples*25 > dec->Fs*3)
620                 return OPUS_CORRUPTED_DATA;
621         else
622                 return samples;
623 }
624