Adds an anti-collapse mechanism for transients
authorJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Tue, 18 Jan 2011 19:44:04 +0000 (14:44 -0500)
committerJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Tue, 18 Jan 2011 19:44:04 +0000 (14:44 -0500)
This looks for bands in each short block that have no energy. For
each of these "collapsed" bands, noise is injected to have an
energy equal to the minimum of the two previous frames for that band.
The mechanism can be used whenever there are 4 or more MDCTs (otherwise
no complete collapse is possible) and is signalled with one bit just
before the final fine energy bits.

libcelt/bands.c
libcelt/bands.h
libcelt/celt.c

index 6e3e8f9..0951c54 100644 (file)
 #include "mathops.h"
 #include "rate.h"
 
 #include "mathops.h"
 #include "rate.h"
 
+static celt_uint32 lcg_rand(celt_uint32 seed)
+{
+   return 1664525 * seed + 1013904223;
+}
+
 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
    with this approximation is important because it has an impact on the bit allocation */
 static celt_int16 bitexact_cos(celt_int16 x)
 /* This is a cos() approximation designed to be bit-exact on any platform. Bit exactness
    with this approximation is important because it has an impact on the bit allocation */
 static celt_int16 bitexact_cos(celt_int16 x)
@@ -206,6 +211,74 @@ void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig
    } while (++c<C);
 }
 
    } while (++c<C);
 }
 
+/* This prevents energy collapse for transients with multiple short MDCTs */
+void anti_collapse(const CELTMode *m, celt_norm *_X, int LM, int C, int size,
+      int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
+      celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
+{
+   int c, i, j, k;
+   c=0; do
+   {
+      for (i=start;i<end;i++)
+      {
+         celt_norm *X;
+         int N0;
+         celt_word16 Ediff;
+         celt_word16 r;
+         celt_word16 thresh;
+         int depth;
+
+         N0 = m->eBands[i+1]-m->eBands[i];
+         Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
+         Ediff = MAX16(0, Ediff);
+         depth = (1+(pulses[i]>>BITRES))/(m->eBands[i+1]-m->eBands[i]<<LM);
+
+#ifdef FIXED_POINT
+         thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11)),1) ));
+         if (Ediff < 16384)
+            r = 2*MIN16(16383,SHR32(celt_exp2(-SHL16(Ediff, 11-DB_SHIFT)),1));
+         else
+            r = 0;
+         r = SHR16(MIN16(thresh, r),1);
+         {
+            int shift;
+            celt_word32 t;
+            t = N0<<LM;
+            shift = celt_ilog2(t)>>1;
+            t = SHL32(t, (7-shift)<<1);
+            r = SHR32(MULT16_16_Q15(celt_rsqrt_norm(t), r),shift);
+         }
+#else
+         thresh = .3f*celt_exp2(-depth);
+         r = 2.f*celt_exp2(-Ediff);
+         r = MIN16(thresh, r);
+         r = r*celt_rsqrt(N0<<LM);
+#endif
+         X = _X+c*size+(m->eBands[i]<<LM);
+         for (k=0;k<1<<LM;k++)
+         {
+            celt_word32 sum=0;
+            /* Detect collapse */
+            for (j=0;j<N0;j++)
+               sum += ABS16(X[(j<<LM)+k]);
+            if (sum<QCONST16(1e-4, 14))
+            {
+               /* Fill with noise */
+               for (j=0;j<N0;j++)
+               {
+                  seed = lcg_rand(seed);
+                  X[(j<<LM)+k] = (seed&0x8000 ? r : -r);
+               }
+            }
+         }
+         /* We just added some energy, so we need to renormalise */
+         renormalise_vector(X, N0<<LM, Q15ONE);
+      }
+   } while (++c<C);
+
+}
+
+
 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
 {
    int i = bandID;
 static void intensity_stereo(const CELTMode *m, celt_norm *X, celt_norm *Y, const celt_ener *bank, int bandID, int N)
 {
    int i = bandID;
@@ -528,11 +601,6 @@ static int compute_qn(int N, int b, int offset, int stereo)
    return qn;
 }
 
    return qn;
 }
 
