Fixes an issue when triggering PLC before receiving any packet
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         memset(st, 0, opus_decoder_get_size(channels));
71         /* Initialize SILK encoder */
72     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
73     if( ret ) {
74         return NULL;
75     }
76     silkDecSizeBytes = align(silkDecSizeBytes);
77     st->silk_dec_offset = align(sizeof(OpusDecoder));
78     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
79     silk_dec = (char*)st+st->silk_dec_offset;
80     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
81     st->stream_channels = st->channels = channels;
82
83     st->Fs = Fs;
84
85     /* Reset decoder */
86     ret = silk_InitDecoder( silk_dec );
87     if( ret ) {
88         goto failure;
89     }
90
91         /* Initialize CELT decoder */
92         celt_decoder_init(celt_dec, Fs, channels, &ret);
93         if (ret != CELT_OK)
94                 goto failure;
95     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
96
97         st->prev_mode = 0;
98         return st;
99 failure:
100     free(st);
101     return NULL;
102 }
103
104 OpusDecoder *opus_decoder_create(int Fs, int channels)
105 {
106     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
107     if (raw_state == NULL)
108         return NULL;
109     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
110 }
111
112 static void smooth_fade(const short *in1, const short *in2, short *out,
113         int overlap, int channels, const celt_word16 *window, int Fs)
114 {
115         int i, c;
116         int inc = 48000/Fs;
117         for (c=0;c<channels;c++)
118         {
119                 for (i=0;i<overlap;i++)
120                 {
121                     celt_word16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
122                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
123                             Q15ONE-w, in1[i*channels+c]), 15);
124                 }
125         }
126 }
127
128 static int opus_packet_get_mode(const unsigned char *data)
129 {
130         int mode;
131     if (data[0]&0x80)
132     {
133         mode = MODE_CELT_ONLY;
134     } else if ((data[0]&0x60) == 0x60)
135     {
136         mode = MODE_HYBRID;
137     } else {
138
139         mode = MODE_SILK_ONLY;
140     }
141     return mode;
142 }
143
144 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
145                 int len, short *pcm, int frame_size, int decode_fec)
146 {
147         void *silk_dec;
148         CELTDecoder *celt_dec;
149         int i, silk_ret=0, celt_ret=0;
150         ec_dec dec;
151     silk_DecControlStruct DecControl;
152     SKP_int32 silk_frame_size;
153     short pcm_celt[960*2];
154     short pcm_transition[960*2];
155     int audiosize;
156     int mode;
157     int transition=0;
158     int start_band;
159     int redundancy=0;
160     int redundancy_bytes = 0;
161     int celt_to_silk=0;
162     short redundant_audio[240*2];
163     int c;
164     int F2_5, F5, F10;
165     const celt_word16 *window;
166
167     silk_dec = (char*)st+st->silk_dec_offset;
168     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
169     F10 = st->Fs/100;
170     F5 = F10>>1;
171     F2_5 = F5>>1;
172     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
173     if (len<=1)
174         data = NULL;
175
176         audiosize = st->frame_size;
177     if (data != NULL)
178     {
179         mode = st->mode;
180         ec_dec_init(&dec,(unsigned char*)data,len);
181     } else {
182         if (st->prev_mode == 0)
183         {
184                 /* If we haven't got any packet yet, all we can do is return zeros */
185                 for (i=0;i<frame_size;i++)
186                         pcm[i] = 0;
187                 return audiosize;
188         } else {
189                 mode = st->prev_mode;
190         }
191     }
192
193     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
194                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
195                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
196     {
197         transition = 1;
198         if (mode == MODE_CELT_ONLY)
199             opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
200     }
201     if (audiosize > frame_size)
202     {
203         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
204         return OPUS_BAD_ARG;
205     } else {
206         frame_size = audiosize;
207     }
208
209     /* SILK processing */
210     if (mode != MODE_CELT_ONLY)
211     {
212         int lost_flag, decoded_samples;
213         SKP_int16 *pcm_ptr = pcm;
214
215         if (st->prev_mode==MODE_CELT_ONLY)
216                 silk_InitDecoder( silk_dec );
217
218         DecControl.API_sampleRate = st->Fs;
219         DecControl.nChannelsAPI      = st->channels;
220         DecControl.nChannelsInternal = st->stream_channels;
221         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
222         if( mode == MODE_SILK_ONLY ) {
223             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
224                 DecControl.internalSampleRate = 8000;
225             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
226                 DecControl.internalSampleRate = 12000;
227             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
228                 DecControl.internalSampleRate = 16000;
229             } else {
230                 DecControl.internalSampleRate = 16000;
231                 SKP_assert( 0 );
232             }
233         } else {
234             /* Hybrid mode */
235             DecControl.internalSampleRate = 16000;
236         }
237
238         lost_flag = data == NULL ? 1 : 2 * decode_fec;
239         decoded_samples = 0;
240         do {
241             /* Call SILK decoder */
242             int first_frame = decoded_samples == 0;
243             silk_ret = silk_Decode( silk_dec, &DecControl, 
244                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
245             if( silk_ret ) {
246                 fprintf (stderr, "SILK decode error\n");
247                 /* Handle error */
248             }
249             pcm_ptr += silk_frame_size * st->channels;
250             decoded_samples += silk_frame_size;
251         } while( decoded_samples < frame_size );
252     } else {
253         for (i=0;i<frame_size*st->channels;i++)
254             pcm[i] = 0;
255     }
256
257     start_band = 0;
258     if (mode != MODE_CELT_ONLY && data != NULL)
259     {
260         /* Check if we have a redundant 0-8 kHz band */
261         redundancy = ec_dec_bit_logp(&dec, 12);
262         if (redundancy)
263         {
264             celt_to_silk = ec_dec_bit_logp(&dec, 1);
265             if (mode == MODE_HYBRID)
266                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
267             else
268                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
269             len -= redundancy_bytes;
270             if (len<0)
271                 return CELT_CORRUPTED_DATA;
272             /* Shrink decoder because of raw bits */
273             dec.storage -= redundancy_bytes;
274         }
275     }
276     if (mode != MODE_CELT_ONLY)
277         start_band = 17;
278
279     if (mode != MODE_SILK_ONLY)
280     {
281         int endband=21;
282
283         switch(st->bandwidth)
284         {
285         case OPUS_BANDWIDTH_NARROWBAND:
286             endband = 13;
287             break;
288         case OPUS_BANDWIDTH_WIDEBAND:
289             endband = 17;
290             break;
291         case OPUS_BANDWIDTH_SUPERWIDEBAND:
292             endband = 19;
293             break;
294         case OPUS_BANDWIDTH_FULLBAND:
295             endband = 21;
296             break;
297         }
298         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
299         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
300     }
301
302     if (redundancy)
303         transition = 0;
304
305     if (transition && mode != MODE_CELT_ONLY)
306         opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
307
308     /* 5 ms redundant frame for CELT->SILK*/
309     if (redundancy && celt_to_silk)
310     {
311         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
312         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
313     }
314
315     /* MUST be after PLC */
316     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
317
318     if (transition)
319         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
320
321     if (mode != MODE_SILK_ONLY)
322     {
323         /* Decode CELT */
324         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, frame_size, &dec);
325         for (i=0;i<frame_size*st->channels;i++)
326             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
327     }
328
329     {
330         const CELTMode *celt_mode;
331         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
332         window = celt_mode->window;
333     }
334
335     /* 5 ms redundant frame for SILK->CELT */
336     if (redundancy && !celt_to_silk)
337     {
338         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
339         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
340
341         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
342         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
343                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
344     }
345     if (redundancy && celt_to_silk)
346     {
347         for (c=0;c<st->channels;c++)
348         {
349             for (i=0;i<F2_5;i++)
350                 pcm[st->channels*i+c] = redundant_audio[st->channels*i];
351         }
352         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
353                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
354     }
355     if (transition)
356     {
357         for (i=0;i<F2_5;i++)
358                 pcm[i] = pcm_transition[i];
359         if (audiosize >= F5)
360             smooth_fade(pcm_transition+F2_5, pcm+F2_5, pcm+F2_5, F2_5,
361                     st->channels, window, st->Fs);
362     }
363 #if OPUS_TEST_RANGE_CODER_STATE
364     st->rangeFinal = dec.rng;
365 #endif
366
367     st->prev_mode = mode;
368     st->prev_redundancy = redundancy;
369         return celt_ret<0 ? celt_ret : audiosize;
370
371 }
372
373 static int parse_size(const unsigned char *data, int len, short *size)
374 {
375         if (len<1)
376         {
377                 *size = -1;
378                 return -1;
379         } else if (data[0]<252)
380         {
381                 *size = data[0];
382                 return 1;
383         } else if (len<2)
384         {
385                 *size = -1;
386                 return -1;
387         } else {
388                 *size = 4*data[1] + data[0];
389                 return 2;
390         }
391 }
392
393 int opus_decode(OpusDecoder *st, const unsigned char *data,
394                 int len, short *pcm, int frame_size, int decode_fec)
395 {
396         int i, bytes, nb_samples;
397         int count;
398         unsigned char ch, toc;
399         /* 48 x 2.5 ms = 120 ms */
400         short size[48];
401         if (len==0 || data==NULL)
402             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
403         else if (len<0)
404                 return CELT_BAD_ARG;
405         st->mode = opus_packet_get_mode(data);
406         st->bandwidth = opus_packet_get_bandwidth(data);
407         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
408         st->stream_channels = opus_packet_get_nb_channels(data);
409         toc = *data++;
410         len--;
411         switch (toc&0x3)
412         {
413         /* One frame */
414         case 0:
415                 count=1;
416                 size[0] = len;
417                 break;
418                 /* Two CBR frames */
419         case 1:
420                 count=2;
421                 if (len&0x1)
422                         return OPUS_CORRUPTED_DATA;
423                 size[0] = size[1] = len/2;
424                 break;
425                 /* Two VBR frames */
426         case 2:
427                 count = 2;
428                 bytes = parse_size(data, len, size);
429                 len -= bytes;
430                 if (size[0]<0 || size[0] > len)
431                         return OPUS_CORRUPTED_DATA;
432                 data += bytes;
433                 size[1] = len-size[0];
434                 break;
435                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
436         case 3:
437                 if (len<1)
438                         return OPUS_CORRUPTED_DATA;
439                 /* Number of frames encoded in bits 0 to 5 */
440                 ch = *data++;
441                 count = ch&0x3F;
442                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
443                     return OPUS_CORRUPTED_DATA;
444                 len--;
445                 /* Padding bit */
446                 if (ch&0x40)
447                 {
448                         int padding=0;
449                         int p;
450                         do {
451                                 if (len<=0)
452                                         return OPUS_CORRUPTED_DATA;
453                                 p = *data++;
454                                 len--;
455                                 padding += p==255 ? 254: p;
456                         } while (p==255);
457                         len -= padding;
458                 }
459                 if (len<0)
460                         return OPUS_CORRUPTED_DATA;
461                 /* Bit 7 is VBR flag (bit 6 is ignored) */
462                 if (ch&0x80)
463                 {
464                         /* VBR case */
465                         int last_size=len;
466                         for (i=0;i<count-1;i++)
467                         {
468                                 bytes = parse_size(data, len, size+i);
469                                 len -= bytes;
470                                 if (size[i]<0 || size[i] > len)
471                                         return OPUS_CORRUPTED_DATA;
472                                 data += bytes;
473                                 last_size -= bytes+size[i];
474                         }
475                         if (last_size<0)
476                                 return OPUS_CORRUPTED_DATA;
477                         size[count-1]=last_size;
478                 } else {
479                         /* CBR case */
480                         int sz = len/count;
481                         if (sz*count!=len)
482                                 return OPUS_CORRUPTED_DATA;
483                         for (i=0;i<count;i++)
484                                 size[i] = sz;
485                 }
486                 break;
487         }
488         /* Because it's not encoded explicitly, it's possible the size of the
489             last packet (or all the packets, for the CBR case) is larger than
490             1275.
491            Reject them here.*/
492         if (size[count-1] > MAX_PACKET)
493                 return OPUS_CORRUPTED_DATA;
494         if (count*st->frame_size > frame_size)
495                 return OPUS_BAD_ARG;
496         nb_samples=0;
497         for (i=0;i<count;i++)
498         {
499                 int ret;
500                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
501                 if (ret<0)
502                         return ret;
503                 data += size[i];
504                 pcm += ret;
505                 nb_samples += ret;
506         }
507         return nb_samples;
508 }
509 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
510 {
511     va_list ap;
512
513     va_start(ap, request);
514
515     switch (request)
516     {
517         case OPUS_GET_MODE_REQUEST:
518         {
519             int *value = va_arg(ap, int*);
520             *value = st->prev_mode;
521         }
522         break;
523         case OPUS_SET_BANDWIDTH_REQUEST:
524         {
525             int value = va_arg(ap, int);
526             st->bandwidth = value;
527         }
528         break;
529         case OPUS_GET_BANDWIDTH_REQUEST:
530         {
531             int *value = va_arg(ap, int*);
532             *value = st->bandwidth;
533         }
534         break;
535         default:
536             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
537             break;
538     }
539
540     va_end(ap);
541     return OPUS_OK;
542 }
543
544 void opus_decoder_destroy(OpusDecoder *st)
545 {
546         free(st);
547 }
548
549 #if OPUS_TEST_RANGE_CODER_STATE
550 int opus_decoder_get_final_range(OpusDecoder *st)
551 {
552     return st->rangeFinal;
553 }
554 #endif
555
556
557 int opus_packet_get_bandwidth(const unsigned char *data)
558 {
559         int bandwidth;
560     if (data[0]&0x80)
561     {
562         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
563         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
564             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
565     } else if ((data[0]&0x60) == 0x60)
566     {
567         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
568     } else {
569
570         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
571     }
572     return bandwidth;
573 }
574
575 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
576 {
577         int audiosize;
578     if (data[0]&0x80)
579     {
580         audiosize = ((data[0]>>3)&0x3);
581         audiosize = (Fs<<audiosize)/400;
582     } else if ((data[0]&0x60) == 0x60)
583     {
584         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
585     } else {
586
587         audiosize = ((data[0]>>3)&0x3);
588         if (audiosize == 3)
589             audiosize = Fs*60/1000;
590         else
591             audiosize = (Fs<<audiosize)/100;
592     }
593     return audiosize;
594 }
595
596 int opus_packet_get_nb_channels(const unsigned char *data)
597 {
598     return (data[0]&0x4) ? 2 : 1;
599 }
600
601 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
602 {
603         int count;
604         if (len<1)
605                 return OPUS_BAD_ARG;
606         count = packet[0]&0x3;
607         if (count==0)
608                 return 1;
609         else if (count!=3)
610                 return 2;
611         else if (len<2)
612                 return OPUS_CORRUPTED_DATA;
613         else
614                 return packet[1]&0x3F;
615 }
616
617 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
618 {
619         int samples;
620         int count = opus_packet_get_nb_frames(packet, len);
621         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
622         /* Can't have more than 120 ms */
623         if (samples*25 > dec->Fs*3)
624                 return OPUS_CORRUPTED_DATA;
625         else
626                 return samples;
627 }
628