Minor code simplifications
[opus.git] / src / opus_decoder.c
1 /* Copyright (c) 2010 Xiph.Org Foundation, Skype Limited
2    Written by Jean-Marc Valin and Koen Vos */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31
32 #include <stdlib.h>
33 #include <stdio.h>
34 #include <stdarg.h>
35 #include "celt.h"
36 #include "opus_decoder.h"
37 #include "entdec.h"
38 #include "modes.h"
39 #include "SKP_Silk_SDK_API.h"
40
41
42 OpusDecoder *opus_decoder_create(int Fs, int channels)
43 {
44     char *raw_state;
45         int ret, silkDecSizeBytes, celtDecSizeBytes;
46         OpusDecoder *st;
47
48         /* Initialize SILK encoder */
49     ret = SKP_Silk_SDK_Get_Decoder_Size( &silkDecSizeBytes );
50     if( ret ) {
51         /* Handle error */
52     }
53     celtDecSizeBytes = celt_decoder_get_size(channels);
54     raw_state = calloc(sizeof(OpusDecoder)+silkDecSizeBytes+celtDecSizeBytes, 1);
55     st = (OpusDecoder*)raw_state;
56     st->silk_dec = (void*)(raw_state+sizeof(OpusDecoder));
57     st->celt_dec = (CELTDecoder*)(raw_state+sizeof(OpusDecoder)+silkDecSizeBytes);
58     st->stream_channels = st->channels = channels;
59
60     st->Fs = Fs;
61
62     /* Reset decoder */
63     ret = SKP_Silk_SDK_InitDecoder( st->silk_dec );
64     if( ret ) {
65         /* Handle error */
66     }
67
68         /* Initialize CELT decoder */
69         st->celt_dec = celt_decoder_init(st->celt_dec, Fs, channels, NULL);
70     celt_decoder_ctl(st->celt_dec, CELT_SET_SIGNALLING(0));
71
72         st->prev_mode = 0;
73         return st;
74 }
75
76 static void smooth_fade(const short *in1, const short *in2, short *out, int overlap, int channels)
77 {
78         int i, c;
79         for (c=0;c<channels;c++)
80         {
81                 /* FIXME: Make this 16-bit safe, remove division */
82                 for (i=0;i<overlap;i++)
83                         out[i*channels+c] = (i*in2[i*channels+c] + (overlap-i)*in1[i*channels+c])/overlap;
84         }
85 }
86
87 static int opus_packet_get_mode(const unsigned char *data)
88 {
89         int mode;
90     if (data[0]&0x80)
91     {
92         mode = MODE_CELT_ONLY;
93     } else if ((data[0]&0x60) == 0x60)
94     {
95         mode = MODE_HYBRID;
96     } else {
97
98         mode = MODE_SILK_ONLY;
99     }
100     return mode;
101 }
102
103 static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
104                 int len, short *pcm, int frame_size, int decode_fec)
105 {
106         int i, silk_ret=0, celt_ret=0;
107         ec_dec dec;
108     SKP_SILK_SDK_DecControlStruct DecControl;
109     SKP_int32 silk_frame_size;
110     short pcm_celt[960*2];
111     short pcm_transition[960*2];
112     int audiosize;
113     int mode;
114     int transition=0;
115     int start_band;
116     int redundancy=0;
117     int redundancy_bytes = 0;
118     int celt_to_silk=0;
119     short redundant_audio[240*2];
120     int c;
121     int F2_5, F5, F10;
122
123     F10 = st->Fs/100;
124     F5 = F10>>1;
125     F2_5 = F5>>1;
126     /* Payloads of 1 (2 including ToC) or 0 trigger the PLC/DTX */
127     if (len<=1)
128         data = NULL;
129
130         audiosize = st->frame_size;
131     if (data != NULL)
132     {
133         mode = st->mode;
134         ec_dec_init(&dec,(unsigned char*)data,len);
135     } else {
136         mode = st->prev_mode;
137     }
138
139     if (st->stream_channels > st->channels)
140         return OPUS_CORRUPTED_DATA;
141
142     /* FIXME: Remove this when we add SILK stereo support */
143     if (st->stream_channels == 2 && mode != MODE_CELT_ONLY)
144         return OPUS_UNIMPLEMENTED;
145
146     if (data!=NULL && !st->prev_redundancy && mode != st->prev_mode && st->prev_mode > 0
147                 && !(mode == MODE_SILK_ONLY && st->prev_mode == MODE_HYBRID)
148                 && !(mode == MODE_HYBRID && st->prev_mode == MODE_SILK_ONLY))
149     {
150         transition = 1;
151         if (mode == MODE_CELT_ONLY)
152             opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
153     }
154     if (audiosize > frame_size)
155     {
156         fprintf(stderr, "PCM buffer too small: %d vs %d (mode = %d)\n", audiosize, frame_size, mode);
157         return OPUS_BAD_ARG;
158     } else {
159         frame_size = audiosize;
160     }
161
162     /* SILK processing */
163     if (mode != MODE_CELT_ONLY)
164     {
165         int lost_flag, decoded_samples;
166         SKP_int16 *pcm_ptr = pcm;
167
168         if (st->prev_mode==MODE_CELT_ONLY)
169                 SKP_Silk_SDK_InitDecoder( st->silk_dec );
170
171         DecControl.API_sampleRate = st->Fs;
172         DecControl.payloadSize_ms = 1000 * audiosize / st->Fs;
173         if( mode == MODE_SILK_ONLY ) {
174             if( st->bandwidth == BANDWIDTH_NARROWBAND ) {
175                 DecControl.internalSampleRate = 8000;
176             } else if( st->bandwidth == BANDWIDTH_MEDIUMBAND ) {
177                 DecControl.internalSampleRate = 12000;
178             } else if( st->bandwidth == BANDWIDTH_WIDEBAND ) {
179                 DecControl.internalSampleRate = 16000;
180             } else {
181                 DecControl.internalSampleRate = 16000;
182                 SKP_assert( 0 );
183             }
184         } else {
185             /* Hybrid mode */
186             DecControl.internalSampleRate = 16000;
187         }
188
189         lost_flag = data == NULL ? 1 : 2 * decode_fec;
190         decoded_samples = 0;
191         do {
192             /* Call SILK decoder */
193             int first_frame = decoded_samples == 0;
194             silk_ret = SKP_Silk_SDK_Decode( st->silk_dec, &DecControl, 
195                 lost_flag, first_frame, &dec, len, pcm_ptr, &silk_frame_size );
196             if( silk_ret ) {
197                 fprintf (stderr, "SILK decode error\n");
198                 /* Handle error */
199             }
200             pcm_ptr += silk_frame_size;
201             decoded_samples += silk_frame_size;
202         } while( decoded_samples < frame_size );
203     } else {
204         for (i=0;i<frame_size*st->channels;i++)
205             pcm[i] = 0;
206     }
207
208     start_band = 0;
209     if (mode != MODE_CELT_ONLY && data != NULL)
210     {
211         /* Check if we have a redundant 0-8 kHz band */
212         redundancy = ec_dec_bit_logp(&dec, 12);
213         if (redundancy)
214         {
215             celt_to_silk = ec_dec_bit_logp(&dec, 1);
216             if (mode == MODE_HYBRID)
217                 redundancy_bytes = 2 + ec_dec_uint(&dec, 256);
218             else
219                 redundancy_bytes = len - ((ec_tell(&dec)+7)>>3);
220             len -= redundancy_bytes;
221             if (len<0)
222                 return CELT_CORRUPTED_DATA;
223             /* Shrink decoder because of raw bits */
224             dec.storage -= redundancy_bytes;
225         }
226     }
227     if (mode != MODE_CELT_ONLY)
228         start_band = 17;
229
230     if (mode != MODE_SILK_ONLY)
231     {
232         int endband=21;
233
234         switch(st->bandwidth)
235         {
236         case BANDWIDTH_NARROWBAND:
237             endband = 13;
238             break;
239         case BANDWIDTH_WIDEBAND:
240             endband = 17;
241             break;
242         case BANDWIDTH_SUPERWIDEBAND:
243             endband = 19;
244             break;
245         case BANDWIDTH_FULLBAND:
246             endband = 21;
247             break;
248         }
249         celt_decoder_ctl(st->celt_dec, CELT_SET_END_BAND(endband));
250         celt_decoder_ctl(st->celt_dec, CELT_SET_CHANNELS(st->stream_channels));
251     }
252
253     if (redundancy)
254         transition = 0;
255
256     if (transition && mode != MODE_CELT_ONLY)
257         opus_decode_frame(st, NULL, 0, pcm_transition, IMAX(F10, audiosize), 0);
258
259     /* 5 ms redundant frame for CELT->SILK*/
260     if (redundancy && celt_to_silk)
261     {
262         celt_decode(st->celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
263         celt_decoder_ctl(st->celt_dec, CELT_RESET_STATE);
264     }
265
266     /* MUST be after PLC */
267     celt_decoder_ctl(st->celt_dec, CELT_SET_START_BAND(start_band));
268
269     if (transition)
270         celt_decoder_ctl(st->celt_dec, CELT_RESET_STATE);
271
272     if (mode != MODE_SILK_ONLY)
273     {
274         /* Decode CELT */
275         celt_ret = celt_decode_with_ec(st->celt_dec, decode_fec?NULL:data, len, pcm_celt, frame_size, &dec);
276         for (i=0;i<frame_size*st->channels;i++)
277             pcm[i] = ADD_SAT16(pcm[i], pcm_celt[i]);
278     }
279
280     /* 5 ms redundant frame for SILK->CELT */
281     if (redundancy && !celt_to_silk)
282     {
283         celt_decoder_ctl(st->celt_dec, CELT_RESET_STATE);
284         celt_decoder_ctl(st->celt_dec, CELT_SET_START_BAND(0));
285
286         celt_decode(st->celt_dec, data+len, redundancy_bytes, redundant_audio, F5);
287         smooth_fade(pcm+st->channels*(frame_size-F2_5), redundant_audio+st->channels*F2_5,
288                         pcm+st->channels*(frame_size-F2_5), F2_5, st->channels);
289     }
290     if (redundancy && celt_to_silk)
291     {
292         for (c=0;c<st->channels;c++)
293         {
294             for (i=0;i<F2_5;i++)
295                 pcm[st->channels*i+c] = redundant_audio[st->channels*i];
296         }
297         smooth_fade(redundant_audio+st->channels*F2_5, pcm+st->channels*F2_5, pcm+st->channels*F2_5, F2_5, st->channels);
298     }
299     if (transition)
300     {
301         int plc_length, overlap;
302         plc_length = IMIN(audiosize, 10+F2_5);
303         for (i=0;i<plc_length;i++)
304                 pcm[i] = pcm_transition[i];
305
306         overlap = IMIN(F2_5, IMAX(0, audiosize-plc_length));
307         smooth_fade(pcm_transition+plc_length, pcm+plc_length, pcm+plc_length, overlap, st->channels);
308     }
309 #if OPUS_TEST_RANGE_CODER_STATE
310     st->rangeFinal = dec.rng;
311 #endif
312
313     st->prev_mode = mode;
314     st->prev_redundancy = redundancy;
315         return celt_ret<0 ? celt_ret : audiosize;
316
317 }
318
319 static int parse_size(const unsigned char *data, int len, short *size)
320 {
321         if (len<1)
322         {
323                 *size = -1;
324                 return -1;
325         } else if (data[0]<252)
326         {
327                 *size = data[0];
328                 return 1;
329         } else if (len<2)
330         {
331                 *size = -1;
332                 return -1;
333         } else {
334                 *size = 4*data[1] + data[0];
335                 return 2;
336         }
337 }
338
339 int opus_decode(OpusDecoder *st, const unsigned char *data,
340                 int len, short *pcm, int frame_size, int decode_fec)
341 {
342         int i, bytes, nb_samples;
343         int count;
344         unsigned char ch, toc;
345         /* 48 x 2.5 ms = 120 ms */
346         short size[48];
347         if (len==0 || data==NULL)
348             return opus_decode_frame(st, NULL, 0, pcm, frame_size, 0);
349         else if (len<0)
350                 return CELT_BAD_ARG;
351         st->mode = opus_packet_get_mode(data);
352         st->bandwidth = opus_packet_get_bandwidth(data);
353         st->frame_size = opus_packet_get_samples_per_frame(data, st->Fs);
354         st->stream_channels = opus_packet_get_nb_channels(data);
355         toc = *data++;
356         len--;
357         switch (toc&0x3)
358         {
359         /* One frame */
360         case 0:
361                 count=1;
362                 size[0] = len;
363                 break;
364                 /* Two CBR frames */
365         case 1:
366                 count=2;
367                 if (len&0x1)
368                         return OPUS_CORRUPTED_DATA;
369                 size[0] = size[1] = len/2;
370                 break;
371                 /* Two VBR frames */
372         case 2:
373                 count = 2;
374                 bytes = parse_size(data, len, size);
375                 len -= bytes;
376                 if (size[0]<0 || size[0] > len)
377                         return OPUS_CORRUPTED_DATA;
378                 data += bytes;
379                 size[1] = len-size[0];
380                 break;
381                 /* Multiple CBR/VBR frames (from 0 to 120 ms) */
382         case 3:
383                 if (len<1)
384                         return OPUS_CORRUPTED_DATA;
385                 /* Number of frames encoded in bits 0 to 5 */
386                 ch = *data++;
387                 count = ch&0x3F;
388                 if (st->frame_size*count*25 > 3*st->Fs)
389                     return OPUS_CORRUPTED_DATA;
390                 len--;
391                 /* Bit 7 is VBR flag (bit 6 is ignored) */
392                 if (ch&0x80)
393                 {
394                         /* VBR case */
395                         int last_size=len;
396                         for (i=0;i<count-1;i++)
397                         {
398                                 bytes = parse_size(data, len, size+i);
399                                 len -= bytes;
400                                 if (size[i]<0 || size[i] > len)
401                                         return OPUS_CORRUPTED_DATA;
402                                 data += bytes;
403                                 last_size -= bytes+size[i];
404                         }
405                         if (last_size<0)
406                                 return OPUS_CORRUPTED_DATA;
407                         if (count)
408                                 size[count-1]=last_size;
409                 } else {
410                         /* CBR case */
411                         int sz = count != 0 ? len/count : 0;
412                         if (sz*count!=len)
413                                 return OPUS_CORRUPTED_DATA;
414                         for (i=0;i<count;i++)
415                                 size[i] = sz;
416                 }
417                 break;
418         }
419         if (count*st->frame_size > frame_size)
420                 return OPUS_BAD_ARG;
421         nb_samples=0;
422         for (i=0;i<count;i++)
423         {
424                 int ret;
425                 ret = opus_decode_frame(st, data, len, pcm, frame_size-nb_samples, decode_fec);
426                 if (ret<0)
427                         return ret;
428                 data += size[i];
429                 pcm += ret;
430                 nb_samples += ret;
431         }
432         return nb_samples;
433 }
434 int opus_decoder_ctl(OpusDecoder *st, int request, ...)
435 {
436     va_list ap;
437
438     va_start(ap, request);
439
440     switch (request)
441     {
442         case OPUS_GET_MODE_REQUEST:
443         {
444             int *value = va_arg(ap, int*);
445             *value = st->prev_mode;
446         }
447         break;
448         case OPUS_SET_BANDWIDTH_REQUEST:
449         {
450             int value = va_arg(ap, int);
451             st->bandwidth = value;
452         }
453         break;
454         case OPUS_GET_BANDWIDTH_REQUEST:
455         {
456             int *value = va_arg(ap, int*);
457             *value = st->bandwidth;
458         }
459         break;
460         default:
461             fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);
462             break;
463     }
464
465     va_end(ap);
466     return OPUS_OK;
467 }
468
469 void opus_decoder_destroy(OpusDecoder *st)
470 {
471         free(st);
472 }
473
474 #if OPUS_TEST_RANGE_CODER_STATE
475 int opus_decoder_get_final_range(OpusDecoder *st)
476 {
477     return st->rangeFinal;
478 }
479 #endif
480
481
482 int opus_packet_get_bandwidth(const unsigned char *data)
483 {
484         int bandwidth;
485     if (data[0]&0x80)
486     {
487         bandwidth = BANDWIDTH_MEDIUMBAND + ((data[0]>>5)&0x3);
488         if (bandwidth == BANDWIDTH_MEDIUMBAND)
489             bandwidth = BANDWIDTH_NARROWBAND;
490     } else if ((data[0]&0x60) == 0x60)
491     {
492         bandwidth = (data[0]&0x10) ? BANDWIDTH_FULLBAND : BANDWIDTH_SUPERWIDEBAND;
493     } else {
494
495         bandwidth = BANDWIDTH_NARROWBAND + ((data[0]>>5)&0x3);
496     }
497     return bandwidth;
498 }
499
500 int opus_packet_get_samples_per_frame(const unsigned char *data, int Fs)
501 {
502         int audiosize;
503     if (data[0]&0x80)
504     {
505         audiosize = ((data[0]>>3)&0x3);
506         audiosize = (Fs<<audiosize)/400;
507     } else if ((data[0]&0x60) == 0x60)
508     {
509         audiosize = (data[0]&0x08) ? Fs/50 : Fs/100;
510     } else {
511
512         audiosize = ((data[0]>>3)&0x3);
513         if (audiosize == 3)
514             audiosize = Fs*60/1000;
515         else
516             audiosize = (Fs<<audiosize)/100;
517     }
518     return audiosize;
519 }
520
521 int opus_packet_get_nb_channels(const unsigned char *data)
522 {
523     return (data[0]&0x4) ? 2 : 1;
524 }
525
526 int opus_packet_get_nb_frames(const unsigned char packet[], int len)
527 {
528         int count;
529         if (len<1)
530                 return OPUS_BAD_ARG;
531         count = packet[0]&0x3;
532         if (count==0)
533                 return 1;
534         else if (count!=3)
535                 return 2;
536         else if (len<2)
537                 return OPUS_CORRUPTED_DATA;
538         else
539                 return packet[1]&0x3F;
540 }
541
542 int opus_decoder_get_nb_samples(const OpusDecoder *dec, const unsigned char packet[], int len)
543 {
544         int samples;
545         int count = opus_packet_get_nb_frames(packet, len);
546         samples = count*opus_packet_get_samples_per_frame(packet, dec->Fs);
547         /* Can't have more than 120 ms */
548         if (samples*25 > dec->Fs*3)
549                 return OPUS_CORRUPTED_DATA;
550         else
551                 return samples;
552 }
553