-static celt_uint32 lcg_rand(celt_uint32 seed)
-{
-   return 1664525 * seed + 1013904223;
-}
-
 /* This function is responsible for encoding and decoding a band for both
    the mono and stereo case. Even in the mono case, it can split the band
    in two and transmit the energy difference with the two half-bands. It
 /* This function is responsible for encoding and decoding a band for both
    the mono and stereo case. Even in the mono case, it can split the band
    in two and transmit the energy difference with the two half-bands. It
@@ -1054,7 +1122,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       /* Compute how many bits we want to allocate to this band */
       if (i != start)
          balance -= tell;
       /* Compute how many bits we want to allocate to this band */
       if (i != start)
          balance -= tell;
-      remaining_bits = ((celt_int32)total_bits<<BITRES)-tell-1;
+      remaining_bits = ((celt_int32)total_bits<<BITRES)-tell-1- (shortBlocks&&LM>=2 ? (1<<BITRES) : 0);
       if (i <= codedBands-1)
       {
          curr_balance = balance / IMIN(3, codedBands-i);
       if (i <= codedBands-1)
       {
          curr_balance = balance / IMIN(3, codedBands-i);
index aa3138f..2bad608 100644 (file)
@@ -93,4 +93,8 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
 
 void stereo_decision(const CELTMode *m, celt_norm * restrict X, int *stereo_mode, int len, int M);
 
 
 void stereo_decision(const CELTMode *m, celt_norm * restrict X, int *stereo_mode, int len, int M);
 
+void anti_collapse(const CELTMode *m, celt_norm *_X, int LM, int C, int size,
+      int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
+      celt_word16 *prev2logE, int *pulses, celt_uint32 seed);
+
 #endif /* BANDS_H */
 #endif /* BANDS_H */
index 6af821c..6ea83eb 100644 (file)
@@ -96,6 +96,7 @@ struct CELTEncoder {
    celt_word16 prefilter_gain_old;
    int prefilter_tapset_old;
 #endif
    celt_word16 prefilter_gain_old;
    int prefilter_tapset_old;
 #endif
+   int consec_transient;
 
    /* VBR-related parameters */
    celt_int32 vbr_reservoir;
 
    /* VBR-related parameters */
    celt_int32 vbr_reservoir;
@@ -113,7 +114,7 @@ struct CELTEncoder {
    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
    /* celt_sig prefilter_mem[],  Size = channels*COMBFILTER_PERIOD */
    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
    celt_sig in_mem[1]; /* Size = channels*mode->overlap */
    /* celt_sig prefilter_mem[],  Size = channels*COMBFILTER_PERIOD */
    /* celt_sig overlap_mem[],  Size = channels*mode->overlap */
-   /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
+   /* celt_word16 oldEBands[], Size = 2*channels*mode->nbEBands */
 };
 
 int celt_encoder_get_size(const CELTMode *mode, int channels)
 };
 
 int celt_encoder_get_size(const CELTMode *mode, int channels)
@@ -121,7 +122,7 @@ int celt_encoder_get_size(const CELTMode *mode, int channels)
    int size = sizeof(struct CELTEncoder)
          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
          + channels*COMBFILTER_MAXPERIOD*sizeof(celt_sig)
    int size = sizeof(struct CELTEncoder)
          + (2*channels*mode->overlap-1)*sizeof(celt_sig)
          + channels*COMBFILTER_MAXPERIOD*sizeof(celt_sig)
-         + channels*mode->nbEBands*sizeof(celt_word16);
+         + 2*channels*mode->nbEBands*sizeof(celt_word16);
    return size;
 }
 
    return size;
 }
 
@@ -784,6 +785,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    VARDECL(celt_norm, X);
    VARDECL(celt_ener, bandE);
    VARDECL(celt_word16, bandLogE);
    VARDECL(celt_norm, X);
    VARDECL(celt_ener, bandE);
    VARDECL(celt_word16, bandLogE);
