Add paranoid checking for the validity of the encoder and the decoder
authorGregory Maxwell <greg@xiph.org>
Thu, 4 Jun 2009 19:15:34 +0000 (15:15 -0400)
committerJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Fri, 5 Jun 2009 01:42:53 +0000 (21:42 -0400)
state before using it. Handle malloc failures for the encoder and
decoder setup gracefully and without leaks.

libcelt/celt.c
libcelt/celt.h

index 72045d9..86e5c96 100644 (file)
@@ -64,11 +64,15 @@ static const float transientWindow[16] = {
    0.5461342, 0.6368315, 0.7228692, 0.8013173, 0.8695045, 0.9251086, 0.9662361, 0.9914865};
 #endif
 
+#define ENCODERVALID   0x4c434554
+#define ENCODERPARTIAL 0x5445434c
+#define ENCODERFREED   0x4c004500
    
 /** Encoder state 
  @brief Encoder state
  */
 struct CELTEncoder {
+   celt_uint32_t marker;
    const CELTMode *mode;     /**< Mode used by the encoder */
    int frame_size;
    int block_size;
@@ -95,6 +99,22 @@ struct CELTEncoder {
 #endif
 };
 
+int check_encoder(const CELTEncoder *st) 
+{
+   if (st==NULL)
+   {
+      celt_warning("NULL passed as an encoder structure");  
+      return CELT_INVALID_STATE;
+   }
+   if (st->marker == ENCODERVALID)
+      return CELT_OK;
+   if (st->marker == ENCODERFREED)
+      celt_warning("Referencing an encoder that has already been freed");
+   else
+      celt_warning("This is not a valid CELT encoder structure");
+   return CELT_INVALID_STATE;
+}
+
 CELTEncoder *celt_encoder_create(const CELTMode *mode)
 {
    int N, C;
@@ -107,6 +127,9 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode)
    C = mode->nbChannels;
    st = celt_alloc(sizeof(CELTEncoder));
    
+   if (st==NULL) 
+      return NULL;   
+   st->marker = ENCODERPARTIAL;
    st->mode = mode;
    st->frame_size = N;
    st->block_size = N;
@@ -132,7 +155,18 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode)
    psydecay_init(&st->psy, MAX_PERIOD/2, st->mode->Fs);
 #endif
 
-   return st;
+   if ((st->in_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) 
+#ifdef EXP_PSY
+       && (st->psy_mem!=NULL) 
+#endif   
+       && (st->preemph_memE!=NULL) && (st->preemph_memD!=NULL))
+   {
+      st->marker   = ENCODERVALID;
+      return st;
+   }
+   /* If the setup fails for some reason deallocate it. */
+   celt_encoder_destroy(st);  
+   return NULL;
 }
 
 void celt_encoder_destroy(CELTEncoder *st)
@@ -142,9 +176,23 @@ void celt_encoder_destroy(CELTEncoder *st)
       celt_warning("NULL passed to celt_encoder_destroy");
       return;
    }
-   if (check_mode(st->mode) != CELT_OK)
+
+   if (st->marker == ENCODERFREED)
+   {
       return;
+      celt_warning("Freeing an encoder which has already been freed"); 
+   }
 
+   if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
+   {
+      celt_warning("This is not a valid CELT encoder structure");
+      return;
+   }
+   /*Check_mode is non-fatal here because we can still free
+    the encoder memory even if the mode is bad, although calling
+    the free functions in this order is a violation of the API.*/
+   check_mode(st->mode);
+   
    celt_free(st->in_mem);
    celt_free(st->out_mem);
    
@@ -157,6 +205,7 @@ void celt_encoder_destroy(CELTEncoder *st)
    celt_free (st->psy_mem);
    psydecay_clear(&st->psy);
 #endif
+   st->marker = ENCODERFREED;
    
    celt_free(st);
 }
@@ -472,6 +521,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    int mdct_weight_pos=0;
    SAVE_STACK;
 
