Fixes an int overflow in the VBR code
[opus.git] / libcelt / celt.c
index 6f93bcf..8c30af9 100644 (file)
 
 #define CELT_C
 
+/* Always enable postfilter for Opus */
+#if defined(OPUS_BUILD) && !defined(ENABLE_POSTFILTER)
+#define ENABLE_POSTFILTER
+#endif
+
 #include "os_support.h"
 #include "mdct.h"
 #include <math.h>
@@ -99,6 +104,7 @@ struct CELTEncoder {
    int stream_channels;
    
    int force_intra;
+   int disable_pf;
    int complexity;
    int upsample;
    int start, end;
@@ -685,8 +691,8 @@ static void tf_encode(int start, int end, int isTransient, int *tf_res, int LM,
    int logp;
    ec_uint32 budget;
    ec_uint32 tell;
-   budget = enc->buf->storage*8;
-   tell = ec_enc_tell(enc, 0);
+   budget = enc->storage*8;
+   tell = ec_tell(enc);
    logp = isTransient ? 2 : 4;
    /* Reserve space to code the tf_select decision. */
    tf_select_rsv = LM>0 && tell+logp+1 <= budget;
@@ -697,7 +703,7 @@ static void tf_encode(int start, int end, int isTransient, int *tf_res, int LM,
       if (tell+logp<=budget)
       {
          ec_enc_bit_logp(enc, tf_res[i] ^ curr, logp);
-         tell = ec_enc_tell(enc, 0);
+         tell = ec_tell(enc);
          curr = tf_res[i];
          tf_changed |= curr;
       }
@@ -726,8 +732,8 @@ static void tf_decode(int start, int end, int isTransient, int *tf_res, int LM,
    ec_uint32 budget;
    ec_uint32 tell;
 
-   budget = dec->buf->storage*8;
-   tell = ec_dec_tell(dec, 0);
+   budget = dec->storage*8;
+   tell = ec_tell(dec);
    logp = isTransient ? 2 : 4;
    tf_select_rsv = LM>0 && tell+logp+1<=budget;
    budget -= tf_select_rsv;
@@ -737,7 +743,7 @@ static void tf_decode(int start, int end, int isTransient, int *tf_res, int LM,
       if (tell+logp<=budget)
       {
          curr ^= ec_dec_bit_logp(dec, logp);
-         tell = ec_dec_tell(dec, 0);
+         tell = ec_tell(dec);
          tf_changed |= curr;
       }
       tf_res[i] = curr;
@@ -855,16 +861,17 @@ static int stereo_analysis(const CELTMode *m, const celt_norm *X,
 }
 
 #ifdef FIXED_POINT
+CELT_STATIC
 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
 {
 #else
+CELT_STATIC
 int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
 {
 #endif
    int i, c, N;
    int bits;
-   ec_byte_buffer buf;
-   ec_enc         _enc;
+   ec_enc _enc;
    VARDECL(celt_sig, in);
    VARDECL(celt_sig, freq);
    VARDECL(celt_norm, X);
@@ -916,10 +923,10 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
      return CELT_BAD_ARG;
 
    frame_size *= st->upsample;
-   for (LM=0;LM<4;LM++)
+   for (LM=0;LM<=st->mode->maxLM;LM++)
       if (st->mode->shortMdctSize<<LM==frame_size)
          break;
-   if (LM>=MAX_CONFIG_SIZES)
+   if (LM>st->mode->maxLM)
       return CELT_BAD_ARG;
    M=1<<LM;
    N = M*st->mode->shortMdctSize;
@@ -936,14 +943,15 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       tell=1;
       nbFilledBytes=0;
    } else {
-      tell=ec_enc_tell(enc, 0);
+      tell=ec_tell(enc);
       nbFilledBytes=(tell+4)>>3;
    }
    nbAvailableBytes = nbCompressedBytes - nbFilledBytes;
 
    if (st->vbr)
    {
-      vbr_rate = ((2*st->bitrate*frame_size<<BITRES)+st->mode->Fs)/(2*st->mode->Fs);
+      celt_int32 den=st->mode->Fs>>BITRES;
+      vbr_rate=(st->bitrate*frame_size+(den>>1))/den;
       effectiveBytes = vbr_rate>>3;
    } else {
       celt_int32 tmp;
@@ -958,8 +966,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 
    if (enc==NULL)
    {
-      ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
-      ec_enc_init(&_enc,&buf);
+      ec_enc_init(&_enc, compressed, nbCompressedBytes);
       enc = &_enc;
    }
 
@@ -985,7 +992,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
          {
             nbCompressedBytes = nbFilledBytes+max_allowed;
             nbAvailableBytes = max_allowed;
-            ec_byte_shrink(enc->buf, nbCompressedBytes);
+            ec_enc_shrink(enc, nbCompressedBytes);
          }
       }
    }
@@ -1050,15 +1057,15 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
             effectiveBytes=nbCompressedBytes=IMIN(nbCompressedBytes, nbFilledBytes+2);
             total_bits=nbCompressedBytes*8;
             nbAvailableBytes=2;
-            ec_byte_shrink(enc->buf, nbCompressedBytes);
+            ec_enc_shrink(enc, nbCompressedBytes);
          }
          /* Pretend we've filled all the remaining bits with zeros
             (that's what the initialiser did anyway) */
          tell = nbCompressedBytes*8;
-         enc->nbits_total+=tell-ec_enc_tell(enc,0);
+         enc->nbits_total+=tell-ec_tell(enc);
       }
 #ifdef ENABLE_POSTFILTER
-      if (nbAvailableBytes>12*C && st->start==0 && !silence)
+      if (nbAvailableBytes>12*C && st->start==0 && !silence && !st->disable_pf && st->complexity >= 5)
       {
          VARDECL(celt_word16, pitch_buf);
          ALLOC(pitch_buf, (COMBFILTER_MAXPERIOD+N)>>1, celt_word16);
@@ -1097,38 +1104,42 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       pf_threshold = MAX16(pf_threshold, QCONST16(.2f,15));
       if (gain1<pf_threshold)
       {
-         if(st->start==0 && tell+17<=total_bits)
+         if(st->start==0 && tell+16<=total_bits)
             ec_enc_bit_logp(enc, 0, 1);
          gain1 = 0;
          pf_on = 0;
       } else {
+         /*This block is not gated by a total bits check only because
+           of the nbAvailableBytes check above.*/
          int qg;
          int octave;
 
-         if (gain1 > QCONST16(.6f,15))
-            gain1 = QCONST16(.6f,15);
          if (ABS16(gain1-st->prefilter_gain)<QCONST16(.1f,15))
             gain1=st->prefilter_gain;
 
 #ifdef FIXED_POINT
-         qg = ((gain1+2048)>>12)-2;
+         qg = ((gain1+1536)>>10)/3-1;
 #else
-         qg = floor(.5+gain1*8)-2;
+         qg = floor(.5+gain1*32/3)-1;
 #endif
+         qg = IMAX(0, IMIN(7, qg));
          ec_enc_bit_logp(enc, 1, 1);
          pitch_index += 1;
          octave = EC_ILOG(pitch_index)-5;
          ec_enc_uint(enc, octave, 6);
          ec_enc_bits(enc, pitch_index-(16<<octave), 4+octave);
          pitch_index -= 1;
-         ec_enc_bits(enc, qg, 2);
-         gain1 = QCONST16(.125f,15)*(qg+2);
-         ec_enc_icdf(enc, prefilter_tapset, tapset_icdf, 2);
+         ec_enc_bits(enc, qg, 3);
+         if (ec_tell(enc)+2<=total_bits)
+            ec_enc_icdf(enc, prefilter_tapset, tapset_icdf, 2);
+         else
+           prefilter_tapset = 0;
+         gain1 = QCONST16(0.09375f,15)*(qg+1);
          pf_on = 1;
       }
       /*printf("%d %f\n", pitch_index, gain1);*/
 #else /* ENABLE_POSTFILTER */
-      if(st->start==0 && tell+17<=total_bits)
+      if(st->start==0 && tell+16<=total_bits)
          ec_enc_bit_logp(enc, 0, 1);
       pf_on = 0;
 #endif /* ENABLE_POSTFILTER */
@@ -1165,7 +1176,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 
    isTransient = 0;
    shortBlocks = 0;
-   if (LM>0 && ec_enc_tell(enc, 0)+3<=total_bits)
+   if (LM>0 && ec_tell(enc)+3<=total_bits)
    {
       if (st->complexity > 1)
       {
@@ -1223,7 +1234,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    tf_encode(st->start, st->end, isTransient, tf_res, LM, tf_select, enc);
 
    st->spread_decision = SPREAD_NORMAL;
-   if (ec_enc_tell(enc, 0)+4<=total_bits)
+   if (ec_tell(enc)+4<=total_bits)
    {
       if (shortBlocks || st->complexity < 3 || nbAvailableBytes < 10*C)
       {
@@ -1272,7 +1283,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    dynalloc_logp = 6;
    total_bits<<=BITRES;
    total_boost = 0;
-   tell = ec_enc_tell(enc, BITRES);
+   tell = ec_tell_frac(enc);
    for (i=st->start;i<st->end;i++)
    {
       int width, quanta;
@@ -1291,7 +1302,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
          int flag;
          flag = j<offsets[i];
          ec_enc_bit_logp(enc, flag, dynalloc_loop_logp);
-         tell = ec_enc_tell(enc, BITRES);
+         tell = ec_tell_frac(enc);
          if (!flag)
             break;
          boost += quanta;
@@ -1309,7 +1320,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       alloc_trim = alloc_trim_analysis(st->mode, X, bandLogE,
             st->mode->nbEBands, LM, C, N);
       ec_enc_icdf(enc, alloc_trim, trim_icdf, 7);
-      tell = ec_enc_tell(enc, BITRES);
+      tell = ec_tell_frac(enc);
    }
 
    /* Variable bitrate */
@@ -1384,7 +1395,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
      }
      nbCompressedBytes = IMIN(nbCompressedBytes,nbAvailableBytes+nbFilledBytes);
      /* This moves the raw bits to take into account the new compressed size */
-     ec_byte_shrink(enc->buf, nbCompressedBytes);
+     ec_enc_shrink(enc, nbCompressedBytes);
    }
    if (C==2)
    {
@@ -1422,7 +1433,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    ALLOC(fine_priority, st->mode->nbEBands, int);
 
    /* bits =   packet size        -       where we are         - safety*/
-   bits = (nbCompressedBytes*8<<BITRES) - ec_enc_tell(enc, BITRES) - 1;
+   bits = (nbCompressedBytes*8<<BITRES) - ec_tell_frac(enc) - 1;
    anti_collapse_rsv = isTransient&&LM>=2&&bits>=(LM+2<<BITRES) ? (1<<BITRES) : 0;
    bits -= anti_collapse_rsv;
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets, cap,
@@ -1454,7 +1465,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       anti_collapse_on = st->consec_transient<2;
       ec_enc_bits(enc, anti_collapse_on, 1);
    }
-   quant_energy_finalise(st->mode, st->start, st->end, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
+   quant_energy_finalise(st->mode, st->start, st->end, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_tell(enc), enc, C);
 
    if (silence)
    {
@@ -1582,7 +1593,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    ec_enc_done(enc);
    
    RESTORE_STACK;
-   if (ec_enc_get_error(enc))
+   if (ec_get_error(enc))
       return CELT_CORRUPTED_DATA;
    else
       return nbCompressedBytes;
@@ -1590,6 +1601,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 
 #ifdef FIXED_POINT
 #ifndef DISABLE_FLOAT_API
+CELT_STATIC
 int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
 {
    int j, ret, C, N;
@@ -1618,6 +1630,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const float * pcm, int
 }
 #endif /*DISABLE_FLOAT_API*/
 #else
+CELT_STATIC
 int celt_encode_with_ec(CELTEncoder * restrict st, const celt_int16 * pcm, int frame_size, unsigned char *compressed, int nbCompressedBytes, ec_enc *enc)
 {
    int j, ret, C, N;
@@ -1701,14 +1714,8 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
          int value = va_arg(ap, celt_int32);
          if (value<0 || value>2)
             goto bad_arg;
-         if (value==0)
-         {
-            st->force_intra   = 1;
-         } else if (value==1) {
-            st->force_intra   = 0;
-         } else {
-            st->force_intra   = 0;
-         }   
+         st->disable_pf = value<=1;
+         st->force_intra = value==0;
       }
       break;
       case CELT_SET_VBR_CONSTRAINT_REQUEST:
@@ -2115,9 +2122,11 @@ static void celt_decode_lost(CELTDecoder * restrict st, celt_word16 * restrict p
 }
 
 #ifdef FIXED_POINT
+CELT_STATIC
 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
 {
 #else
+CELT_STATIC
 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig * restrict pcm, int frame_size, ec_dec *dec)
 {
 #endif
@@ -2125,7 +2134,6 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    int spread_decision;
    int bits;
    ec_dec _dec;
-   ec_byte_buffer buf;
    VARDECL(celt_sig, freq);
    VARDECL(celt_norm, X);
    VARDECL(celt_ener, bandE);
@@ -2171,10 +2179,10 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
       return CELT_BAD_ARG;
 
    frame_size *= st->downsample;
-   for (LM=0;LM<4;LM++)
+   for (LM=0;LM<=st->mode->maxLM;LM++)
       if (st->mode->shortMdctSize<<LM==frame_size)
          break;
-   if (LM>=MAX_CONFIG_SIZES)
+   if (LM>st->mode->maxLM)
       return CELT_BAD_ARG;
    M=1<<LM;
 
@@ -2220,8 +2228,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    
    if (dec == NULL)
    {
-      ec_byte_readinit(&buf,(unsigned char*)data,len);
-      ec_dec_init(&_dec,&buf);
+      ec_dec_init(&_dec,(unsigned char*)data,len);
       dec = &_dec;
    }
 
@@ -2236,7 +2243,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    }
 
    total_bits = len*8;
-   tell = ec_dec_tell(dec, 0);
+   tell = ec_tell(dec);
 
    if (tell==1)
       silence = ec_dec_bit_logp(dec, 15);
@@ -2246,13 +2253,13 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    {
       /* Pretend we've read all the remaining bits */
       tell = len*8;
-      dec->nbits_total+=tell-ec_dec_tell(dec,0);
+      dec->nbits_total+=tell-ec_tell(dec);
    }
 
    postfilter_gain = 0;
    postfilter_pitch = 0;
    postfilter_tapset = 0;
-   if (st->start==0 && tell+17 <= total_bits)
+   if (st->start==0 && tell+16 <= total_bits)
    {
       if(ec_dec_bit_logp(dec, 1))
       {
@@ -2260,21 +2267,22 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
          int qg, octave;
          octave = ec_dec_uint(dec, 6);
          postfilter_pitch = (16<<octave)+ec_dec_bits(dec, 4+octave)-1;
-         qg = ec_dec_bits(dec, 2);
-         postfilter_tapset = ec_dec_icdf(dec, tapset_icdf, 2);
-         postfilter_gain = QCONST16(.125f,15)*(qg+2);
+         qg = ec_dec_bits(dec, 3);
+         if (ec_tell(dec)+2<=total_bits)
+            postfilter_tapset = ec_dec_icdf(dec, tapset_icdf, 2);
+         postfilter_gain = QCONST16(.09375f,15)*(qg+1);
 #else /* ENABLE_POSTFILTER */
          RESTORE_STACK;
          return CELT_CORRUPTED_DATA;
 #endif /* ENABLE_POSTFILTER */
       }
-      tell = ec_dec_tell(dec, 0);
+      tell = ec_tell(dec);
    }
 
    if (LM > 0 && tell+3 <= total_bits)
    {
       isTransient = ec_dec_bit_logp(dec, 3);
-      tell = ec_dec_tell(dec, 0);
+      tell = ec_tell(dec);
    }
    else
       isTransient = 0;
@@ -2293,7 +2301,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    ALLOC(tf_res, st->mode->nbEBands, int);
    tf_decode(st->start, st->end, isTransient, tf_res, LM, dec);
 
-   tell = ec_dec_tell(dec, 0);
+   tell = ec_tell(dec);
    spread_decision = SPREAD_NORMAL;
    if (tell+4 <= total_bits)
       spread_decision = ec_dec_icdf(dec, spread_icdf, 5);
@@ -2307,7 +2315,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 
    dynalloc_logp = 6;
    total_bits<<=BITRES;
-   tell = ec_dec_tell(dec, BITRES);
+   tell = ec_tell_frac(dec);
    for (i=st->start;i<st->end;i++)
    {
       int width, quanta;
@@ -2323,7 +2331,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
       {
          int flag;
          flag = ec_dec_bit_logp(dec, dynalloc_loop_logp);
-         tell = ec_dec_tell(dec, BITRES);
+         tell = ec_tell_frac(dec);
          if (!flag)
             break;
          boost += quanta;
@@ -2340,7 +2348,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    alloc_trim = tell+(6<<BITRES) <= total_bits ?
          ec_dec_icdf(dec, trim_icdf, 7) : 5;
 
-   bits = (len*8<<BITRES) - ec_dec_tell(dec, BITRES) - 1;
+   bits = (len*8<<BITRES) - ec_tell_frac(dec) - 1;
    anti_collapse_rsv = isTransient&&LM>=2&&bits>=(LM+2<<BITRES) ? (1<<BITRES) : 0;
    bits -= anti_collapse_rsv;
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets, cap,
@@ -2361,7 +2369,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    }
 
    unquant_energy_finalise(st->mode, st->start, st->end, oldBandE,
-         fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
+         fine_quant, fine_priority, len*8-ec_tell(dec), dec, C);
 
    if (anti_collapse_on)
       anti_collapse(st->mode, X, collapse_masks, LM, C, CC, N,
@@ -2392,7 +2400,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
       int bound = M*st->mode->eBands[effEnd];
       if (st->downsample!=1)
          bound = IMIN(bound, N/st->downsample);
-      for (i=M*st->mode->eBands[effEnd];i<N;i++)
+      for (i=bound;i<N;i++)
          freq[c*N+i] = 0;
    } while (++c<C);
 
@@ -2465,7 +2473,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    deemphasis(out_syn, pcm, N, CC, st->downsample, st->mode->preemph, st->preemph_memD);
    st->loss_count = 0;
    RESTORE_STACK;
-   if (ec_dec_tell(dec,0) > 8*len || ec_dec_get_error(dec))
+   if (ec_tell(dec) > 8*len || ec_get_error(dec))
       return CELT_CORRUPTED_DATA;
    else
       return CELT_OK;
@@ -2473,6 +2481,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 
 #ifdef FIXED_POINT
 #ifndef DISABLE_FLOAT_API
+CELT_STATIC
 int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size, ec_dec *dec)
 {
    int j, ret, C, N;
@@ -2497,6 +2506,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 }
 #endif /*DISABLE_FLOAT_API*/
 #else
+CELT_STATIC
 int celt_decode_with_ec(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16 * restrict pcm, int frame_size, ec_dec *dec)
 {
    int j, ret, C, N;