+   VARDECL(celt_word16, oldLogE);
    VARDECL(int, fine_quant);
    VARDECL(celt_word16, error);
    VARDECL(int, pulses);
    VARDECL(int, fine_quant);
    VARDECL(celt_word16, error);
    VARDECL(int, pulses);
@@ -792,7 +794,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    VARDECL(int, tf_res);
    celt_sig *_overlap_mem;
    celt_sig *prefilter_mem;
    VARDECL(int, tf_res);
    celt_sig *_overlap_mem;
    celt_sig *prefilter_mem;
-   celt_word16 *oldBandE;
+   celt_word16 *oldBandE, *oldLogE2;
    int shortBlocks=0;
    int isTransient=0;
    int resynth;
    int shortBlocks=0;
    int isTransient=0;
    int resynth;
@@ -817,6 +819,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    celt_int32 tell;
    int prefilter_tapset=0;
    int pf_on;
    celt_int32 tell;
    int prefilter_tapset=0;
    int pf_on;
+   int anti_collapse_on=0;
    SAVE_STACK;
 
    if (nbCompressedBytes<0 || pcm==NULL)
    SAVE_STACK;
 
    if (nbCompressedBytes<0 || pcm==NULL)
@@ -833,6 +836,7 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    _overlap_mem = prefilter_mem+C*COMBFILTER_MAXPERIOD;
    /*_overlap_mem = st->in_mem+C*(st->overlap);*/
    oldBandE = (celt_word16*)(st->in_mem+C*(2*st->overlap+COMBFILTER_MAXPERIOD));
    _overlap_mem = prefilter_mem+C*COMBFILTER_MAXPERIOD;
    /*_overlap_mem = st->in_mem+C*(st->overlap);*/
    oldBandE = (celt_word16*)(st->in_mem+C*(2*st->overlap+COMBFILTER_MAXPERIOD));
+   oldLogE2 = oldBandE + C*st->mode->nbEBands;
 
    if (enc==NULL)
    {
 
    if (enc==NULL)
    {
@@ -1059,6 +1063,9 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    for (i=effEnd;i<st->end;i++)
       tf_res[i] = tf_res[effEnd-1];
 
    for (i=effEnd;i<st->end;i++)
       tf_res[i] = tf_res[effEnd-1];
 
+   ALLOC(oldLogE, C*st->mode->nbEBands, celt_word16);
+   for (i=0;i<C*st->mode->nbEBands;i++)
+      oldLogE[i] = oldBandE[i];
    ALLOC(error, C*st->mode->nbEBands, celt_word16);
    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
          oldBandE, total_bits, error, enc,
    ALLOC(error, C*st->mode->nbEBands, celt_word16);
    quant_coarse_energy(st->mode, st->start, st->end, effEnd, bandLogE,
          oldBandE, total_bits, error, enc,
@@ -1258,8 +1265,8 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
    ALLOC(pulses, st->mode->nbEBands, int);
    ALLOC(fine_priority, st->mode->nbEBands, int);
 
    ALLOC(pulses, st->mode->nbEBands, int);
    ALLOC(fine_priority, st->mode->nbEBands, int);
 
-   /* bits =   packet size        -       where we are           - safety */
-   bits = (nbCompressedBytes*8<<BITRES) - ec_enc_tell(enc, BITRES) - 1;
+   /* bits =   packet size        -       where we are         - safety -  anti-collapse*/
+   bits = (nbCompressedBytes*8<<BITRES) - ec_enc_tell(enc, BITRES) - 1 - (isTransient&&LM>=2 ? (1<<BITRES) : 0);
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, &intensity, &dual_stereo, bits, pulses, fine_quant,
          fine_priority, C, LM, enc, 1, st->lastCodedBands);
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, &intensity, &dual_stereo, bits, pulses, fine_quant,
          fine_priority, C, LM, enc, 1, st->lastCodedBands);
@@ -1283,6 +1290,11 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
          bandE, pulses, shortBlocks, st->spread_decision, dual_stereo, intensity, tf_res, resynth,
          nbCompressedBytes*8, enc, LM, codedBands);
 
          bandE, pulses, shortBlocks, st->spread_decision, dual_stereo, intensity, tf_res, resynth,
          nbCompressedBytes*8, enc, LM, codedBands);
 