+   if (check_encoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
    if (check_mode(st->mode) != CELT_OK)
       return CELT_INVALID_MODE;
 
@@ -865,11 +917,18 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 #ifndef DISABLE_FLOAT_API
 int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
 {
-   int j, ret;
-   const int C = CHANNELS(st->mode);
-   const int N = st->block_size;
+   int j, ret, C, N;
    VARDECL(celt_int16_t, in);
+
+   if (check_encoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
+   if (check_mode(st->mode) != CELT_OK)
+      return CELT_INVALID_MODE;
+
    SAVE_STACK;
+   C = CHANNELS(st->mode);
+   N = st->block_size;
    ALLOC(in, C*N, celt_int16_t);
 
    for (j=0;j<C*N;j++)
@@ -890,11 +949,18 @@ int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * opti
 #else
 int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_t * optional_synthesis, unsigned char *compressed, int nbCompressedBytes)
 {
-   int j, ret;
+   int j, ret, C, N;
    VARDECL(celt_sig_t, in);
-   const int C = CHANNELS(st->mode);
-   const int N = st->block_size;
+
+   if (check_encoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
+   if (check_mode(st->mode) != CELT_OK)
+      return CELT_INVALID_MODE;
+
    SAVE_STACK;
+   C=CHANNELS(st->mode);
+   N=st->block_size;
    ALLOC(in, C*N, celt_sig_t);
    for (j=0;j<C*N;j++) {
      in[j] = SCALEOUT(pcm[j]);
@@ -915,7 +981,13 @@ int celt_encode(CELTEncoder * restrict st, const celt_int16_t * pcm, celt_int16_
 int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
 {
    va_list ap;
+   
+   if (check_encoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
    va_start(ap, request);
+   if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
+     goto bad_mode;
    switch (request)
    {
       case CELT_GET_MODE_REQUEST:
@@ -985,6 +1057,9 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
    }
    va_end(ap);
    return CELT_OK;
+bad_mode:
+  va_end(ap);
+  return CELT_INVALID_MODE;
 bad_arg:
    va_end(ap);
    return CELT_BAD_ARG;
@@ -1004,10 +1079,15 @@ bad_request:
 #define DECODE_BUFFER_SIZE MAX_PERIOD
 #endif
 
+#define DECODERVALID   0x4c434454
+#define DECODERPARTIAL 0x5444434c
+#define DECODERFREED   0x4c004400
+
 /** Decoder state 
  @brief Decoder state
  */
 struct CELTDecoder {
+   celt_uint32_t marker;
    const CELTMode *mode;
    int frame_size;
    int block_size;
@@ -1026,6 +1106,22 @@ struct CELTDecoder {
    int last_pitch_index;
 };
 
+int check_decoder(const CELTDecoder *st) 
+{
+   if (st==NULL)
+   {
+      celt_warning("NULL passed a decoder structure");  
+      return CELT_INVALID_STATE;
+   }
+   if (st->marker == DECODERVALID)
+      return CELT_OK;
+   if (st->marker == DECODERFREED)
+      celt_warning("Referencing a decoder that has already been freed");
+   else
+      celt_warning("This is not a valid CELT decoder structure");
+   return CELT_INVALID_STATE;
+}
+
 CELTDecoder *celt_decoder_create(const CELTMode *mode)
 {
    int N, C;
@@ -1037,7 +1133,11 @@ CELTDecoder *celt_decoder_create(const CELTMode *mode)
    N = mode->mdctSize;
    C = CHANNELS(mode);
    st = celt_alloc(sizeof(CELTDecoder));
+
+   if (st==NULL)
+      return NULL;
    
+   st->marker = DECODERPARTIAL;
    st->mode = mode;
    st->frame_size = N;
    st->block_size = N;
@@ -1047,30 +1147,53 @@ CELTDecoder *celt_decoder_create(const CELTMode *mode)
    st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
    
    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
-
+   
    st->preemph_memD = (celt_sig_t*)celt_alloc(C*sizeof(celt_sig_t));
 
    st->last_pitch_index = 0;
-   return st;
+
+   if ((st->decode_mem!=NULL) && (st->out_mem!=NULL) && (st->oldBandE!=NULL) &&
+       (st->preemph_memD!=NULL))
+   {
+      st->marker = DECODERVALID;
+      return st;
+   }
+   /* If the setup fails for some reason deallocate it. */
+   celt_decoder_destroy(st);
+   return NULL;
 }
 
 void celt_decoder_destroy(CELTDecoder *st)
 {
    if (st == NULL)
    {
-      celt_warning("NULL passed to celt_encoder_destroy");
+      celt_warning("NULL passed to celt_decoder_destroy");
       return;
    }
-   if (check_mode(st->mode) != CELT_OK)
-      return;
 
-
-   celt_free(st->decode_mem);
+   if (st->marker == DECODERFREED) 
+   {
+      celt_warning("Freeing a decoder which has already been freed"); 
+      return;
+   }
    
-   celt_free(st->oldBandE);
+   if (st->marker != DECODERVALID && st->marker != DECODERPARTIAL)
+   {
+      celt_warning("This is not a valid CELT decoder structure");
+      return;
+   }
+   
+   /*Check_mode is non-fatal here because we can still free
+     the encoder memory even if the mode is bad, although calling
+     the free functions in this order is a violation of the API.*/
+   check_mode(st->mode);
    
+   celt_free(st->decode_mem);
+   celt_free(st->oldBandE);
    celt_free(st->preemph_memD);
-
+   
+   st->marker = DECODERFREED;
+   
    celt_free(st);
 }
 
@@ -1161,6 +1284,9 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    int mdct_weight_pos=0;
    SAVE_STACK;
 
+   if (check_decoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
    if (check_mode(st->mode) != CELT_OK)
       return CELT_INVALID_MODE;
 
@@ -1173,11 +1299,6 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    ALLOC(bandE, st->mode->nbEBands*C, celt_ener_t);
    ALLOC(gains, st->mode->nbPBands, celt_pgain_t);
    
-   if (check_mode(st->mode) != CELT_OK)
-   {
-      RESTORE_STACK;
-      return CELT_INVALID_MODE;
-   }
    if (data == NULL)
    {
       celt_decode_lost(st, pcm);
@@ -1318,11 +1439,18 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
 #ifndef DISABLE_FLOAT_API
 int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
 {
-   int j, ret;
-   const int C = CHANNELS(st->mode);
-   const int N = st->block_size;
+   int j, ret, C, N;
    VARDECL(celt_int16_t, out);
+
+   if (check_decoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
+   if (check_mode(st->mode) != CELT_OK)
+      return CELT_INVALID_MODE;
+
    SAVE_STACK;
+   C = CHANNELS(st->mode);
+   N = st->block_size;
    ALLOC(out, C*N, celt_int16_t);
 
    ret=celt_decode(st, data, len, out);
@@ -1336,11 +1464,18 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
 #else
 int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
 {
-   int j, ret;
+   int j, ret, C, N;
    VARDECL(celt_sig_t, out);
-   const int C = CHANNELS(st->mode);
-   const int N = st->block_size;
+
+   if (check_decoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
+   if (check_mode(st->mode) != CELT_OK)
+      return CELT_INVALID_MODE;
+
    SAVE_STACK;
+   C = CHANNELS(st->mode);
+   N = st->block_size;
    ALLOC(out, C*N, celt_sig_t);
 
    ret=celt_decode_float(st, data, len, out);
@@ -1356,7 +1491,13 @@ int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, c
 int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
 {
    va_list ap;
+
+   if (check_decoder(st) != CELT_OK)
+      return CELT_INVALID_STATE;
+
    va_start(ap, request);
+   if ((request!=CELT_GET_MODE_REQUEST) && (check_mode(st->mode) != CELT_OK))
+     goto bad_mode;
    switch (request)
    {
       case CELT_GET_MODE_REQUEST:
@@ -1385,6 +1526,9 @@ int celt_decoder_ctl(CELTDecoder * restrict st, int request, ...)
    }
    va_end(ap);
    return CELT_OK;
+bad_mode:
+  va_end(ap);
+  return CELT_INVALID_MODE;
 bad_arg:
    va_end(ap);
    return CELT_BAD_ARG;
index 8ed67f3..9fc366a 100644 (file)
@@ -67,6 +67,8 @@ extern "C" {
 #define CELT_CORRUPTED_DATA   -4
 /** Invalid/unsupported request number */
 #define CELT_UNIMPLEMENTED    -5
+/** An encoder or decoder structure passed is invalid or already freed */
+#define CELT_INVALID_STATE    -6
 
 /* Requests */
 #define CELT_GET_MODE_REQUEST    1