Move tf_select before the tf_res bits.
[opus.git] / libcelt / celt.c
index c034f24..dcce270 100644 (file)
@@ -53,7 +53,8 @@
 #include <stdarg.h>
 #include "plc.h"
 
-static const int trim_cdf[12] = {0, 2, 4, 9, 19, 41, 87, 109, 119, 124, 126, 128};
+static const unsigned trim_cdf[12] = {0, 2, 4, 9, 19, 41, 87, 109, 119, 124, 126, 128};
+static const unsigned spread_cdf[5] = {0, 7, 9, 30, 32};
 
 #define COMBFILTER_MAXPERIOD 1024
 #define COMBFILTER_MINPERIOD 16
@@ -77,7 +78,7 @@ struct CELTEncoder {
 #define ENCODER_RESET_START frame_max
 
    celt_word32 frame_max;
-   int fold_decision;
+   int spread_decision;
    int delayedIntra;
    int tonal_average;
    int lastCodedBands;
@@ -155,7 +156,7 @@ CELTEncoder *celt_encoder_init(CELTEncoder *st, const CELTMode *mode, int channe
    st->force_intra  = 0;
    st->delayedIntra = 1;
    st->tonal_average = 256;
-   st->fold_decision = 1;
+   st->spread_decision = SPREAD_NORMAL;
    st->complexity = 5;
 
    if (error)
@@ -386,7 +387,6 @@ static void deemphasis(celt_sig *in[], celt_word16 *pcm, int N, int _C, const ce
 }
 
 #ifdef ENABLE_POSTFILTER
-/* FIXME: Handle the case where T = maxperiod */
 static void comb_filter(celt_word32 *y, celt_word32 *x, int T0, int T1, int N,
       int C, celt_word16 g0, celt_word16 g1, const celt_word16 *window, int overlap)
 {
@@ -427,9 +427,9 @@ static void comb_filter(celt_word32 *y, celt_word32 *x, int T0, int T1, int N,
 
 static const signed char tf_select_table[4][8] = {
       {0, -1, 0, -1,    0,-1, 0,-1},
-      {0, -1, 0, -2,    1, 0, 1 -1},
-      {0, -2, 0, -3,    2, 0, 1 -1},
-      {0, -2, 0, -3,    2, 0, 1 -1},
+      {0, -1, 0, -2,    1, 0, 1,-1},
+      {0, -2, 0, -3,    2, 0, 1,-1},
+      {0, -2, 0, -3,    2, 0, 1,-1},
 };
 
 static celt_word32 l1_metric(const celt_norm *tmp, int N, int LM, int width)
@@ -472,7 +472,6 @@ static int tf_analysis(const CELTMode *m, celt_word16 *bandLogE, celt_word16 *ol
    int tf_select=0;
    SAVE_STACK;
 
-   /* FIXME: Should check number of bytes *left* */
    if (nbCompressedBytes<15*C)
    {
       *tf_sum = 0;
@@ -503,7 +502,7 @@ static int tf_analysis(const CELTMode *m, celt_word16 *bandLogE, celt_word16 *ol
       N = (m->eBands[i+1]-m->eBands[i])<<LM;
       for (j=0;j<N;j++)
          tmp[j] = X[j+(m->eBands[i]<<LM)];
-      /* FIXME: Do something with the right channel */
+      /* Just add the right channel if we're in stereo */
       if (C==2)
          for (j=0;j<N;j++)
             tmp[j] = ADD16(tmp[j],X[N0+j+(m->eBands[i]<<LM)]);
@@ -591,15 +590,15 @@ static int tf_analysis(const CELTMode *m, celt_word16 *bandLogE, celt_word16 *ol
 static void tf_encode(int start, int end, int isTransient, int *tf_res, int LM, int tf_select, ec_enc *enc)
 {
    int curr, i;
-   ec_enc_bit_prob(enc, tf_res[start], isTransient ? 16384 : 4096);
+   if (LM!=0)
+      ec_enc_bit_logp(enc, tf_select, 1);
+   ec_enc_bit_logp(enc, tf_res[start], isTransient ? 2 : 4);
    curr = tf_res[start];
    for (i=start+1;i<end;i++)
    {
-      ec_enc_bit_prob(enc, tf_res[i] ^ curr, isTransient ? 4096 : 2048);
+      ec_enc_bit_logp(enc, tf_res[i] ^ curr, isTransient ? 4 : 5);
       curr = tf_res[i];
    }
-   if (LM!=0)
-      ec_enc_bits(enc, tf_select, 1);
    for (i=start;i<end;i++)
       tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
    /*printf("%d %d ", isTransient, tf_select); for(i=0;i<end;i++)printf("%d ", tf_res[i]);printf("\n");*/
@@ -608,19 +607,17 @@ static void tf_encode(int start, int end, int isTransient, int *tf_res, int LM,
 static void tf_decode(int start, int end, int C, int isTransient, int *tf_res, int LM, ec_dec *dec)
 {
    int i, curr, tf_select;
-   tf_res[start] = ec_dec_bit_prob(dec, isTransient ? 16384 : 4096);
-   curr = tf_res[start];
-   for (i=start+1;i<end;i++)
-   {
-      tf_res[i] = ec_dec_bit_prob(dec, isTransient ? 4096 : 2048) ^ curr;
-      curr = tf_res[i];
-   }
    if (LM!=0)
-      tf_select = ec_dec_bits(dec, 1);
+      tf_select = ec_dec_bit_logp(dec, 1);
    else
       tf_select = 0;
-   for (i=start;i<end;i++)
-      tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+tf_res[i]];
+   curr = ec_dec_bit_logp(dec, isTransient ? 2 : 4);
+   tf_res[start] = tf_select_table[LM][4*isTransient+2*tf_select+curr];
+   for (i=start+1;i<end;i++)
+   {
+      curr = ec_dec_bit_logp(dec, isTransient ? 4 : 5) ^ curr;
+      tf_res[i] = tf_select_table[LM][4*isTransient+2*tf_select+curr];
+   }
 }
 
 static int alloc_trim_analysis(const CELTMode *m, const celt_norm *X,
@@ -719,7 +716,6 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 #endif
    int i, c, N;
    int bits;
-   int has_fold=1;
    ec_byte_buffer buf;
    ec_enc         _enc;
    VARDECL(celt_sig, in);
@@ -752,6 +748,8 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    int intensity=0;
    int dual_stereo=0;
    int effectiveBytes;
+   celt_word16 pf_threshold;
+   int dynalloc_prob;
    SAVE_STACK;
 
    if (nbCompressedBytes<0 || pcm==NULL)
@@ -822,7 +820,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       } while (++c<C);
 
 #ifdef ENABLE_POSTFILTER
-      if (LM != 0 && nbAvailableBytes>10)
+      if (nbAvailableBytes>12*C && st->start==0)
       {
          VARDECL(celt_word16, pitch_buf);
          ALLOC(pitch_buf, (COMBFILTER_MAXPERIOD+N)>>1, celt_word16);
@@ -841,26 +839,46 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
          if (pitch_index > COMBFILTER_MAXPERIOD)
             pitch_index = COMBFILTER_MAXPERIOD;
          gain1 = MULT16_16_Q15(QCONST16(.7f,15),gain1);
-         if (gain1 > QCONST16(.6f,15))
-            gain1 = QCONST16(.6f,15);
-         if (ABS16(gain1-st->prefilter_gain)<QCONST16(.1,15))
-            gain1=st->prefilter_gain;
       } else {
          gain1 = 0;
       }
-      if (gain1<QCONST16(.2f,15) || (nbAvailableBytes<30 && gain1<QCONST16(.4f,15)))
+
+      /* Gain threshold for enabling the prefilter/postfilter */
+      pf_threshold = QCONST16(.2f,15);
+
+      /* Adjusting the threshold based on rate and continuity */
+      if (abs(pitch_index-st->prefilter_period)*10>pitch_index)
+         pf_threshold += QCONST16(.2f,15);
+      if (nbAvailableBytes<25)
+         pf_threshold += QCONST16(.1f,15);
+      if (nbAvailableBytes<35)
+         pf_threshold += QCONST16(.1f,15);
+      if (st->prefilter_gain > QCONST16(.4f,15))
+         pf_threshold -= QCONST16(.1f,15);
+      if (st->prefilter_gain > QCONST16(.55f,15))
+         pf_threshold -= QCONST16(.1f,15);
+
+      /* Hard threshold at 0.2 */
+      pf_threshold = MAX16(pf_threshold, QCONST16(.2f,15));
+      if (gain1<pf_threshold)
       {
-         ec_enc_bit_prob(enc, 0, 32768);
+         ec_enc_bit_logp(enc, 0, 1);
          gain1 = 0;
       } else {
          int qg;
          int octave;
+
+         if (gain1 > QCONST16(.6f,15))
+            gain1 = QCONST16(.6f,15);
+         if (ABS16(gain1-st->prefilter_gain)<QCONST16(.1,15))
+            gain1=st->prefilter_gain;
+
 #ifdef FIXED_POINT
          qg = ((gain1+2048)>>12)-2;
 #else
          qg = floor(.5+gain1*8)-2;
 #endif
-         ec_enc_bit_prob(enc, 1, 32768);
+         ec_enc_bit_logp(enc, 1, 1);
          octave = EC_ILOG(pitch_index)-5;
          ec_enc_uint(enc, octave, 6);
          ec_enc_bits(enc, pitch_index-(16<<octave), 4+octave);
@@ -869,7 +887,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       }
       /*printf("%d %f\n", pitch_index, gain1);*/
 #else /* ENABLE_POSTFILTER */
-      ec_enc_bit_prob(enc, 0, 32768);
+      ec_enc_bit_logp(enc, 0, 1);
 #endif /* ENABLE_POSTFILTER */
 
       c=0; do {
@@ -942,25 +960,24 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
          &st->delayedIntra, st->complexity >= 4);
 
    if (LM > 0)
-      ec_enc_bit_prob(enc, shortBlocks!=0, 8192);
+      ec_enc_bit_logp(enc, shortBlocks!=0, 3);
 
    tf_encode(st->start, st->end, isTransient, tf_res, LM, tf_select, enc);
 
-   if (shortBlocks || st->complexity < 3)
+   if (shortBlocks || st->complexity < 3 || nbAvailableBytes < 10*C)
    {
       if (st->complexity == 0)
       {
-         has_fold = 0;
-         st->fold_decision = 3;
+         st->spread_decision = SPREAD_NONE;
       } else {
-         has_fold = 1;
-         st->fold_decision = 1;
+         st->spread_decision = SPREAD_NORMAL;
       }
    } else {
-      has_fold = folding_decision(st->mode, X, &st->tonal_average, &st->fold_decision, effEnd, C, M);
+      st->spread_decision = spreading_decision(st->mode, X, &st->tonal_average, st->spread_decision, effEnd, C, M);
    }
-   ec_enc_bit_prob(enc, has_fold>>1, 8192);
-   ec_enc_bit_prob(enc, has_fold&1, (has_fold>>1) ? 32768 : 49152);
+   /* Probs: NONE: 21.875%, LIGHT: 6.25%, NORMAL: 65.625%, AGGRESSIVE: 6.25% */
+   ec_encode_bin(enc, spread_cdf[st->spread_decision],
+         spread_cdf[st->spread_decision+1], 5);
 
    ALLOC(offsets, st->mode->nbEBands, int);
 
@@ -992,17 +1009,25 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
             offsets[i] += 1;
       }
    }
+   dynalloc_prob = 6;
    for (i=0;i<st->mode->nbEBands;i++)
    {
       int j;
-      ec_enc_bit_prob(enc, offsets[i]!=0, 1024);
+      ec_enc_bit_logp(enc, offsets[i]!=0, dynalloc_prob);
       if (offsets[i]!=0)
       {
+         int width, quanta;
+         width = C*(st->mode->eBands[i+1]-st->mode->eBands[i])<<LM;
+         /* quanta is 6 bits, but no more than 1 bit/sample
+            and no less than 1/8 bit/sample */
+         quanta = IMIN(width<<BITRES, IMAX(6<<BITRES, width));
          for (j=0;j<offsets[i]-1;j++)
-            ec_enc_bit_prob(enc, 1, 32768);
-         ec_enc_bit_prob(enc, 0, 32768);
+            ec_enc_bit_logp(enc, 1, 1);
+         ec_enc_bit_logp(enc, 0, 1);
+         offsets[i] *= quanta;
+         /* Making dynalloc more likely */
+         dynalloc_prob = IMAX(2, dynalloc_prob-1);
       }
-      offsets[i] *= (6<<BITRES);
    }
    alloc_trim = alloc_trim_analysis(st->mode, X, bandLogE, st->mode->nbEBands, LM, C, N);
    ec_encode_bin(enc, trim_cdf[alloc_trim], trim_cdf[alloc_trim+1], 7);
@@ -1084,8 +1109,12 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 
    if (C==2)
    {
-      dual_stereo = stereo_analysis(st->mode, X, st->mode->nbEBands, LM, C, N);
-      ec_enc_bit_prob(enc, dual_stereo, 32768);
+      /* Always use MS for 2.5 ms frames until we can do a better analysis */
+      if (LM==0)
+         dual_stereo = 0;
+      else
+         dual_stereo = stereo_analysis(st->mode, X, st->mode->nbEBands, LM, C, N);
+      ec_enc_bit_logp(enc, dual_stereo, 1);
    }
    if (C==2)
    {
@@ -1119,8 +1148,8 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    ALLOC(pulses, st->mode->nbEBands, int);
    ALLOC(fine_priority, st->mode->nbEBands, int);
 
-   /* bits =   packet size    -    where we are   - safety */
-   bits = nbCompressedBytes*8 - ec_enc_tell(enc, 0) - 1;
+   /* bits =   packet size        -       where we are           - safety */
+   bits = (nbCompressedBytes*8<<BITRES) - ec_enc_tell(enc, BITRES) - 1;
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, bits, pulses, fine_quant, fine_priority, C, LM, enc, 1, st->lastCodedBands);
    st->lastCodedBands = codedBands;
@@ -1140,7 +1169,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 
    /* Residual quantisation */
    quant_all_bands(1, st->mode, st->start, st->end, X, C==2 ? X+N : NULL,
-         bandE, pulses, shortBlocks, has_fold, dual_stereo, intensity, tf_res, resynth,
+         bandE, pulses, shortBlocks, st->spread_decision, dual_stereo, intensity, tf_res, resynth,
          nbCompressedBytes*8, enc, LM, codedBands);
 
    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
@@ -1384,7 +1413,7 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
                ((char*)&st->ENCODER_RESET_START - (char*)st));
          st->vbr_offset = 0;
          st->delayedIntra = 1;
-         st->fold_decision = 1;
+         st->spread_decision = SPREAD_NORMAL;
          st->tonal_average = QCONST16(1.f,8);
       }
       break;
@@ -1691,7 +1720,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 {
 #endif
    int c, i, N;
-   int has_fold;
+   int spread_decision;
    int bits;
    ec_dec _dec;
    ec_byte_buffer buf;
@@ -1723,6 +1752,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    celt_word16 postfilter_gain;
    int intensity=0;
    int dual_stereo=0;
+   int dynalloc_prob;
    SAVE_STACK;
 
    if (pcm==NULL)
@@ -1783,7 +1813,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    }
    nbAvailableBytes = len-nbFilledBytes;
 
-   if (ec_dec_bit_prob(dec, 32768))
+   if (ec_dec_bit_logp(dec, 1))
    {
 #ifdef ENABLE_POSTFILTER
       int qg, octave;
@@ -1802,13 +1832,13 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    }
 
    /* Decode the global flags (first symbols in the stream) */
-   intra_ener = ec_dec_bit_prob(dec, 8192);
+   intra_ener = ec_dec_bit_logp(dec, 3);
    /* Get band energies */
    unquant_coarse_energy(st->mode, st->start, st->end, bandE, oldBandE,
          intra_ener, dec, C, LM);
 
    if (LM > 0)
-      isTransient = ec_dec_bit_prob(dec, 8192);
+      isTransient = ec_dec_bit_logp(dec, 3);
    else
       isTransient = 0;
 
@@ -1820,8 +1850,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    ALLOC(tf_res, st->mode->nbEBands, int);
    tf_decode(st->start, st->end, C, isTransient, tf_res, LM, dec);
 
-   has_fold = ec_dec_bit_prob(dec, 8192)<<1;
-   has_fold |= ec_dec_bit_prob(dec, (has_fold>>1) ? 32768 : 49152);
+   spread_decision = ec_dec_cdf(dec, spread_cdf, 5);
 
    ALLOC(pulses, st->mode->nbEBands, int);
    ALLOC(offsets, st->mode->nbEBands, int);
@@ -1829,34 +1858,35 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 
    for (i=0;i<st->mode->nbEBands;i++)
       offsets[i] = 0;
+   dynalloc_prob = 6;
    for (i=0;i<st->mode->nbEBands;i++)
    {
-      if (ec_dec_bit_prob(dec, 1024))
+      if (ec_dec_bit_logp(dec, dynalloc_prob))
       {
-         while (ec_dec_bit_prob(dec, 32768))
+         int width, quanta;
+         width = C*(st->mode->eBands[i+1]-st->mode->eBands[i])<<LM;
+         /* quanta is 6 bits, but no more than 1 bit/sample
+            and no less than 1/8 bit/sample */
+         quanta = IMIN(width<<BITRES, IMAX(6<<BITRES, width));
+         while (ec_dec_bit_logp(dec, 1))
             offsets[i]++;
          offsets[i]++;
-         offsets[i] *= (6<<BITRES);
+         offsets[i] *= quanta;
+         /* Making dynalloc more likely */
+         dynalloc_prob = IMAX(2, dynalloc_prob-1);
       }
    }
 
    ALLOC(fine_quant, st->mode->nbEBands, int);
-   {
-      int fl;
-      alloc_trim = 0;
-      fl = ec_decode_bin(dec, 7);
-      while (trim_cdf[alloc_trim+1] <= fl)
-         alloc_trim++;
-      ec_dec_update(dec, trim_cdf[alloc_trim], trim_cdf[alloc_trim+1], 128);
-   }
+   alloc_trim = ec_dec_cdf(dec, trim_cdf, 7);
 
    if (C==2)
    {
-      dual_stereo = ec_dec_bit_prob(dec, 32768);
+      dual_stereo = ec_dec_bit_logp(dec, 1);
       intensity = ec_dec_uint(dec, 1+st->end-st->start);
    }
 
-   bits = len*8 - ec_dec_tell(dec, 0) - 1;
+   bits = (len*8<<BITRES) - ec_dec_tell(dec, BITRES) - 1;
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, bits, pulses, fine_quant, fine_priority, C, LM, dec, 0, 0);
    
@@ -1864,7 +1894,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
 
    /* Decode fixed codebook */
    quant_all_bands(0, st->mode, st->start, st->end, X, C==2 ? X+N : NULL,
-         NULL, pulses, shortBlocks, has_fold, dual_stereo, intensity, tf_res, 1,
+         NULL, pulses, shortBlocks, spread_decision, dual_stereo, intensity, tf_res, 1,
          len*8, dec, LM, codedBands);
 
    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
@@ -1919,7 +1949,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
    st->loss_count = 0;
    RESTORE_STACK;
-   if (ec_dec_get_error(dec))
+   if (ec_dec_tell(dec,0) > 8*len || ec_dec_get_error(dec))
       return CELT_CORRUPTED_DATA;
    else
       return CELT_OK;