Handle malloc failure in mode_create gracefully without leaking memory.
authorGregory Maxwell <greg@xiph.org>
Thu, 4 Jun 2009 21:17:35 +0000 (17:17 -0400)
committerJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Fri, 5 Jun 2009 01:51:12 +0000 (21:51 -0400)
libcelt/celt.c
libcelt/mdct.c
libcelt/modes.c
libcelt/psy.c
libcelt/quant_bands.c
libcelt/rate.c

index 86e5c96..ec4d4bd 100644 (file)
@@ -179,8 +179,8 @@ void celt_encoder_destroy(CELTEncoder *st)
 
    if (st->marker == ENCODERFREED)
    {
-      return;
       celt_warning("Freeing an encoder which has already been freed"); 
+      return;
    }
 
    if (st->marker != ENCODERVALID && st->marker != ENCODERPARTIAL)
index 8b3a1fd..80448a9 100644 (file)
@@ -64,7 +64,11 @@ void mdct_init(mdct_lookup *l,int N)
    l->n = N;
    N2 = N>>1;
    l->kfft = cpx32_fft_alloc(N>>2);
+   if (l->kfft==NULL)
+     return;
    l->trig = (kiss_twiddle_scalar*)celt_alloc(N2*sizeof(kiss_twiddle_scalar));
+   if (l->trig==NULL)
+     return;
    /* We have enough points that sine isn't necessary */
 #if defined(FIXED_POINT)
 #if defined(DOUBLE_PRECISION) & !defined(MIXED_PRECISION)
index 85f0470..b7c39d7 100644 (file)
@@ -44,8 +44,9 @@
 #include "static_modes.c"
 #endif
 
-#define MODEVALID 0xa110ca7e
-#define MODEFREED 0xb10cf8ee
+#define MODEVALID   0xa110ca7e
+#define MODEPARTIAL 0x7eca10a1
+#define MODEFREED   0xb10cf8ee
 
 #ifndef M_PI
 #define M_PI 3.141592653
