Some missing checks
[opus.git] / src / opus_multistream_decoder.c
index 47f87b1..f767ea0 100644 (file)
 #include "float_cast.h"
 #include "os_support.h"
 
-struct OpusMSDecoder {
-   ChannelLayout layout;
-   /* Decoder states go here */
-};
-
+/* DECODER */
 
+#if defined(ENABLE_HARDENING) || defined(ENABLE_ASSERTIONS)
+static void validate_ms_decoder(OpusMSDecoder *st)
+{
+   validate_layout(&st->layout);
+#ifdef OPUS_ARCHMASK
+   celt_assert(st->arch >= 0);
+   celt_assert(st->arch <= OPUS_ARCHMASK);
+#endif
+}
+#define VALIDATE_MS_DECODER(st) validate_ms_decoder(st)
+#else
+#define VALIDATE_MS_DECODER(st)
+#endif
 
 
-/* DECODER */
-
 opus_int32 opus_multistream_decoder_get_size(int nb_streams, int nb_coupled_streams)
 {
    int coupled_size;
@@ -75,7 +82,7 @@ int opus_multistream_decoder_init(
    char *ptr;
 
    if ((channels>255) || (channels<1) || (coupled_streams>streams) ||
-       (coupled_streams+streams>255) || (streams<1) || (coupled_streams<0))
+       (streams<1) || (coupled_streams<0) || (streams>255-coupled_streams))
       return OPUS_BAD_ARG;
 
    st->layout.nb_channels = channels;
@@ -119,7 +126,7 @@ OpusMSDecoder *opus_multistream_decoder_create(
    int ret;
    OpusMSDecoder *st;
    if ((channels>255) || (channels<1) || (coupled_streams>streams) ||
-       (coupled_streams+streams>255) || (streams<1) || (coupled_streams<0))
+       (streams<1) || (coupled_streams<0) || (streams>255-coupled_streams))
    {
       if (error)
          *error = OPUS_BAD_ARG;
@@ -143,16 +150,36 @@ OpusMSDecoder *opus_multistream_decoder_create(
    return st;
 }
 
-typedef void (*opus_copy_channel_out_func)(
-  void *dst,
-  int dst_stride,
-  int dst_channel,
-  const opus_val16 *src,
-  int src_stride,
-  int frame_size
-);
+static int opus_multistream_packet_validate(const unsigned char *data,
+      opus_int32 len, int nb_streams, opus_int32 Fs)
+{
+   int s;
+   int count;
+   unsigned char toc;
+   opus_int16 size[48];
+   int samples=0;
+   opus_int32 packet_offset;
+
+   for (s=0;s<nb_streams;s++)
+   {
+      int tmp_samples;
+      if (len<=0)
+         return OPUS_INVALID_PACKET;
+      count = opus_packet_parse_impl(data, len, s!=nb_streams-1, &toc, NULL,
+                                     size, NULL, &packet_offset);
+      if (count<0)
+         return count;
+      tmp_samples = opus_packet_get_nb_samples(data, packet_offset, Fs);
+      if (s!=0 && samples != tmp_samples)
+         return OPUS_INVALID_PACKET;
+      samples = tmp_samples;
+      data += packet_offset;
+      len -= packet_offset;
+   }
+   return samples;
+}
 
-static int opus_multistream_decode_native(
+int opus_multistream_decode_native(
       OpusMSDecoder *st,
       const unsigned char *data,
       opus_int32 len,
@@ -160,7 +187,8 @@ static int opus_multistream_decode_native(
       opus_copy_channel_out_func copy_channel_out,
       int frame_size,
       int decode_fec,
-      int soft_clip
+      int soft_clip,
+      void *user_data
 )
 {
    opus_int32 Fs;
@@ -172,8 +200,9 @@ static int opus_multistream_decode_native(
    VARDECL(opus_val16, buf);
    ALLOC_STACK;
 
+   VALIDATE_MS_DECODER(st);
    /* Limit frame_size to avoid excessive stack allocations. */
-   opus_multistream_decoder_ctl(st, OPUS_GET_SAMPLE_RATE(&Fs));
+   MUST_SUCCEED(opus_multistream_decoder_ctl(st, OPUS_GET_SAMPLE_RATE(&Fs)));
    frame_size = IMIN(frame_size, Fs/25*3);
    ALLOC(buf, 2*frame_size, opus_val16);
    ptr = (char*)st + align(sizeof(OpusMSDecoder));
@@ -183,13 +212,33 @@ static int opus_multistream_decode_native(
    if (len==0)
       do_plc = 1;
    if (len < 0)
+   {
+      RESTORE_STACK;
       return OPUS_BAD_ARG;
+   }
    if (!do_plc && len < 2*st->layout.nb_streams-1)
+   {
+      RESTORE_STACK;
       return OPUS_INVALID_PACKET;
+   }
+   if (!do_plc)
+   {
+      int ret = opus_multistream_packet_validate(data, len, st->layout.nb_streams, Fs);
+      if (ret < 0)
+      {
+         RESTORE_STACK;
+         return ret;
+      } else if (ret > frame_size)
+      {
+         RESTORE_STACK;
+         return OPUS_BUFFER_TOO_SMALL;
+      }
+   }
    for (s=0;s<st->layout.nb_streams;s++)
    {
       OpusDecoder *dec;
-      int packet_offset, ret;
+      opus_int32 packet_offset;
+      int ret;
 
       dec = (OpusDecoder*)ptr;
       ptr += (s < st->layout.nb_coupled_streams) ? align(coupled_size) : align(mono_size);
@@ -197,22 +246,12 @@ static int opus_multistream_decode_native(
       if (!do_plc && len<=0)
       {
          RESTORE_STACK;
-         return OPUS_INVALID_PACKET;
+         return OPUS_INTERNAL_ERROR;
       }
       packet_offset = 0;
       ret = opus_decode_native(dec, data, len, buf, frame_size, decode_fec, s!=st->layout.nb_streams-1, &packet_offset, soft_clip);
       data += packet_offset;
       len -= packet_offset;
-      if (ret > frame_size)
-      {
-         RESTORE_STACK;
-         return OPUS_BUFFER_TOO_SMALL;
-      }
-      if (s>0 && ret != frame_size)
-      {
-         RESTORE_STACK;
-         return OPUS_INVALID_PACKET;
-      }
       if (ret <= 0)
       {
          RESTORE_STACK;
@@ -227,7 +266,7 @@ static int opus_multistream_decode_native(
          while ( (chan = get_left_channel(&st->layout, s, prev)) != -1)
          {
             (*copy_channel_out)(pcm, st->layout.nb_channels, chan,
-               buf, 2, frame_size);
+               buf, 2, frame_size, user_data);
             prev = chan;
          }
          prev = -1;
@@ -235,7 +274,7 @@ static int opus_multistream_decode_native(
          while ( (chan = get_right_channel(&st->layout, s, prev)) != -1)
          {
             (*copy_channel_out)(pcm, st->layout.nb_channels, chan,
-               buf+1, 2, frame_size);
+               buf+1, 2, frame_size, user_data);
             prev = chan;
          }
       } else {
@@ -245,7 +284,7 @@ static int opus_multistream_decode_native(
          while ( (chan = get_mono_channel(&st->layout, s, prev)) != -1)
          {
             (*copy_channel_out)(pcm, st->layout.nb_channels, chan,
-               buf, 1, frame_size);
+               buf, 1, frame_size, user_data);
             prev = chan;
          }
       }
@@ -256,7 +295,7 @@ static int opus_multistream_decode_native(
       if (st->layout.mapping[c] == 255)
       {
          (*copy_channel_out)(pcm, st->layout.nb_channels, c,
-            NULL, 0, frame_size);
+            NULL, 0, frame_size, user_data);
       }
    }
    RESTORE_STACK;
@@ -270,11 +309,13 @@ static void opus_copy_channel_out_float(
   int dst_channel,
   const opus_val16 *src,
   int src_stride,
-  int frame_size
+  int frame_size,
+  void *user_data
 )
 {
    float *float_dst;
-   int i;
+   opus_int32 i;
+   (void)user_data;
    float_dst = (float*)dst;
    if (src != NULL)
    {
@@ -299,11 +340,13 @@ static void opus_copy_channel_out_short(
   int dst_channel,
   const opus_val16 *src,
   int src_stride,
-  int frame_size
+  int frame_size,
+  void *user_data
 )
 {
    opus_int16 *short_dst;
-   int i;
+   opus_int32 i;
+   (void)user_data;
    short_dst = (opus_int16*)dst;
    if (src != NULL)
    {
@@ -334,7 +377,7 @@ int opus_multistream_decode(
 )
 {
    return opus_multistream_decode_native(st, data, len,
-       pcm, opus_copy_channel_out_short, frame_size, decode_fec, 0);
+       pcm, opus_copy_channel_out_short, frame_size, decode_fec, 0, NULL);
 }
 
 #ifndef DISABLE_FLOAT_API
@@ -342,7 +385,7 @@ int opus_multistream_decode_float(OpusMSDecoder *st, const unsigned char *data,
       opus_int32 len, float *pcm, int frame_size, int decode_fec)
 {
    return opus_multistream_decode_native(st, data, len,
-       pcm, opus_copy_channel_out_float, frame_size, decode_fec, 0);
+       pcm, opus_copy_channel_out_float, frame_size, decode_fec, 0, NULL);
 }
 #endif
 
@@ -352,32 +395,30 @@ int opus_multistream_decode(OpusMSDecoder *st, const unsigned char *data,
       opus_int32 len, opus_int16 *pcm, int frame_size, int decode_fec)
 {
    return opus_multistream_decode_native(st, data, len,
-       pcm, opus_copy_channel_out_short, frame_size, decode_fec, 1);
+       pcm, opus_copy_channel_out_short, frame_size, decode_fec, 1, NULL);
 }
 
 int opus_multistream_decode_float(
       OpusMSDecoder *st,
       const unsigned char *data,
       opus_int32 len,
-      float *pcm,
+      opus_val16 *pcm,
       int frame_size,
       int decode_fec
 )
 {
    return opus_multistream_decode_native(st, data, len,
-       pcm, opus_copy_channel_out_float, frame_size, decode_fec, 0);
+       pcm, opus_copy_channel_out_float, frame_size, decode_fec, 0, NULL);
 }
 #endif
 
-int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
+int opus_multistream_decoder_ctl_va_list(OpusMSDecoder *st, int request,
+                                         va_list ap)
 {
-   va_list ap;
    int coupled_size, mono_size;
    char *ptr;
    int ret = OPUS_OK;
 
-   va_start(ap, request);
-
    coupled_size = opus_decoder_get_size(2);
    mono_size = opus_decoder_get_size(1);
    ptr = (char*)st + align(sizeof(OpusMSDecoder));
@@ -387,6 +428,7 @@ int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
        case OPUS_GET_SAMPLE_RATE_REQUEST:
        case OPUS_GET_GAIN_REQUEST:
        case OPUS_GET_LAST_PACKET_DURATION_REQUEST:
+       case OPUS_GET_PHASE_INVERSION_DISABLED_REQUEST:
        {
           OpusDecoder *dec;
           /* For int32* GET params, just query the first stream */
@@ -400,6 +442,10 @@ int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
           int s;
           opus_uint32 *value = va_arg(ap, opus_uint32*);
           opus_uint32 tmp;
+          if (!value)
+          {
+             goto bad_arg;
+          }
           *value = 0;
           for (s=0;s<st->layout.nb_streams;s++)
           {
@@ -442,6 +488,10 @@ int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
           if (stream_id<0 || stream_id >= st->layout.nb_streams)
              ret = OPUS_BAD_ARG;
           value = va_arg(ap, OpusDecoder**);
+          if (!value)
+          {
+             goto bad_arg;
+          }
           for (s=0;s<stream_id;s++)
           {
              if (s < st->layout.nb_coupled_streams)
@@ -453,6 +503,7 @@ int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
        }
        break;
        case OPUS_SET_GAIN_REQUEST:
+       case OPUS_SET_PHASE_INVERSION_DISABLED_REQUEST:
        {
           int s;
           /* This works for int32 params */
@@ -476,12 +527,21 @@ int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
           ret = OPUS_UNIMPLEMENTED;
        break;
    }
+   return ret;
+bad_arg:
+   return OPUS_BAD_ARG;
+}
 
+int opus_multistream_decoder_ctl(OpusMSDecoder *st, int request, ...)
+{
+   int ret;
+   va_list ap;
+   va_start(ap, request);
+   ret = opus_multistream_decoder_ctl_va_list(st, request, ap);
    va_end(ap);
    return ret;
 }
 
-
 void opus_multistream_decoder_destroy(OpusMSDecoder *st)
 {
     opus_free(st);