+   if (isTransient && LM>=2)
+   {
+      anti_collapse_on = st->consec_transient<2;
+      ec_enc_bits(enc, anti_collapse_on, 1);
+   }
    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
 
 #ifdef RESYNTH
    quant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(enc, 0), enc, C);
 
 #ifdef RESYNTH
@@ -1297,6 +1309,11 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
 #ifdef MEASURE_NORM_MSE
       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
 #endif
 #ifdef MEASURE_NORM_MSE
       measure_norm_mse(st->mode, X, X0, bandE, bandE0, M, N, C);
 #endif
+      if (anti_collapse_on)
+      {
+         anti_collapse(st->mode, X, LM, C, N,
+               st->start, st->end, oldBandE, oldLogE, oldLogE2, pulses, enc->rng);
+      }
 
       /* Synthesis */
       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
 
       /* Synthesis */
       denormalise_bands(st->mode, X, freq, bandE, effEnd, C, M);
@@ -1360,6 +1377,12 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, i
       oldBandE[i]=0;
    for (i=st->end;i<st->mode->nbEBands;i++)
       oldBandE[i]=0;
       oldBandE[i]=0;
    for (i=st->end;i<st->mode->nbEBands;i++)
       oldBandE[i]=0;
+   for (i=0;i<C*st->mode->nbEBands;i++)
+      oldLogE2[i] = oldLogE[i];
+   if (isTransient)
+      st->consec_transient++;
+   else
+      st->consec_transient=0;
 
    /* If there's any room left (can only happen for very high rates),
       fill it with zeros */
 
    /* If there's any room left (can only happen for very high rates),
       fill it with zeros */
@@ -1585,7 +1608,7 @@ struct CELTDecoder {
    
    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
    
    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
    /* celt_word16 lpc[],  Size = channels*LPC_ORDER */
-   /* celt_word16 oldEBands[], Size = channels*mode->nbEBands */
+   /* celt_word16 oldEBands[], Size = 2*channels*mode->nbEBands */
 };
 
 int celt_decoder_get_size(const CELTMode *mode, int channels)
 };
 
 int celt_decoder_get_size(const CELTMode *mode, int channels)
@@ -1593,7 +1616,7 @@ int celt_decoder_get_size(const CELTMode *mode, int channels)
    int size = sizeof(struct CELTDecoder)
             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
             + channels*LPC_ORDER*sizeof(celt_word16)
    int size = sizeof(struct CELTDecoder)
             + (channels*(DECODE_BUFFER_SIZE+mode->overlap)-1)*sizeof(celt_sig)
             + channels*LPC_ORDER*sizeof(celt_word16)
-            + channels*mode->nbEBands*sizeof(celt_word16);
+            + 2*channels*mode->nbEBands*sizeof(celt_word16);
    return size;
 }
 
    return size;
 }
 
@@ -1853,6 +1876,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    VARDECL(celt_sig, freq);
    VARDECL(celt_norm, X);
    VARDECL(celt_ener, bandE);
    VARDECL(celt_sig, freq);
    VARDECL(celt_norm, X);
    VARDECL(celt_ener, bandE);
+   VARDECL(celt_word16, oldLogE);
    VARDECL(int, fine_quant);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
    VARDECL(int, fine_quant);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
@@ -1863,7 +1887,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    celt_sig *overlap_mem[2];
    celt_sig *out_syn[2];
    celt_word16 *lpc;
    celt_sig *overlap_mem[2];
    celt_sig *out_syn[2];
    celt_word16 *lpc;
-   celt_word16 *oldBandE;
+   celt_word16 *oldBandE, *oldLogE2;
 
    int shortBlocks;
    int isTransient;
 
    int shortBlocks;
    int isTransient;
@@ -1881,6 +1905,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    celt_int32 tell;
    int dynalloc_logp;
    int postfilter_tapset;
    celt_int32 tell;
    int dynalloc_logp;
    int postfilter_tapset;
