Fixes a crash in silk prefill (used for mode switching)
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         memset(st, 0, opus_decoder_get_size(channels));
71         /* Initialize SILK encoder */
72     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
73     if( ret ) {
74         return NULL;
75     }
76     silkDecSizeBytes = align(silkDecSizeBytes);
77     st->silk_dec_offset = align(sizeof(OpusDecoder));
78     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
79     silk_dec = (char*)st+st->silk_dec_offset;
80     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
81     st->stream_channels = st->channels = channels;
82
83     st->Fs = Fs;
84
85     /* Reset decoder */
86     ret = silk_InitDecoder( silk_dec );
87     if( ret ) {
88         goto failure;
89     }
90
91         /* Initialize CELT decoder */
92         celt_decoder_init(celt_dec, Fs, channels, &ret);
93         if (ret != CELT_OK)
94                 goto failure;
95     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
96
97         st->prev_mode = 0;
98         return st;
99 failure:
100     free(st);
101     return NULL;
102 }
103
104 OpusDecoder *opus_decoder_create(int Fs, int channels)
105 {
106     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
107     if (raw_state == NULL)
108         return NULL;
109     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
110 }
111
112 static void smooth_fade(const short *in1, const short *in2, short *out,
113         int overlap, int channels, const celt_word16 *window, int Fs)
114 {
115         int i, c;
116         int inc = 48000/Fs;
117         for (c=0;c<channels;c++)
118         {
119                 for (i=0;i<overlap;i++)
120                 {
121                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
122                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
123                             Q15ONE-w, in1[i*channels+c]), 15);
124                 }
125         }
126 }
127
128 static int opus_packet_get_mode(const unsigned char *data)
129 {
130         int mode;
131     if (data[0]&0x80)
132     {
133         mode = MODE_CELT_ONLY;
134     } else if ((data[0]&0x60) == 0x60)
135     {
136         mode = MODE_HYBRID;
137     } else {
138
139         mode = MODE_SILK_ONLY;
140     }
141     return mode;
142 }
143
144 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
145                 int len, short *pcm, int frame_size, int decode_fec)
146 {
147         void *silk_dec;
148         CELTDecoder *celt_dec;
149         int i, silk_ret=0, celt_ret=0;
150         ec_dec dec;
151     silk_DecControlStruct DecControl;
152     SKP_int32 silk_frame_size;
153     short pcm_celt[960*2];
154     short pcm_transition[960*2];
155     int audiosize;
156     int mode;
157     int transition=0;
158     int start_band;
159     int redundancy=0;
160     int redundancy_bytes = 0;
161     int celt_to_silk=0;
162     short redundant_audio[240*2];
163     int c;
164     int F2_5, F5, F10;
165     const celt_word16 *window;
166
167     silk_dec = (char*)st+st->silk_dec_offset;
168     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
169     F10 = st->Fs/100;
170     F5 = F10>>1;
171     F2_5 = F5>>1;
172     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
173     if (len<=1)
174         data = NULL;
175
176         audiosize = st->frame_size;
177     if (data != NULL)
178     {
179         mode = st->mode;
180         ec_dec_init(&dec,(unsigned char*)data,len);
181     } else {
182         if (st->prev_mode == 0)
183         {
184                 /* If we haven't got any packet yet, all we can do is return zeros */
185                 for (i=0;i<frame_size;i++)
186                         pcm[i] = 0;
187                 return audiosize;
188         } else {
189                 mode = st->prev_mode;
190         }
191     }
192
193     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
194                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
195                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
196     {
197         transition = 1;
198         if (mode == MODE_CELT_ONLY)
199             opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
200     }
201     if (audiosize > frame_size)
202     {
203         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
204         return OPUS_BAD_ARG;
205     } else {
206         frame_size = audiosize;
207     }
208
209     /* SILK processing */
210     if (mode != MODE_CELT_ONLY)
211     {
212         int lost_flag, decoded_samples;
213         SKP_int16 *pcm_ptr = pcm;
214
215         if (st->prev_mode==MODE_CELT_ONLY)
216                 silk_InitDecoder( silk_dec );
217
218         DecControl.API_sampleRate = st->Fs;
219         DecControl.nChannelsAPI      = st->channels;
220         DecControl.nChannelsInternal = st->stream_channels;
221         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
222         if( mode == MODE_SILK_ONLY ) {
223             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
224                 DecControl.internalSampleRate = 8000;
225             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
226                 DecControl.internalSampleRate = 12000;
227             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
228                 DecControl.internalSampleRate = 16000;
229             } else {
230                 DecControl.internalSampleRate = 16000;
231                 SKP_assert( 0 );
232             }
233         } else {
234             /* Hybrid mode */
235             DecControl.internalSampleRate = 16000;
236         }
237
238         lost_flag = data == NULL ? 1 : 2 * decode_fec;
239         decoded_samples = 0;
240         do {
241             /* Call SILK decoder */
242             int first_frame = decoded_samples == 0;
243             silk_ret = silk_Decode( silk_dec, &DecControl, 
244                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
245             if( silk_ret ) {
246                 fprintf (stderr, "SILK decode error\n");
247                 /* Handle error */
248             }
249             pcm_ptr += silk_frame_size * st->channels;
250             decoded_samples += silk_frame_size;
251         } while( decoded_samples < frame_size );
252     } else {
253         for (i=0;i<frame_size*st->channels;i++)
254             pcm[i] = 0;
255     }
256
257     start_band = 0;
258     if (mode != MODE_CELT_ONLY && data != NULL)
259     {
260         /* Check if we have a redundant 0-8 kHz band */
261         redundancy = ec_dec_bit_logp(&dec, 12);
262         if (redundancy)
263         {
264             celt_to_silk = ec_dec_bit_logp(&dec, 1);
265             if (mode == MODE_HYBRID)
266                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
267             else
268                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
269             len -= redundancy_bytes;
270             if (len<0)
271                 return CELT_CORRUPTED_DATA;
272             /* Shrink decoder because of raw bits */
273             dec.storage -= redundancy_bytes;
274         }
275     }
276     if (mode != MODE_CELT_ONLY)
277         start_band = 17;
278
279     if (mode != MODE_SILK_ONLY)
280     {
281         int endband=21;
282
283         switch(st->bandwidth)
284         {
285         case OPUS_BANDWIDTH_NARROWBAND:
286             endband = 13;
287             break;
288         case OPUS_BANDWIDTH_WIDEBAND:
289             endband = 17;
290             break;
291         case OPUS_BANDWIDTH_SUPERWIDEBAND:
292             endband = 19;
293             break;
294         case OPUS_BANDWIDTH_FULLBAND:
295             endband = 21;
296             break;
297         }
298         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
299         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
300     }
301
302     if (redundancy)
303         transition = 0;
304
305     if (transition && mode != MODE_CELT_ONLY)
306         opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
307
308     /* 5 ms redundant frame for CELT->SILK*/
309     if (redundancy && celt_to_silk)
310     {
311         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
312         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
313     }
314
315     /* MUST be after PLC */
316     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
317
318     if (transition)
319         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
320
321     if (mode != MODE_SILK_ONLY)
322     {
323         /* Decode CELT */
324         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, frame_size, &dec);
325         for (i=0;i<frame_size*st->channels;i++)
326             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
327     }
328
329     {
330         const CELTMode *celt_mode;
331         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
332         window = celt_mode->window;
333     }
334
335     /* 5 ms redundant frame for SILK->CELT */
336     if (redundancy && !celt_to_silk)
337     {
338         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
339         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
340
341         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
342         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
343                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
344     }
345     if (redundancy && celt_to_silk)
346     {
347         for (c=0;c<st->channels;c++)
348         {
349             for (i=0;i<F2_5;i++)
350                 pcm[st->channels*i+c] = redundant_audio[st->channels*i+c];
351         }
352         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
353                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
354     }
355     if (transition)
356     {
357         for (i=0;i<st->channels*F2_5;i++)
358                 pcm[i] = pcm_transition[i];
359         if (audiosize >= F5)
360             smooth_fade(pcm_transition+st->channels*F2_5, pcm+st->channels*F2_5,
361                     pcm+st->channels*F2_5, F2_5,
362                     st->channels, window, st->Fs);
363     }
364 #if OPUS_TEST_RANGE_CODER_STATE
365     st->rangeFinal = dec.rng;
366 #endif
367
368     st->prev_mode = mode;
369     st->prev_redundancy = redundancy;
370         return celt_ret<0 ? celt_ret : audiosize;
371
372 }
373
374 static int parse_size(const unsigned char *data, int len, short *size)
375 {
376         if (len<1)
377         {
378                 *size = -1;
379                 return -1;
380         } else if (data[0]<252)
381         {
382                 *size = data[0];
383                 return 1;
384         } else if (len<2)
385         {
386                 *size = -1;
387                 return -1;
388         } else {
389                 *size = 4*data[1] + data[0];
390                 return 2;
391         }
392 }
393
394 int opus_decode(OpusDecoder *st, const unsigned char *data,
395                 int len, short *pcm, int frame_size, int decode_fec)
396 {
397         int i, bytes, nb_samples;
398         int count;
399         unsigned char ch, toc;
400         /* 48 x 2.5 ms = 120 ms */
401         short size[48];
402         if (len==0 || data==NULL)
403             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
404         else if (len<0)
405                 return CELT_BAD_ARG;
406         st->mode = opus_packet_get_mode(data);
407         st->bandwidth = opus_packet_get_bandwidth(data);
408         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
409         st->stream_channels = opus_packet_get_nb_channels(data);
410         toc = *data++;
411         len--;
412         switch (toc&0x3)
413         {
414         /* One frame */
415         case 0:
416                 count=1;
417                 size[0] = len;
418                 break;
419                 /* Two CBR frames */
420         case 1:
421                 count=2;
422                 if (len&0x1)
423                         return OPUS_CORRUPTED_DATA;
424                 size[0] = size[1] = len/2;
425                 break;
426                 /* Two VBR frames */
427         case 2:
428                 count = 2;
429                 bytes = parse_size(data, len, size);
430                 len -= bytes;
431                 if (size[0]<0 || size[0] > len)
432                         return OPUS_CORRUPTED_DATA;
433                 data += bytes;
434                 size[1] = len-size[0];
435                 break;
436                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
437         case 3:
438                 if (len<1)
439                         return OPUS_CORRUPTED_DATA;
440                 /* Number of frames encoded in bits 0 to 5 */
441                 ch = *data++;
442                 count = ch&0x3F;
443                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
444                     return OPUS_CORRUPTED_DATA;
445                 len--;
446                 /* Padding bit */
447                 if (ch&0x40)
448                 {
449                         int padding=0;
450                         int p;
451                         do {
452                                 if (len<=0)
453                                         return OPUS_CORRUPTED_DATA;
454                                 p = *data++;
455                                 len--;
456                                 padding += p==255 ? 254: p;
457                         } while (p==255);
458                         len -= padding;
459                 }
460                 if (len<0)
461                         return OPUS_CORRUPTED_DATA;
462                 /* Bit 7 is VBR flag (bit 6 is ignored) */
463                 if (ch&0x80)
464                 {
465                         /* VBR case */
466                         int last_size=len;
467                         for (i=0;i<count-1;i++)
468                         {
469                                 bytes = parse_size(data, len, size+i);
470                                 len -= bytes;
471                                 if (size[i]<0 || size[i] > len)
472                                         return OPUS_CORRUPTED_DATA;
473                                 data += bytes;
474                                 last_size -= bytes+size[i];
475                         }
476                         if (last_size<0)
477                                 return OPUS_CORRUPTED_DATA;
478                         size[count-1]=last_size;
479                 } else {
480                         /* CBR case */
481                         int sz = len/count;
482                         if (sz*count!=len)
483                                 return OPUS_CORRUPTED_DATA;
484                         for (i=0;i<count;i++)
485                                 size[i] = sz;
486                 }
487                 break;
488         }
489         /* Because it's not encoded explicitly, it's possible the size of the
490             last packet (or all the packets, for the CBR case) is larger than
491             1275.
492            Reject them here.*/
493         if (size[count-1] > MAX_PACKET)
494                 return OPUS_CORRUPTED_DATA;
495         if (count*st->frame_size > frame_size)
496                 return OPUS_BAD_ARG;
497         nb_samples=0;
498         for (i=0;i<count;i++)
499         {
500                 int ret;
501                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
502                 if (ret<0)
503                         return ret;
504                 data += size[i];
505                 pcm += ret;
506                 nb_samples += ret;
507         }
508         return nb_samples;
509 }
510 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
511 {
512     va_list ap;
513
514     va_start(ap, request);
515
516     switch (request)
517     {
518         case OPUS_GET_MODE_REQUEST:
519         {
520             int *value = va_arg(ap, int*);
521             *value = st->prev_mode;
522         }
523         break;
524         case OPUS_SET_BANDWIDTH_REQUEST:
525         {
526             int value = va_arg(ap, int);
527             st->bandwidth = value;
528         }
529         break;
530         case OPUS_GET_BANDWIDTH_REQUEST:
531         {
532             int *value = va_arg(ap, int*);
533             *value = st->bandwidth;
534         }
535         break;
536         default:
537             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
538             break;
539     }
540
541     va_end(ap);
542     return OPUS_OK;
543 }
544
545 void opus_decoder_destroy(OpusDecoder *st)
546 {
547         free(st);
548 }
549
550 #if OPUS_TEST_RANGE_CODER_STATE
551 int opus_decoder_get_final_range(OpusDecoder *st)
552 {
553     return st->rangeFinal;
554 }
555 #endif
556
557
558 int opus_packet_get_bandwidth(const unsigned char *data)
559 {
560         int bandwidth;
561     if (data[0]&0x80)
562     {
563         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
564         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
565             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
566     } else if ((data[0]&0x60) == 0x60)
567     {
568         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
569     } else {
570
571         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
572     }
573     return bandwidth;
574 }
575
576 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
577 {
578         int audiosize;
579     if (data[0]&0x80)
580     {
581         audiosize = ((data[0]>>3)&0x3);
582         audiosize = (Fs<<audiosize)/400;
583     } else if ((data[0]&0x60) == 0x60)
584     {
585         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
586     } else {
587
588         audiosize = ((data[0]>>3)&0x3);
589         if (audiosize == 3)
590             audiosize = Fs*60/1000;
591         else
592             audiosize = (Fs<<audiosize)/100;
593     }
594     return audiosize;
595 }
596
597 int opus_packet_get_nb_channels(const unsigned char *data)
598 {
599     return (data[0]&0x4) ? 2 : 1;
600 }
601
602 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
603 {
604         int count;
605         if (len<1)
606                 return OPUS_BAD_ARG;
607         count = packet[0]&0x3;
608         if (count==0)
609                 return 1;
610         else if (count!=3)
611                 return 2;
612         else if (len<2)
613                 return OPUS_CORRUPTED_DATA;
614         else
615                 return packet[1]&0x3F;
616 }
617
618 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
619 {
620         int samples;
621         int count = opus_packet_get_nb_frames(packet, len);
622         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
623         /* Can't have more than 120 ms */
624         if (samples*25 > dec->Fs*3)
625                 return OPUS_CORRUPTED_DATA;
626         else
627                 return samples;
628 }
629