Renamed celt_word* to opus_val*
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <string.h>
33 #include <stdlib.h>
34 #include <stdio.h>
35 #include <stdarg.h>
36 #include "celt.h"
37 #include "opus_decoder.h"
38 #include "entdec.h"
39 #include "modes.h"
40 #include "silk_API.h"
41
42 #define MAX_PACKET (1275)
43
44 /* Make sure everything's aligned to 4 bytes (this may need to be increased
45    on really weird architectures) */
46 static inline int align(int i)
47 {
48         return (i+3)&-4;
49 }
50
51 int opus_decoder_get_size(int channels)
52 {
53         int silkDecSizeBytes, celtDecSizeBytes;
54         int ret;
55     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
56         if(ret)
57                 return 0;
58         silkDecSizeBytes = align(silkDecSizeBytes);
59     celtDecSizeBytes = celt_decoder_get_size(channels);
60     return align(sizeof(OpusDecoder))+silkDecSizeBytes+celtDecSizeBytes;
61
62 }
63
64 OpusDecoder *opus_decoder_init(OpusDecoder *st, int Fs, int channels)
65 {
66         void *silk_dec;
67         CELTDecoder *celt_dec;
68         int ret, silkDecSizeBytes;
69
70         if (channels<1 || channels > 2)
71             return NULL;
72         memset(st, 0, opus_decoder_get_size(channels));
73         /* Initialize SILK encoder */
74     ret = silk_Get_Decoder_Size( &silkDecSizeBytes );
75     if( ret ) {
76         return NULL;
77     }
78     silkDecSizeBytes = align(silkDecSizeBytes);
79     st->silk_dec_offset = align(sizeof(OpusDecoder));
80     st->celt_dec_offset = st->silk_dec_offset+silkDecSizeBytes;
81     silk_dec = (char*)st+st->silk_dec_offset;
82     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
83     st->stream_channels = st->channels = channels;
84
85     st->Fs = Fs;
86
87     /* Reset decoder */
88     ret = silk_InitDecoder( silk_dec );
89     if( ret ) {
90         goto failure;
91     }
92
93         /* Initialize CELT decoder */
94         celt_decoder_init(celt_dec, Fs, channels, &ret);
95         if (ret != CELT_OK)
96                 goto failure;
97     celt_decoder_ctl(celt_dec, CELT_SET_SIGNALLING(0));
98
99         st->prev_mode = 0;
100         return st;
101 failure:
102     free(st);
103     return NULL;
104 }
105
106 OpusDecoder *opus_decoder_create(int Fs, int channels)
107 {
108     char *raw_state = (char*)malloc(opus_decoder_get_size(channels));
109     if (raw_state == NULL)
110         return NULL;
111     return opus_decoder_init((OpusDecoder*)raw_state, Fs, channels);
112 }
113
114 static void smooth_fade(const short *in1, const short *in2, short *out,
115         int overlap, int channels, const opus_val16 *window, int Fs)
116 {
117         int i, c;
118         int inc = 48000/Fs;
119         for (c=0;c<channels;c++)
120         {
121                 for (i=0;i<overlap;i++)
122                 {
123                     opus_val16 w = MULT16_16_Q15(window[i*inc], window[i*inc]);
124                     out[i*channels+c] = SHR32(MAC16_16(MULT16_16(w,in2[i*channels+c]),
125                             Q15ONE-w, in1[i*channels+c]), 15);
126                 }
127         }
128 }
129
130 static int opus_packet_get_mode(const unsigned char *data)
131 {
132         int mode;
133     if (data[0]&0x80)
134     {
135         mode = MODE_CELT_ONLY;
136     } else if ((data[0]&0x60) == 0x60)
137     {
138         mode = MODE_HYBRID;
139     } else {
140
141         mode = MODE_SILK_ONLY;
142     }
143     return mode;
144 }
145
146 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
147                 int len, short *pcm, int frame_size, int decode_fec)
148 {
149         void *silk_dec;
150         CELTDecoder *celt_dec;
151         int i, silk_ret=0, celt_ret=0;
152         ec_dec dec;
153     silk_DecControlStruct DecControl;
154     opus_int32 silk_frame_size;
155     short pcm_celt[960*2];
156     short pcm_transition[480*2];
157
158     int audiosize;
159     int mode;
160     int transition=0;
161     int start_band;
162     int redundancy=0;
163     int redundancy_bytes = 0;
164     int celt_to_silk=0;
165     short redundant_audio[240*2];
166     int c;
167     int F2_5, F5, F10, F20;
168     const opus_val16 *window;
169
170     silk_dec = (char*)st+st->silk_dec_offset;
171     celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
172     F20 = st->Fs/50;
173     F10 = F20>>1;
174     F5 = F10>>1;
175     F2_5 = F5>>1;
176     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
177     if (len<=1)
178     {
179         data = NULL;
180         /* In that case, don't conceal more than what the ToC says */
181         frame_size = IMIN(frame_size, st->frame_size);
182     }
183     if (data != NULL)
184     {
185         audiosize = st->frame_size;
186         mode = st->mode;
187         ec_dec_init(&dec,(unsigned char*)data,len);
188     } else {
189         audiosize = frame_size;
190         if (st->prev_mode == 0)
191         {
192                 /* If we haven't got any packet yet, all we can do is return zeros */
193                 for (i=0;i<audiosize;i++)
194                         pcm[i] = 0;
195                 return audiosize;
196         } else {
197                 mode = st->prev_mode;
198         }
199     }
200
201     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
202                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
203                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
204     {
205         transition = 1;
206         if (mode == MODE_CELT_ONLY)
207             opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
208     }
209     if (audiosize > frame_size)
210     {
211         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
212         return OPUS_BAD_ARG;
213     } else {
214         frame_size = audiosize;
215     }
216
217     /* SILK processing */
218     if (mode != MODE_CELT_ONLY)
219     {
220         int lost_flag, decoded_samples;
221         opus_int16 *pcm_ptr = pcm;
222
223         if (st->prev_mode==MODE_CELT_ONLY)
224                 silk_InitDecoder( silk_dec );
225
226         DecControl.API_sampleRate = st->Fs;
227         DecControl.nChannelsAPI      = st->channels;
228         DecControl.nChannelsInternal = st->stream_channels;
229         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
230         if( mode == MODE_SILK_ONLY ) {
231             if( st->bandwidth == OPUS_BANDWIDTH_NARROWBAND ) {
232                 DecControl.internalSampleRate = 8000;
233             } else if( st->bandwidth == OPUS_BANDWIDTH_MEDIUMBAND ) {
234                 DecControl.internalSampleRate = 12000;
235             } else if( st->bandwidth == OPUS_BANDWIDTH_WIDEBAND ) {
236                 DecControl.internalSampleRate = 16000;
237             } else {
238                 DecControl.internalSampleRate = 16000;
239                 SKP_assert( 0 );
240             }
241         } else {
242             /* Hybrid mode */
243             DecControl.internalSampleRate = 16000;
244         }
245
246         lost_flag = data == NULL ? 1 : 2 * decode_fec;
247         decoded_samples = 0;
248         do {
249             /* Call SILK decoder */
250             int first_frame = decoded_samples == 0;
251             silk_ret = silk_Decode( silk_dec, &DecControl, 
252                 lost_flag, first_frame, &dec, pcm_ptr, &silk_frame_size );
253             if( silk_ret ) {
254                 if (lost_flag) {
255                         /* PLC failure should not be fatal */
256                         silk_frame_size = frame_size;
257                         for (i=0;i<frame_size*st->channels;i++)
258                                 pcm_ptr[i] = 0;
259                 } else
260                     return OPUS_CORRUPTED_DATA;
261             }
262             pcm_ptr += silk_frame_size * st->channels;
263             decoded_samples += silk_frame_size;
264         } while( decoded_samples < frame_size );
265     } else {
266         for (i=0;i<frame_size*st->channels;i++)
267             pcm[i] = 0;
268     }
269
270     start_band = 0;
271     if (mode != MODE_CELT_ONLY && data != NULL)
272     {
273         /* Check if we have a redundant 0-8 kHz band */
274         redundancy = ec_dec_bit_logp(&dec, 12);
275         if (redundancy)
276         {
277             celt_to_silk = ec_dec_bit_logp(&dec, 1);
278             if (mode == MODE_HYBRID)
279                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
280             else {
281                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
282                 /* Can only happen on an invalid packet */
283                 if (redundancy_bytes<0)
284                 {
285                         redundancy_bytes = 0;
286                         redundancy = 0;
287                 }
288             }
289             len -= redundancy_bytes;
290             if (len<0)
291                 return OPUS_CORRUPTED_DATA;
292             /* Shrink decoder because of raw bits */
293             dec.storage -= redundancy_bytes;
294         }
295     }
296     if (mode != MODE_CELT_ONLY)
297         start_band = 17;
298
299     {
300         int endband=21;
301
302         switch(st->bandwidth)
303         {
304         case OPUS_BANDWIDTH_NARROWBAND:
305             endband = 13;
306             break;
307         case OPUS_BANDWIDTH_MEDIUMBAND:
308         case OPUS_BANDWIDTH_WIDEBAND:
309             endband = 17;
310             break;
311         case OPUS_BANDWIDTH_SUPERWIDEBAND:
312             endband = 19;
313             break;
314         case OPUS_BANDWIDTH_FULLBAND:
315             endband = 21;
316             break;
317         }
318         celt_decoder_ctl(celt_dec, CELT_SET_END_BAND(endband));
319         celt_decoder_ctl(celt_dec, CELT_SET_CHANNELS(st->stream_channels));
320     }
321
322     if (redundancy)
323         transition = 0;
324
325     if (transition && mode != MODE_CELT_ONLY)
326         opus_decode_frame(st, NULL, 0, pcm_transition, IMIN(F10, audiosize), 0);
327
328     /* 5 ms redundant frame for CELT->SILK*/
329     if (redundancy && celt_to_silk)
330     {
331         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
332         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
333         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
334     }
335
336     /* MUST be after PLC */
337     celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(start_band));
338
339     if (transition)
340         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
341
342     if (mode != MODE_SILK_ONLY)
343     {
344         int celt_frame_size = IMIN(F20, frame_size);
345         /* Decode CELT */
346         celt_ret = celt_decode_with_ec(celt_dec, decode_fec?NULL:data, len, pcm_celt, celt_frame_size, &dec);
347         for (i=0;i<celt_frame_size*st->channels;i++)
348             pcm[i] = SAT16(pcm[i] + (int)pcm_celt[i]);
349     }
350
351     {
352         const CELTMode *celt_mode;
353         celt_decoder_ctl(celt_dec, CELT_GET_MODE(&celt_mode));
354         window = celt_mode->window;
355     }
356
357     /* 5 ms redundant frame for SILK->CELT */
358     if (redundancy && !celt_to_silk)
359     {
360         celt_decoder_ctl(celt_dec, CELT_RESET_STATE);
361         celt_decoder_ctl(celt_dec, CELT_SET_START_BAND(0));
362
363         celt_decode(celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
364         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
365                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels, window, st->Fs);
366     }
367     if (redundancy && celt_to_silk)
368     {
369         for (c=0;c<st->channels;c++)
370         {
371             for (i=0;i<F2_5;i++)
372                 pcm[st->channels*i+c] = redundant_audio[st->channels*i+c];
373         }
374         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5,
375                 pcm+st->channels*F2_5, F2_5, st->channels, window, st->Fs);
376     }
377     if (transition)
378     {
379         for (i=0;i<st->channels*F2_5;i++)
380                 pcm[i] = pcm_transition[i];
381         if (audiosize >= F5)
382             smooth_fade(pcm_transition+st->channels*F2_5, pcm+st->channels*F2_5,
383                     pcm+st->channels*F2_5, F2_5,
384                     st->channels, window, st->Fs);
385     }
386
387     st->rangeFinal = dec.rng;
388
389     st->prev_mode = mode;
390     st->prev_redundancy = redundancy;
391         return celt_ret<0 ? celt_ret : audiosize;
392
393 }
394
395 static int parse_size(const unsigned char *data, int len, short *size)
396 {
397         if (len<1)
398         {
399                 *size = -1;
400                 return -1;
401         } else if (data[0]<252)
402         {
403                 *size = data[0];
404                 return 1;
405         } else if (len<2)
406         {
407                 *size = -1;
408                 return -1;
409         } else {
410                 *size = 4*data[1] + data[0];
411                 return 2;
412         }
413 }
414
415 int opus_decode(OpusDecoder *st, const unsigned char *data,
416                 int len, short *pcm, int frame_size, int decode_fec)
417 {
418         int i, bytes, nb_samples;
419         int count;
420         unsigned char ch, toc;
421         /* 48 x 2.5 ms = 120 ms */
422         short size[48];
423         if (len==0 || data==NULL)
424             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
425         else if (len<0)
426                 return OPUS_BAD_ARG;
427         st->mode = opus_packet_get_mode(data);
428         st->bandwidth = opus_packet_get_bandwidth(data);
429         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
430         st->stream_channels = opus_packet_get_nb_channels(data);
431         toc = *data++;
432         len--;
433         switch (toc&0x3)
434         {
435         /* One frame */
436         case 0:
437                 count=1;
438                 size[0] = len;
439                 break;
440                 /* Two CBR frames */
441         case 1:
442                 count=2;
443                 if (len&0x1)
444                         return OPUS_CORRUPTED_DATA;
445                 size[0] = size[1] = len/2;
446                 break;
447                 /* Two VBR frames */
448         case 2:
449                 count = 2;
450                 bytes = parse_size(data, len, size);
451                 len -= bytes;
452                 if (size[0]<0 || size[0] > len)
453                         return OPUS_CORRUPTED_DATA;
454                 data += bytes;
455                 size[1] = len-size[0];
456                 break;
457                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
458         case 3:
459                 if (len<1)
460                         return OPUS_CORRUPTED_DATA;
461                 /* Number of frames encoded in bits 0 to 5 */
462                 ch = *data++;
463                 count = ch&0x3F;
464                 if (count <= 0 || st->frame_size*count*25 > 3*st->Fs)
465                     return OPUS_CORRUPTED_DATA;
466                 len--;
467                 /* Padding flag is bit 6 */
468                 if (ch&0x40)
469                 {
470                         int padding=0;
471                         int p;
472                         do {
473                                 if (len<=0)
474                                         return OPUS_CORRUPTED_DATA;
475                                 p = *data++;
476                                 len--;
477                                 padding += p==255 ? 254: p;
478                         } while (p==255);
479                         len -= padding;
480                 }
481                 if (len<0)
482                         return OPUS_CORRUPTED_DATA;
483                 /* VBR flag is bit 7 */
484                 if (ch&0x80)
485                 {
486                         /* VBR case */
487                         int last_size=len;
488                         for (i=0;i<count-1;i++)
489                         {
490                                 bytes = parse_size(data, len, size+i);
491                                 len -= bytes;
492                                 if (size[i]<0 || size[i] > len)
493                                         return OPUS_CORRUPTED_DATA;
494                                 data += bytes;
495                                 last_size -= bytes+size[i];
496                         }
497                         if (last_size<0)
498                                 return OPUS_CORRUPTED_DATA;
499                         size[count-1]=last_size;
500                 } else {
501                         /* CBR case */
502                         int sz = len/count;
503                         if (sz*count!=len)
504                                 return OPUS_CORRUPTED_DATA;
505                         for (i=0;i<count;i++)
506                                 size[i] = sz;
507                 }
508                 break;
509         }
510         /* Because it's not encoded explicitly, it's possible the size of the
511             last packet (or all the packets, for the CBR case) is larger than
512             1275.
513            Reject them here.*/
514         if (size[count-1] > MAX_PACKET)
515                 return OPUS_CORRUPTED_DATA;
516         if (count*st->frame_size > frame_size)
517                 return OPUS_BAD_ARG;
518         nb_samples=0;
519         for (i=0;i<count;i++)
520         {
521                 int ret;
522                 ret = opus_decode_frame(st, data, size[i], pcm, frame_size-nb_samples, decode_fec);
523                 if (ret<0)
524                         return ret;
525                 data += size[i];
526                 pcm += ret*st->channels;
527                 nb_samples += ret;
528         }
529         return nb_samples;
530 }
531 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
532 {
533     va_list ap;
534
535     va_start(ap, request);
536
537     switch (request)
538     {
539         case OPUS_GET_MODE_REQUEST:
540         {
541             int *value = va_arg(ap, int*);
542             *value = st->prev_mode;
543         }
544         break;
545         case OPUS_SET_BANDWIDTH_REQUEST:
546         {
547             int value = va_arg(ap, int);
548             st->bandwidth = value;
549         }
550         break;
551         case OPUS_GET_BANDWIDTH_REQUEST:
552         {
553             int *value = va_arg(ap, int*);
554             *value = st->bandwidth;
555         }
556         break;
557         default:
558             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
559             break;
560     }
561
562     va_end(ap);
563     return OPUS_OK;
564 }
565
566 void opus_decoder_destroy(OpusDecoder *st)
567 {
568         free(st);
569 }
570
571 int opus_decoder_get_final_range(OpusDecoder *st)
572 {
573     return st->rangeFinal;
574 }
575
576 int opus_packet_get_bandwidth(const unsigned char *data)
577 {
578         int bandwidth;
579     if (data[0]&0x80)
580     {
581         bandwidth = OPUS_BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
582         if (bandwidth == OPUS_BANDWIDTH_MEDIUMBAND)
583             bandwidth = OPUS_BANDWIDTH_NARROWBAND;
584     } else if ((data[0]&0x60) == 0x60)
585     {
586         bandwidth = (data[0]&0x10) ? OPUS_BANDWIDTH_FULLBAND : OPUS_BANDWIDTH_SUPERWIDEBAND;
587     } else {
588
589         bandwidth = OPUS_BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
590     }
591     return bandwidth;
592 }
593
594 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
595 {
596         int audiosize;
597     if (data[0]&0x80)
598     {
599         audiosize = ((data[0]>>3)&0x3);
600         audiosize = (Fs<<audiosize)/400;
601     } else if ((data[0]&0x60) == 0x60)
602     {
603         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
604     } else {
605
606         audiosize = ((data[0]>>3)&0x3);
607         if (audiosize == 3)
608             audiosize = Fs*60/1000;
609         else
610             audiosize = (Fs<<audiosize)/100;
611     }
612     return audiosize;
613 }
614
615 int opus_packet_get_nb_channels(const unsigned char *data)
616 {
617     return (data[0]&0x4) ? 2 : 1;
618 }
619
620 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
621 {
622         int count;
623         if (len<1)
624                 return OPUS_BAD_ARG;
625         count = packet[0]&0x3;
626         if (count==0)
627                 return 1;
628         else if (count!=3)
629                 return 2;
630         else if (len<2)
631                 return OPUS_CORRUPTED_DATA;
632         else
633                 return packet[1]&0x3F;
634 }
635
636 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
637 {
638         int samples;
639         int count = opus_packet_get_nb_frames(packet, len);
640         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
641         /* Can't have more than 120 ms */
642         if (samples*25 > dec->Fs*3)
643                 return OPUS_CORRUPTED_DATA;
644         else
645                 return samples;
646 }
647