@@ -54,6 +55,8 @@
 
 int celt_mode_info(const CELTMode *mode, int request, celt_int32_t *value)
 {
+   if (check_mode(mode) != CELT_OK)
+      return CELT_INVALID_MODE;
    switch (request)
    {
       case CELT_GET_FRAME_SIZE:
@@ -148,6 +151,9 @@ static celt_int16_t *compute_ebands(celt_int32_t Fs, int frame_size, int *nbEBan
    *nbEBands = low+high;
    eBands = celt_alloc(sizeof(celt_int16_t)*(*nbEBands+2));
    
+   if (eBands==NULL)
+      return NULL;
+   
    /* Linear spacing (min_width) */
    for (i=0;i<low;i++)
       eBands[i] = MIN_BINS*i;
@@ -176,6 +182,9 @@ static void compute_pbands(CELTMode *mode, int res)
    int i;
    celt_int16_t *pBands;
    pBands=celt_alloc(sizeof(celt_int16_t)*(PBANDS+2));
+   mode->pBands = pBands;
+   if (pBands==NULL)
+     return;
    mode->nbPBands = PBANDS;
    for (i=0;i<PBANDS+1;i++)
    {
@@ -203,7 +212,6 @@ static void compute_pbands(CELTMode *mode, int res)
    /*for (i=0;i<mode->nbPBands+2;i++)
       printf("%d ", pBands[i]);
    printf ("\n");*/
-   mode->pBands = pBands;
    mode->pitchEnd = pBands[PBANDS];
 }
 
@@ -220,6 +228,8 @@ static void compute_allocation_table(CELTMode *mode, int res)
 
    mode->nbAllocVectors = BITALLOC_SIZE;
    allocVectors = celt_alloc(sizeof(celt_int16_t)*(BITALLOC_SIZE*mode->nbEBands));
+   if (allocVectors==NULL)
+      return;
    /* Compute per-codec-band allocation from per-critical-band matrix */
    for (i=0;i<BITALLOC_SIZE;i++)
    {
@@ -269,6 +279,13 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
    const CELTMode *m = NULL;
    CELTMode *mode=NULL;
    ALLOC_STACK;
+#if !defined(VAR_ARRAYS) && !defined(USE_ALLOCA)
+   if (global_stack==NULL)
+   {
+      celt_free(global_stack);
+      goto failure;
+   }
+#endif 
    for (i=0;i<TOTAL_MODES;i++)
    {
       if (Fs == static_mode_list[i]->Fs &&
@@ -287,12 +304,22 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
       return NULL;
    }
    mode = (CELTMode*)celt_alloc(sizeof(CELTMode));
+   if (mode==NULL)
+      goto failure;
    CELT_COPY(mode, m, 1);
+   mode->marker_start = MODEPARTIAL;
 #else
    int res;
    CELTMode *mode;
    celt_word16_t *window;
    ALLOC_STACK;
+#if !defined(VAR_ARRAYS) && !defined(USE_ALLOCA)
+   if (global_stack==NULL)
+   {
+      celt_free(global_stack);
+      goto failure;
+   }
+#endif 
 
    /* The good thing here is that permutation of the arguments will automatically be invalid */
    
@@ -320,11 +347,18 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
    res = (Fs+frame_size)/(2*frame_size);
    
    mode = celt_alloc(sizeof(CELTMode));
+   if (mode==NULL)
+      goto failure;
+   mode->marker_start = MODEPARTIAL;
    mode->Fs = Fs;
    mode->mdctSize = frame_size;
    mode->nbChannels = channels;
    mode->eBands = compute_ebands(Fs, frame_size, &mode->nbEBands);
+   if (mode->eBands==NULL)
+      goto failure;
    compute_pbands(mode, res);
+   if (mode->pBands==NULL)
+      goto failure;
    mode->ePredCoef = QCONST16(.8f,15);
 
    if (frame_size > 384 && (frame_size%8)==0)
@@ -356,9 +390,13 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
       mode->overlap = (frame_size>>3)<<2;
 
    compute_allocation_table(mode, res);
+   if (mode->allocVectors==NULL)
+      goto failure;
    /*printf ("%d bands\n", mode->nbEBands);*/
    
    window = (celt_word16_t*)celt_alloc(mode->overlap*sizeof(celt_word16_t));
+   if (window==NULL)
+      goto failure;
 
 #ifndef FIXED_POINT
    for (i=0;i<mode->overlap;i++)
@@ -370,13 +408,15 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
    mode->window = window;
 
    mode->bits = (const celt_int16_t **)compute_alloc_cache(mode, 1);
+   if (mode->bits==NULL)
+      goto failure;
 
 #ifndef SHORTCUTS
    psydecay_init(&mode->psy, MAX_PERIOD/2, mode->Fs);
+   if (mode->psy.decayR==NULL)
+      goto failure;
 #endif
    
-   mode->marker_start = MODEVALID;
-   mode->marker_end = MODEVALID;
 #endif /* !STATIC_MODES */
    mdct_init(&mode->mdct, 2*mode->mdctSize);
    mode->fft = pitch_state_alloc(MAX_PERIOD);
@@ -384,38 +424,65 @@ CELTMode *celt_mode_create(celt_int32_t Fs, int channels, int frame_size, int *e
    mode->shortMdctSize = mode->mdctSize/mode->nbShortMdcts;
    mdct_init(&mode->shortMdct, 2*mode->shortMdctSize);
    mode->shortWindow = mode->window;
-
    mode->prob = quant_prob_alloc(mode);
+   if ((mode->mdct.trig==NULL) || (mode->mdct.kfft==NULL) || (mode->fft==NULL) ||
+       (mode->shortMdct.trig==NULL) || (mode->shortMdct.kfft==NULL) || (mode->prob==NULL))
+     goto failure;
 
+   mode->marker_start = MODEVALID;
+   mode->marker_end   = MODEVALID;
    if (error)
       *error = CELT_OK;
    return mode;
+failure: 
+   if (error)
+      *error = CELT_INVALID_MODE;
+   if (mode!=NULL)
+      celt_mode_destroy(mode);
+   return NULL;
 }
 
 void celt_mode_destroy(CELTMode *mode)
 {
+   if (mode == NULL)
+   {
+      celt_warning("NULL passed to celt_mode_destroy");
+      return;
+   }
+
+   if (mode->marker_start == MODEFREED || mode->marker_end == MODEFREED)
+   {
+      celt_warning("Freeing a mode which has already been freed"); 
+      return;
+   }
+
+   if (mode->marker_start != MODEVALID && mode->marker_start != MODEPARTIAL)
+   {
+      celt_warning("This is not a valid CELT mode structure");
+      return;  
+   }
+   mode->marker_start = MODEFREED;
 #ifndef STATIC_MODES
    int i;
    const celt_int16_t *prevPtr = NULL;
-   for (i=0;i<mode->nbEBands;i++)
+   if (mode->bits!=NULL)
    {
-      if (mode->bits[i] != prevPtr)
+      for (i=0;i<mode->nbEBands;i++)
       {
-         prevPtr = mode->bits[i];
-         celt_free((int*)mode->bits[i]);
+         if (mode->bits[i] != prevPtr)
+         {
+            prevPtr = mode->bits[i];
+            celt_free((int*)mode->bits[i]);
+          }
       }
-   }
+   }   
    celt_free((int**)mode->bits);
-   if (check_mode(mode) != CELT_OK)
-      return;
    celt_free((int*)mode->eBands);
    celt_free((int*)mode->pBands);
    celt_free((int*)mode->allocVectors);
-
+   
    celt_free((celt_word16_t*)mode->window);
 
-   mode->marker_start = MODEFREED;
-   mode->marker_end = MODEFREED;
 #ifndef SHORTCUTS
    psydecay_clear(&mode->psy);
 #endif
@@ -424,11 +491,14 @@ void celt_mode_destroy(CELTMode *mode)
    mdct_clear(&mode->shortMdct);
    pitch_state_free(mode->fft);
    quant_prob_free(mode->prob);
+   mode->marker_end = MODEFREED;
    celt_free((CELTMode *)mode);
 }
 
 int check_mode(const CELTMode *mode)
 {
+   if (mode==NULL)
+      return CELT_INVALID_MODE;
    if (mode->marker_start == MODEVALID && mode->marker_end == MODEVALID)
       return CELT_OK;
    if (mode->marker_start == MODEFREED || mode->marker_end == MODEFREED)
index e2c8f44..7661e0e 100644 (file)
@@ -53,7 +53,9 @@ void psydecay_init(struct PsyDecay *decay, int len, celt_int32_t Fs)
 {
    int i;
    celt_word16_t *decayR = (celt_word16_t*)celt_alloc(sizeof(celt_word16_t)*len);
-   /*decay->decayL = celt_alloc(sizeof(celt_word16_t)*len);*/
+   decay->decayR = decayR;
+   if (decayR==NULL)
+     return;
    for (i=0;i<len;i++)
    {
       float f;
@@ -70,7 +72,6 @@ void psydecay_init(struct PsyDecay *decay, int len, celt_int32_t Fs)
       /*decay->decayL[i] = Q15ONE*pow(0.0031623f, deriv);*/
       /*printf ("%f %f\n", decayL[i], decayR[i]);*/
    }
-   decay->decayR = decayR;
 }
 
 void psydecay_clear(struct PsyDecay *decay)
index 03db96f..c8fd5a1 100644 (file)
@@ -104,6 +104,8 @@ int *quant_prob_alloc(const CELTMode *m)
    int i;
    int *prob;
    prob = celt_alloc(4*m->nbEBands*sizeof(int));
+   if (prob==NULL)
+     return NULL;
    for (i=0;i<m->nbEBands;i++)
    {
       prob[2*i] = 6000-i*200;
index eac7c71..a85984e 100644 (file)
 celt_int16_t **compute_alloc_cache(CELTMode *m, int C)
 {
    int i, prevN;
+   int error = 0;
    celt_int16_t **bits;
    const celt_int16_t *eBands = m->eBands;
 
    bits = celt_alloc(m->nbEBands*sizeof(celt_int16_t*));
-   
+   if (bits==NULL)
+     return NULL;
+        
    prevN = -1;
    for (i=0;i<m->nbEBands;i++)
    {
@@ -62,10 +65,31 @@ celt_int16_t **compute_alloc_cache(CELTMode *m, int C)
          bits[i] = bits[i-1];
       } else {
          bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t));
-         get_required_bits(bits[i], N, MAX_PULSES, BITRES);
+         if (bits[i]!=NULL) {
+           get_required_bits(bits[i], N, MAX_PULSES, BITRES);
+         } else {
+            error=1;
+         }
          prevN = N;
       }
    }
+   if (error)
+   {
+      const celt_int16_t *prevPtr = NULL;
+      if (bits!=NULL)
+      {
+         for (i=0;i<m->nbEBands;i++)
+         {
+            if (bits[i] != prevPtr)
+            {
+               prevPtr = bits[i];
+               celt_free((int*)bits[i]);
+            }
+         }
+      free(bits);
+      bits=NULL;
+      }   
+   }
    return bits;
 }