+   int anti_collapse_on=0;
    SAVE_STACK;
 
    if (pcm==NULL)
    SAVE_STACK;
 
    if (pcm==NULL)
@@ -1900,6 +1925,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    } while (++c<C);
    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
    oldBandE = lpc+C*LPC_ORDER;
    } while (++c<C);
    lpc = (celt_word16*)(st->_decode_mem+(DECODE_BUFFER_SIZE+st->overlap)*C);
    oldBandE = lpc+C*LPC_ORDER;
+   oldLogE2 = oldBandE + C*st->mode->nbEBands;
 
    N = M*st->mode->shortMdctSize;
 
 
    N = M*st->mode->shortMdctSize;
 
@@ -1976,6 +2002,10 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    else
       shortBlocks = 0;
 
    else
       shortBlocks = 0;
 
+   ALLOC(oldLogE, C*st->mode->nbEBands, celt_word16);
+   for (i=0;i<C*st->mode->nbEBands;i++)
+      oldLogE[i] = oldBandE[i];
+
    /* Decode the global flags (first symbols in the stream) */
    intra_ener = tell+3<=total_bits ? ec_dec_bit_logp(dec, 3) : 0;
    /* Get band energies */
    /* Decode the global flags (first symbols in the stream) */
    intra_ener = tell+3<=total_bits ? ec_dec_bit_logp(dec, 3) : 0;
    /* Get band energies */
@@ -2030,7 +2060,7 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
    alloc_trim = tell+(6<<BITRES) <= total_bits ?
          ec_dec_icdf(dec, trim_icdf, 7) : 5;
 
    alloc_trim = tell+(6<<BITRES) <= total_bits ?
          ec_dec_icdf(dec, trim_icdf, 7) : 5;
 
-   bits = (len*8<<BITRES) - ec_dec_tell(dec, BITRES) - 1;
+   bits = (len*8<<BITRES) - ec_dec_tell(dec, BITRES) - 1 - (isTransient&&LM>=2 ? (1<<BITRES) : 0);
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, &intensity, &dual_stereo, bits, pulses, fine_quant,
          fine_priority, C, LM, dec, 0, 0);
    codedBands = compute_allocation(st->mode, st->start, st->end, offsets,
          alloc_trim, &intensity, &dual_stereo, bits, pulses, fine_quant,
          fine_priority, C, LM, dec, 0, 0);
@@ -2042,9 +2072,18 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
          NULL, pulses, shortBlocks, spread_decision, dual_stereo, intensity, tf_res, 1,
          len*8, dec, LM, codedBands);
 
          NULL, pulses, shortBlocks, spread_decision, dual_stereo, intensity, tf_res, 1,
          len*8, dec, LM, codedBands);
 
+   if (isTransient && LM>=2)
+   {
+      anti_collapse_on = ec_dec_bits(dec, 1);
+   }
+
    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
 
    unquant_energy_finalise(st->mode, st->start, st->end, bandE, oldBandE,
          fine_quant, fine_priority, len*8-ec_dec_tell(dec, 0), dec, C);
 
+   if (anti_collapse_on)
+      anti_collapse(st->mode, X, LM, C, N,
+            st->start, st->end, oldBandE, oldLogE, oldLogE2, pulses, dec->rng);
+
    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
 
    /* Synthesis */
    log2Amp(st->mode, st->start, st->end, bandE, oldBandE, C);
 
    /* Synthesis */
@@ -2101,6 +2140,8 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da
       oldBandE[i]=0;
    for (i=st->end;i<st->mode->nbEBands;i++)
       oldBandE[i]=0;
       oldBandE[i]=0;
    for (i=st->end;i<st->mode->nbEBands;i++)
       oldBandE[i]=0;
+   for (i=0;i<C*st->mode->nbEBands;i++)
+      oldLogE2[i] = oldLogE[i];
 
    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
    st->loss_count = 0;
 
    deemphasis(out_syn, pcm, N, C, st->mode->preemph, st->preemph_memD);
    st->loss_count = 0;