Fixes an encoder bugg when requesting a CBR rate over the allowable limit
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 /* Make sure everything's aligned to 4 bytes (this may need to be increased
43    on really weird architectures) */
44 static inline int align(int i)
45 {
46         return (i+3)&-4;
47 }
48
49 int opus_decoder_get_size(int channels)
50 {
51         int silkDecSizeBytes, celtDecSizeBytes;
52         int ret;
53     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
54         if(ret)
55                 return 0;
56         silkDecSizeBytes = align(silkDecSizeBytes);
57     celtDecSizeBytes = celt_decoder_get_size(channels);
58     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
59
60 }
61
62 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
63 {
64         void *silk_dec;
65         CELTDecoder *celt_dec;
66         int ret, silkDecSizeBytes;
67
68         if (channels<1 || channels > 2)
69             return NULL;
70         memset(st, 0, opus_decoder_get_size(channels));
71         /* Initialize SILK encoder */
72     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
73     if( ret ) {
74         return NULL;
75     }
76     silkDecSizeBytes = align(silkDecSizeBytes);
77     st->silk_dec_offset = align(sizeof(OpusDecoder));
78     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
79     silk_dec = (char*)st+st->silk_dec_offset;
80     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
81     st->stream_channels = st->channels = channels;
82
83     st->Fs = Fs;
84
85     /* Reset decoder */
86     ret = silk_InitDecoder( silk_dec );
87     if( ret ) {
88         goto failure;
89     }
90
91         /* Initialize CELT decoder */
92         celt_decoder_init(celt_dec, Fs, channels, &ret);
93         if (ret != CELT_OK)
94                 goto failure;
95     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
96
97         st->prev_mode = 0;
98         return st;
99 failure:
100     free(st);
101     return NULL;
102 }
103
104 OpusDecoder *opus_decoder_create(int Fs, int channels)
105 {
106     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
107     if (raw_state == NULL)
108         return NULL;
109     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
110 }
111
112 static void smooth_fade(const opus_int16 *in1, const opus_int16 *in2, opus_int16 *out,
113         int overlap, int channels, const opus_val16 *window, int Fs)
114 {
115         int i, c;
116         int inc = 48000/Fs;
117         for (c=0;c<channels;c++)
118         {
119                 for (i=0;i<overlap;i++)
120                 {
121                     opus_val16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
122                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
123                             Q15ONE-w, in1[i*channels+c]), 15);
124                 }
125         }
126 }
127
128 static int opus_packet_get_mode(const unsigned char *data)
129 {
130         int mode;
131     if (data[0]&0x80)
132     {
133         mode = MODE_CELT_ONLY;
134     } else if ((data[0]&0x60) == 0x60)
135     {
136         mode = MODE_HYBRID;
137     } else {
138
139         mode = MODE_SILK_ONLY;
140     }
141     return mode;
142 }
143
144 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
145                 int len, opus_int16 *pcm, int frame_size, int decode_fec)
146 {
147         void *silk_dec;
148         CELTDecoder *celt_dec;
149         int i, silk_ret=0, celt_ret=0;
150         ec_dec dec;
151     silk_DecControlStruct DecControl;
152     opus_int32 silk_frame_size;
153     opus_int16 pcm_celt[960*2];
154     opus_int16 pcm_transition[480*2];
155
156     int audiosize;
157     int mode;
158     int transition=0;
159     int start_band;
160     int redundancy=0;
161     int redundancy_bytes = 0;
162     int celt_to_silk=0;
163     opus_int16 redundant_audio[240*2];
164     int c;
165     int F2_5, F5, F10, F20;
166     const opus_val16 *window;
167
168     silk_dec = (char*)st+st->silk_dec_offset;
169     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
170     F20 = st->Fs/50;
171     F10 = F20>>1;
172     F5 = F10>>1;
173     F2_5 = F5>>1;
174     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
175     if (len<=1)
176     {
177         data = NULL;
178         /* In that case, don't conceal more than what the ToC says */
179         frame_size = IMIN(frame_size, st->frame_size);
180     }
181     if (data != NULL)
182     {
183         audiosize = st->frame_size;
184         mode = st->mode;
185         ec_dec_init(&dec,(unsigned char*)data,len);
186     } else {
187         audiosize = frame_size;
188         if (st->prev_mode == 0)
189         {
190                 /* If we haven't got any packet yet, all we can do is return zeros */
191                 for (i=0;i<audiosize;i++)
192                         pcm[i] = 0;
193                 return audiosize;
194         } else {
195                 mode = st->prev_mode;
196         }
197     }
198
199     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
200                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
201                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
202     {
203         transition = 1;
204         if (mode == MODE_CELT_ONLY)
205             opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
206     }
207     if (audiosize > frame_size)
208     {
209         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
210         return OPUS_BAD_ARG;
211     } else {
212         frame_size = audiosize;
213     }
214
215     /* SILK processing */
216     if (mode != MODE_CELT_ONLY)
217     {
218         int lost_flag, decoded_samples;
219         opus_int16 *pcm_ptr = pcm;
220
221         if (st->prev_mode==MODE_CELT_ONLY)
222                 silk_InitDecoder( silk_dec );
223
224         DecControl.API_sampleRate = st->Fs;
225         DecControl.nChannelsAPI      = st->channels;
226         DecControl.nChannelsInternal = st->stream_channels;
227         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
228         if( mode == MODE_SILK_ONLY ) {
229             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
230                 DecControl.internalSampleRate = 8000;
231             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
232                 DecControl.internalSampleRate = 12000;
233             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
234                 DecControl.internalSampleRate = 16000;
235             } else {
236                 DecControl.internalSampleRate = 16000;
237                 SKP_assert( 0 );
238             }
239         } else {
240             /* Hybrid mode */
241             DecControl.internalSampleRate = 16000;
242         }
243
244         lost_flag = data == NULL ? 1 : 2 * decode_fec;
245         decoded_samples = 0;
246         do {
247             /* Call SILK decoder */
248             int first_frame = decoded_samples == 0;
249             silk_ret = silk_Decode( silk_dec, &DecControl,
250                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
251             if( silk_ret ) {
252                 if (lost_flag) {
253                         /* PLC failure should not be fatal */
254                         silk_frame_size = frame_size;
255                         for (i=0;i<frame_size*st->channels;i++)
256                                 pcm_ptr[i] = 0;
257                 } else
258                     return OPUS_CORRUPTED_DATA;
259             }
260             pcm_ptr += silk_frame_size * st->channels;
261             decoded_samples += silk_frame_size;
262         } while( decoded_samples < frame_size );
263     } else {
264         for (i=0;i<frame_size*st->channels;i++)
265             pcm[i] = 0;
266     }
267
268     start_band = 0;
269     if (mode != MODE_CELT_ONLY && data != NULL)
270     {
271         /* Check if we have a redundant 0-8 kHz band */
272         redundancy = ec_dec_bit_logp(&dec, 12);
273         if (redundancy)
274         {
275             celt_to_silk = ec_dec_bit_logp(&dec, 1);
276             if (mode == MODE_HYBRID)
277                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
278             else {
279                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
280                 /* Can only happen on an invalid packet */
281                 if (redundancy_bytes<0)
282                 {
283                         redundancy_bytes = 0;
284                         redundancy = 0;
285                 }
286             }
287             len -= redundancy_bytes;
288             if (len<0)
289                 return OPUS_CORRUPTED_DATA;
290             /* Shrink decoder because of raw bits */
291             dec.storage -= redundancy_bytes;
292         }
293     }
294     if (mode != MODE_CELT_ONLY)
295         start_band = 17;
296
297     {
298         int endband=21;
299
300         switch(st->bandwidth)
301         {
302         case OPUS_BANDWIDTH_NARROWBAND:
303             endband = 13;
304             break;
305         case OPUS_BANDWIDTH_MEDIUMBAND:
306         case OPUS_BANDWIDTH_WIDEBAND:
307             endband = 17;
308             break;
309         case OPUS_BANDWIDTH_SUPERWIDEBAND:
310             endband = 19;
311             break;
312         case OPUS_BANDWIDTH_FULLBAND:
313             endband = 21;
314             break;
315         }
316         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
317         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
318     }
319
320     if (redundancy)
321         transition = 0;
322
323     if (transition && mode != MODE_CELT_ONLY)
324         opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
325
326     /* 5 ms redundant frame for CELT->SILK*/
327     if (redundancy && celt_to_silk)
328     {
329         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
330         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
331         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
332     }
333
334     /* MUST be after PLC */
335     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
336
337     if (transition)
338         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
339
340     if (mode != MODE_SILK_ONLY)
341     {
342         int celt_frame_size = IMIN(F20, frame_size);
343         /* Decode CELT */
344         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, celt_frame_size, &dec);
345         for (i=0;i<celt_frame_size*st->channels;i++)
346             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
347     }
348
349     {
350         const CELTMode *celt_mode;
351         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
352         window = celt_mode->window;
353     }
354
355     /* 5 ms redundant frame for SILK->CELT */
356     if (redundancy && !celt_to_silk)
357     {
358         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
359         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
360
361         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
362         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
363                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
364     }
365     if (redundancy && celt_to_silk)
366     {
367         for (c=0;c<st->channels;c++)
368         {
369             for (i=0;i<F2_5;i++)
370                 pcm[st->channels*i+c] = redundant_audio[st->channels*i+c];
371         }
372         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
373                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
374     }
375     if (transition)
376     {
377         for (i=0;i<st->channels*F2_5;i++)
378                 pcm[i] = pcm_transition[i];
379         if (audiosize >= F5)
380             smooth_fade(pcm_transition+st->channels*F2_5, pcm+st->channels*F2_5,
381                     pcm+st->channels*F2_5, F2_5,
382                     st->channels, window, st->Fs);
383     }
384
385     st->rangeFinal = dec.rng;
386
387     st->prev_mode = mode;
388     st->prev_redundancy = redundancy;
389         return celt_ret<0 ? celt_ret : audiosize;
390
391 }
392
393 static int parse_size(const unsigned char *data, int len, short *size)
394 {
395         if (len<1)
396         {
397                 *size = -1;
398                 return -1;
399         } else if (data[0]<252)
400         {
401                 *size = data[0];
402                 return 1;
403         } else if (len<2)
404         {
405                 *size = -1;
406                 return -1;
407         } else {
408                 *size = 4*data[1] + data[0];
409                 return 2;
410         }
411 }
412
413 int opus_packet_parse(const unsigned char *data, int len,
414       unsigned char *out_toc, const unsigned char *frames[48],
415       short size[48], const unsigned char **payload)
416 {
417    int i, bytes;
418    int count;
419    unsigned char ch, toc;
420    int framesize;
421
422    if (size==NULL)
423       return OPUS_BAD_ARG;
424
425    framesize = opus_packet_get_samples_per_frame(data, 48000);
426
427    toc = *data++;
428    len--;
429    switch (toc&0x3)
430    {
431    /* One frame */
432    case 0:
433       count=1;
434       size[0] = len;
435       break;
436       /* Two CBR frames */
437    case 1:
438       count=2;
439       if (len&0x1)
440          return OPUS_CORRUPTED_DATA;
441       size[0] = size[1] = len/2;
442       break;
443       /* Two VBR frames */
444    case 2:
445       count = 2;
446       bytes = parse_size(data, len, size);
447       len -= bytes;
448       if (size[0]<0 || size[0] > len)
449          return OPUS_CORRUPTED_DATA;
450       data += bytes;
451       size[1] = len-size[0];
452       break;
453       /* Multiple CBR/VBR frames (from 0 to 120 ms) */
454    case 3:
455       if (len<1)
456          return OPUS_CORRUPTED_DATA;
457       /* Number of frames encoded in bits 0 to 5 */
458       ch = *data++;
459       count = ch&0x3F;
460       if (count <= 0 || framesize*count > 5760)
461           return OPUS_CORRUPTED_DATA;
462       len--;
463       /* Padding flag is bit 6 */
464       if (ch&0x40)
465       {
466          int padding=0;
467          int p;
468          do {
469             if (len<=0)
470                return OPUS_CORRUPTED_DATA;
471             p = *data++;
472             len--;
473             padding += p==255 ? 254: p;
474          } while (p==255);
475          len -= padding;
476       }
477       if (len<0)
478          return OPUS_CORRUPTED_DATA;
479       /* VBR flag is bit 7 */
480       if (ch&0x80)
481       {
482          /* VBR case */
483          int last_size=len;
484          for (i=0;i<count-1;i++)
485          {
486             bytes = parse_size(data, len, size+i);
487             len -= bytes;
488             if (size[i]<0 || size[i] > len)
489                return OPUS_CORRUPTED_DATA;
490             data += bytes;
491             last_size -= bytes+size[i];
492          }
493          if (last_size<0)
494             return OPUS_CORRUPTED_DATA;
495          size[count-1]=last_size;
496       } else {
497          /* CBR case */
498          int sz = len/count;
499          if (sz*count!=len)
500             return OPUS_CORRUPTED_DATA;
501          for (i=0;i<count;i++)
502             size[i] = sz;
503       }
504       break;
505    }
506    /* Because it's not encoded explicitly, it's possible the size of the
507        last packet (or all the packets, for the CBR case) is larger than
508        1275.
509       Reject them here.*/
510    if (size[count-1] > 1275)
511       return OPUS_CORRUPTED_DATA;
512
513    if (frames)
514    {
515       for (i=0;i<count;i++)
516       {
517          frames[i] = data;
518          data += size[i];
519       }
520    }
521
522    if (out_toc)
523       *out_toc = toc;
524
525    if (payload)
526       *payload = data;
527
528    return count;
529 }
530
531
532 int opus_decode(OpusDecoder *st, const unsigned char *data,
533                 int len, opus_int16 *pcm, int frame_size, int decode_fec)
534 {
535         int i, nb_samples;
536         int count;
537         unsigned char toc;
538         /* 48 x 2.5 ms = 120 ms */
539         short size[48];
540         if (len==0 || data==NULL)
541             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
542         else if (len<0)
543                 return OPUS_BAD_ARG;
544         st->mode = opus_packet_get_mode(data);
545         st->bandwidth = opus_packet_get_bandwidth(data);
546         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
547         st->stream_channels = opus_packet_get_nb_channels(data);
548
549         count = opus_packet_parse(data, len, &toc, NULL, size, &data);
550         if (count < 0)
551            return count;
552
553         if (count*st->frame_size > frame_size)
554                 return OPUS_BAD_ARG;
555         nb_samples=0;
556         for (i=0;i<count;i++)
557         {
558                 int ret;
559                 ret = opus_decode_frame(st, data, size[i], pcm, frame_size-nb_samples, decode_fec);
560                 if (ret<0)
561                         return ret;
562                 data += size[i];
563                 pcm += ret*st->channels;
564                 nb_samples += ret;
565         }
566         return nb_samples;
567 }
568
569
570 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
571 {
572     va_list ap;
573
574     va_start(ap, request);
575
576     switch (request)
577     {
578         case OPUS_GET_MODE_REQUEST:
579         {
580             int *value = va_arg(ap, int*);
581             *value = st->prev_mode;
582         }
583         break;
584         case OPUS_SET_BANDWIDTH_REQUEST:
585         {
586             int value = va_arg(ap, int);
587             st->bandwidth = value;
588         }
589         break;
590         case OPUS_GET_BANDWIDTH_REQUEST:
591         {
592             int *value = va_arg(ap, int*);
593             *value = st->bandwidth;
594         }
595         break;
596         default:
597             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
598             break;
599     }
600
601     va_end(ap);
602     return OPUS_OK;
603 }
604
605 void opus_decoder_destroy(OpusDecoder *st)
606 {
607         free(st);
608 }
609
610 int opus_decoder_get_final_range(OpusDecoder *st)
611 {
612     return st->rangeFinal;
613 }
614
615 int opus_packet_get_bandwidth(const unsigned char *data)
616 {
617         int bandwidth;
618     if (data[0]&0x80)
619     {
620         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
621         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
622             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
623     } else if ((data[0]&0x60) == 0x60)
624     {
625         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
626     } else {
627
628         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
629     }
630     return bandwidth;
631 }
632
633 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
634 {
635         int audiosize;
636     if (data[0]&0x80)
637     {
638         audiosize = ((data[0]>>3)&0x3);
639         audiosize = (Fs<<audiosize)/400;
640     } else if ((data[0]&0x60) == 0x60)
641     {
642         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
643     } else {
644
645         audiosize = ((data[0]>>3)&0x3);
646         if (audiosize == 3)
647             audiosize = Fs*60/1000;
648         else
649             audiosize = (Fs<<audiosize)/100;
650     }
651     return audiosize;
652 }
653
654 int opus_packet_get_nb_channels(const unsigned char *data)
655 {
656     return (data[0]&0x4) ? 2 : 1;
657 }
658
659 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
660 {
661         int count;
662         if (len<1)
663                 return OPUS_BAD_ARG;
664         count = packet[0]&0x3;
665         if (count==0)
666                 return 1;
667         else if (count!=3)
668                 return 2;
669         else if (len<2)
670                 return OPUS_CORRUPTED_DATA;
671         else
672                 return packet[1]&0x3F;
673 }
674
675 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
676 {
677         int samples;
678         int count = opus_packet_get_nb_frames(packet, len);
679         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
680         /* Can't have more than 120 ms */
681         if (samples*25 > dec->Fs*3)
682                 return OPUS_CORRUPTED_DATA;
683         else
684                 return samples;
685 }
686