Fixes a bunch of valgrind errors when decoding random junk
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         if (channels<1 || channels > 2)
71             return NULL;
72         memset(st, 0, opus_decoder_get_size(channels));
73         /* Initialize SILK encoder */
74     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
75     if( ret ) {
76         return NULL;
77     }
78     silkDecSizeBytes = align(silkDecSizeBytes);
79     st->silk_dec_offset = align(sizeof(OpusDecoder));
80     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
81     silk_dec = (char*)st+st->silk_dec_offset;
82     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
83     st->stream_channels = st->channels = channels;
84
85     st->Fs = Fs;
86
87     /* Reset decoder */
88     ret = silk_InitDecoder( silk_dec );
89     if( ret ) {
90         goto failure;
91     }
92
93         /* Initialize CELT decoder */
94         celt_decoder_init(celt_dec, Fs, channels, &ret);
95         if (ret != CELT_OK)
96                 goto failure;
97     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
98
99         st->prev_mode = 0;
100         return st;
101 failure:
102     free(st);
103     return NULL;
104 }
105
106 OpusDecoder *opus_decoder_create(int Fs, int channels)
107 {
108     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
109     if (raw_state == NULL)
110         return NULL;
111     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
112 }
113
114 static void smooth_fade(const short *in1, const short *in2, short *out,
115         int overlap, int channels, const celt_word16 *window, int Fs)
116 {
117         int i, c;
118         int inc = 48000/Fs;
119         for (c=0;c<channels;c++)
120         {
121                 for (i=0;i<overlap;i++)
122                 {
123                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
124                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
125                             Q15ONE-w, in1[i*channels+c]), 15);
126                 }
127         }
128 }
129
130 static int opus_packet_get_mode(const unsigned char *data)
131 {
132         int mode;
133     if (data[0]&0x80)
134     {
135         mode = MODE_CELT_ONLY;
136     } else if ((data[0]&0x60) == 0x60)
137     {
138         mode = MODE_HYBRID;
139     } else {
140
141         mode = MODE_SILK_ONLY;
142     }
143     return mode;
144 }
145
146 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
147                 int len, short *pcm, int frame_size, int decode_fec)
148 {
149         void *silk_dec;
150         CELTDecoder *celt_dec;
151         int i, silk_ret=0, celt_ret=0;
152         ec_dec dec;
153     silk_DecControlStruct DecControl;
154     SKP_int32 silk_frame_size;
155     short pcm_celt[960*2];
156     short pcm_transition[480*2];
157
158     int audiosize;
159     int mode;
160     int transition=0;
161     int start_band;
162     int redundancy=0;
163     int redundancy_bytes = 0;
164     int celt_to_silk=0;
165     short redundant_audio[240*2];
166     int c;
167     int F2_5, F5, F10, F20;
168     const celt_word16 *window;
169
170     silk_dec = (char*)st+st->silk_dec_offset;
171     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
172     F20 = st->Fs/50;
173     F10 = F20>>1;
174     F5 = F10>>1;
175     F2_5 = F5>>1;
176     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
177     if (len<=1)
178     {
179         data = NULL;
180         /* In that case, don't conceal more than what the ToC says */
181         frame_size = IMIN(frame_size, st->frame_size);
182     }
183     if (data != NULL)
184     {
185         audiosize = st->frame_size;
186         mode = st->mode;
187         ec_dec_init(&dec,(unsigned char*)data,len);
188     } else {
189         audiosize = frame_size;
190         if (st->prev_mode == 0)
191         {
192                 /* If we haven't got any packet yet, all we can do is return zeros */
193                 for (i=0;i<audiosize;i++)
194                         pcm[i] = 0;
195                 return audiosize;
196         } else {
197                 mode = st->prev_mode;
198         }
199     }
200
201     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
202                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
203                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
204     {
205         transition = 1;
206         if (mode == MODE_CELT_ONLY)
207             opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
208     }
209     if (audiosize > frame_size)
210     {
211         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
212         return OPUS_BAD_ARG;
213     } else {
214         frame_size = audiosize;
215     }
216
217     /* SILK processing */
218     if (mode != MODE_CELT_ONLY)
219     {
220         int lost_flag, decoded_samples;
221         SKP_int16 *pcm_ptr = pcm;
222
223         if (st->prev_mode==MODE_CELT_ONLY)
224                 silk_InitDecoder( silk_dec );
225
226         DecControl.API_sampleRate = st->Fs;
227         DecControl.nChannelsAPI      = st->channels;
228         DecControl.nChannelsInternal = st->stream_channels;
229         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
230         if( mode == MODE_SILK_ONLY ) {
231             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
232                 DecControl.internalSampleRate = 8000;
233             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
234                 DecControl.internalSampleRate = 12000;
235             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
236                 DecControl.internalSampleRate = 16000;
237             } else {
238                 DecControl.internalSampleRate = 16000;
239                 SKP_assert( 0 );
240             }
241         } else {
242             /* Hybrid mode */
243             DecControl.internalSampleRate = 16000;
244         }
245
246         lost_flag = data == NULL ? 1 : 2 * decode_fec;
247         decoded_samples = 0;
248         do {
249             /* Call SILK decoder */
250             int first_frame = decoded_samples == 0;
251             silk_ret = silk_Decode( silk_dec, &DecControl, 
252                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
253             if( silk_ret ) {
254                 if (lost_flag) {
255                         /* PLC failure should not be fatal */
256                         silk_frame_size = frame_size;
257                         for (i=0;i<frame_size*st->channels;i++)
258                                 pcm_ptr[i] = 0;
259                 } else
260                     return OPUS_CORRUPTED_DATA;
261             }
262             pcm_ptr += silk_frame_size * st->channels;
263             decoded_samples += silk_frame_size;
264         } while( decoded_samples < frame_size );
265     } else {
266         for (i=0;i<frame_size*st->channels;i++)
267             pcm[i] = 0;
268     }
269
270     start_band = 0;
271     if (mode != MODE_CELT_ONLY && data != NULL)
272     {
273         /* Check if we have a redundant 0-8 kHz band */
274         redundancy = ec_dec_bit_logp(&dec, 12);
275         if (redundancy)
276         {
277             celt_to_silk = ec_dec_bit_logp(&dec, 1);
278             if (mode == MODE_HYBRID)
279                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
280             else {
281                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
282                 /* Can only happen on an invalid packet */
283                 if (redundancy_bytes<0)
284                 {
285                         redundancy_bytes = 0;
286                         redundancy = 0;
287                 }
288             }
289             len -= redundancy_bytes;
290             if (len<0)
291                 return OPUS_CORRUPTED_DATA;
292             /* Shrink decoder because of raw bits */
293             dec.storage -= redundancy_bytes;
294         }
295     }
296     if (mode != MODE_CELT_ONLY)
297         start_band = 17;
298
299     if (mode != MODE_SILK_ONLY)
300     {
301         int endband=21;
302
303         switch(st->bandwidth)
304         {
305         case OPUS_BANDWIDTH_NARROWBAND:
306             endband = 13;
307             break;
308         case OPUS_BANDWIDTH_WIDEBAND:
309             endband = 17;
310             break;
311         case OPUS_BANDWIDTH_SUPERWIDEBAND:
312             endband = 19;
313             break;
314         case OPUS_BANDWIDTH_FULLBAND:
315             endband = 21;
316             break;
317         }
318         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
319         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
320     }
321
322     if (redundancy)
323         transition = 0;
324
325     if (transition && mode != MODE_CELT_ONLY)
326         opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
327
328     /* 5 ms redundant frame for CELT->SILK*/
329     if (redundancy && celt_to_silk)
330     {
331         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
332         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
333     }
334
335     /* MUST be after PLC */
336     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
337
338     if (transition)
339         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
340
341     if (mode != MODE_SILK_ONLY)
342     {
343         int celt_frame_size = IMIN(F20, frame_size);
344         /* Decode CELT */
345         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, celt_frame_size, &dec);
346         for (i=0;i<celt_frame_size*st->channels;i++)
347             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
348     }
349
350     {
351         const CELTMode *celt_mode;
352         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
353         window = celt_mode->window;
354     }
355
356     /* 5 ms redundant frame for SILK->CELT */
357     if (redundancy && !celt_to_silk)
358     {
359         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
360         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
361
362         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
363         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
364                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
365     }
366     if (redundancy && celt_to_silk)
367     {
368         for (c=0;c<st->channels;c++)
369         {
370             for (i=0;i<F2_5;i++)
371                 pcm[st->channels*i+c] = redundant_audio[st->channels*i+c];
372         }
373         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
374                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
375     }
376     if (transition)
377     {
378         for (i=0;i<st->channels*F2_5;i++)
379                 pcm[i] = pcm_transition[i];
380         if (audiosize >= F5)
381             smooth_fade(pcm_transition+st->channels*F2_5, pcm+st->channels*F2_5,
382                     pcm+st->channels*F2_5, F2_5,
383                     st->channels, window, st->Fs);
384     }
385 #if OPUS_TEST_RANGE_CODER_STATE
386     st->rangeFinal = dec.rng;
387 #endif
388
389     st->prev_mode = mode;
390     st->prev_redundancy = redundancy;
391         return celt_ret<0 ? celt_ret : audiosize;
392
393 }
394
395 static int parse_size(const unsigned char *data, int len, short *size)
396 {
397         if (len<1)
398         {
399                 *size = -1;
400                 return -1;
401         } else if (data[0]<252)
402         {
403                 *size = data[0];
404                 return 1;
405         } else if (len<2)
406         {
407                 *size = -1;
408                 return -1;
409         } else {
410                 *size = 4*data[1] + data[0];
411                 return 2;
412         }
413 }
414
415 int opus_decode(OpusDecoder *st, const unsigned char *data,
416                 int len, short *pcm, int frame_size, int decode_fec)
417 {
418         int i, bytes, nb_samples;
419         int count;
420         unsigned char ch, toc;
421         /* 48 x 2.5 ms = 120 ms */
422         short size[48];
423         if (len==0 || data==NULL)
424             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
425         else if (len<0)
426                 return OPUS_BAD_ARG;
427         st->mode = opus_packet_get_mode(data);
428         st->bandwidth = opus_packet_get_bandwidth(data);
429         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
430         st->stream_channels = opus_packet_get_nb_channels(data);
431         toc = *data++;
432         len--;
433         switch (toc&0x3)
434         {
435         /* One frame */
436         case 0:
437                 count=1;
438                 size[0] = len;
439                 break;
440                 /* Two CBR frames */
441         case 1:
442                 count=2;
443                 if (len&0x1)
444                         return OPUS_CORRUPTED_DATA;
445                 size[0] = size[1] = len/2;
446                 break;
447                 /* Two VBR frames */
448         case 2:
449                 count = 2;
450                 bytes = parse_size(data, len, size);
451                 len -= bytes;
452                 if (size[0]<0 || size[0] > len)
453                         return OPUS_CORRUPTED_DATA;
454                 data += bytes;
455                 size[1] = len-size[0];
456                 break;
457                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
458         case 3:
459                 if (len<1)
460                         return OPUS_CORRUPTED_DATA;
461                 /* Number of frames encoded in bits 0 to 5 */
462                 ch = *data++;
463                 count = ch&0x3F;
464                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
465                     return OPUS_CORRUPTED_DATA;
466                 len--;
467                 /* Padding bit */
468                 if (ch&0x40)
469                 {
470                         int padding=0;
471                         int p;
472                         do {
473                                 if (len<=0)
474                                         return OPUS_CORRUPTED_DATA;
475                                 p = *data++;
476                                 len--;
477                                 padding += p==255 ? 254: p;
478                         } while (p==255);
479                         len -= padding;
480                 }
481                 if (len<0)
482                         return OPUS_CORRUPTED_DATA;
483                 /* Bit 7 is VBR flag (bit 6 is ignored) */
484                 if (ch&0x80)
485                 {
486                         /* VBR case */
487                         int last_size=len;
488                         for (i=0;i<count-1;i++)
489                         {
490                                 bytes = parse_size(data, len, size+i);
491                                 len -= bytes;
492                                 if (size[i]<0 || size[i] > len)
493                                         return OPUS_CORRUPTED_DATA;
494                                 data += bytes;
495                                 last_size -= bytes+size[i];
496                         }
497                         if (last_size<0)
498                                 return OPUS_CORRUPTED_DATA;
499                         size[count-1]=last_size;
500                 } else {
501                         /* CBR case */
502                         int sz = len/count;
503                         if (sz*count!=len)
504                                 return OPUS_CORRUPTED_DATA;
505                         for (i=0;i<count;i++)
506                                 size[i] = sz;
507                 }
508                 break;
509         }
510         /* Because it's not encoded explicitly, it's possible the size of the
511             last packet (or all the packets, for the CBR case) is larger than
512             1275.
513            Reject them here.*/
514         if (size[count-1] > MAX_PACKET)
515                 return OPUS_CORRUPTED_DATA;
516         if (count*st->frame_size > frame_size)
517                 return OPUS_BAD_ARG;
518         nb_samples=0;
519         for (i=0;i<count;i++)
520         {
521                 int ret;
522                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
523                 if (ret<0)
524                         return ret;
525                 data += size[i];
526                 pcm += ret;
527                 nb_samples += ret;
528         }
529         return nb_samples;
530 }
531 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
532 {
533     va_list ap;
534
535     va_start(ap, request);
536
537     switch (request)
538     {
539         case OPUS_GET_MODE_REQUEST:
540         {
541             int *value = va_arg(ap, int*);
542             *value = st->prev_mode;
543         }
544         break;
545         case OPUS_SET_BANDWIDTH_REQUEST:
546         {
547             int value = va_arg(ap, int);
548             st->bandwidth = value;
549         }
550         break;
551         case OPUS_GET_BANDWIDTH_REQUEST:
552         {
553             int *value = va_arg(ap, int*);
554             *value = st->bandwidth;
555         }
556         break;
557         default:
558             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
559             break;
560     }
561
562     va_end(ap);
563     return OPUS_OK;
564 }
565
566 void opus_decoder_destroy(OpusDecoder *st)
567 {
568         free(st);
569 }
570
571 #if OPUS_TEST_RANGE_CODER_STATE
572 int opus_decoder_get_final_range(OpusDecoder *st)
573 {
574     return st->rangeFinal;
575 }
576 #endif
577
578
579 int opus_packet_get_bandwidth(const unsigned char *data)
580 {
581         int bandwidth;
582     if (data[0]&0x80)
583     {
584         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
585         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
586             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
587     } else if ((data[0]&0x60) == 0x60)
588     {
589         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
590     } else {
591
592         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
593     }
594     return bandwidth;
595 }
596
597 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
598 {
599         int audiosize;
600     if (data[0]&0x80)
601     {
602         audiosize = ((data[0]>>3)&0x3);
603         audiosize = (Fs<<audiosize)/400;
604     } else if ((data[0]&0x60) == 0x60)
605     {
606         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
607     } else {
608
609         audiosize = ((data[0]>>3)&0x3);
610         if (audiosize == 3)
611             audiosize = Fs*60/1000;
612         else
613             audiosize = (Fs<<audiosize)/100;
614     }
615     return audiosize;
616 }
617
618 int opus_packet_get_nb_channels(const unsigned char *data)
619 {
620     return (data[0]&0x4) ? 2 : 1;
621 }
622
623 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
624 {
625         int count;
626         if (len<1)
627                 return OPUS_BAD_ARG;
628         count = packet[0]&0x3;
629         if (count==0)
630                 return 1;
631         else if (count!=3)
632                 return 2;
633         else if (len<2)
634                 return OPUS_CORRUPTED_DATA;
635         else
636                 return packet[1]&0x3F;
637 }
638
639 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
640 {
641         int samples;
642         int count = opus_packet_get_nb_frames(packet, len);
643         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
644         /* Can't have more than 120 ms */
645         if (samples*25 > dec->Fs*3)
646                 return OPUS_CORRUPTED_DATA;
647         else
648                 return samples;
649 }
650