Improved transient_analysis() by adding one frame of memory.
[opus.git] / libcelt / celt.c
index eeb3418..90c484d 100644 (file)
@@ -95,6 +95,7 @@ struct CELTEncoder {
    celt_word16 tonal_average;
    int fold_decision;
    celt_word16 gain_prod;
+   celt_word32 frame_max;
 
    /* VBR-related parameters */
    celt_int32 vbr_reservoir;
@@ -182,7 +183,7 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode, int channels, int *error)
    st->pitch_available = 1;
    st->force_intra  = 0;
    st->delayedIntra = 1;
-   st->tonal_average = QCONST16(1.,8);
+   st->tonal_average = QCONST16(1.f,8);
    st->fold_decision = 1;
 
    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig));
@@ -266,48 +267,61 @@ static inline celt_word16 SIG2WORD16(celt_sig x)
 #endif
 }
 
-static int transient_analysis(celt_word32 *in, int len, int C, int *transient_time, int *transient_shift)
+static int transient_analysis(const celt_word32 * restrict in, int len, int C,
+                              int *transient_time, int *transient_shift,
+                              celt_word32 *frame_max)
 {
-   int c, i, n;
+   int i, n;
    celt_word32 ratio;
+   celt_word32 threshold;
    VARDECL(celt_word32, begin);
    SAVE_STACK;
-   ALLOC(begin, len, celt_word32);
-   for (i=0;i<len;i++)
-      begin[i] = ABS32(SHR32(in[C*i],SIG_SHIFT));
-   for (c=1;c<C;c++)
+   ALLOC(begin, len+1, celt_word32);
+   begin[0] = 0;
+   if (C==1)
    {
       for (i=0;i<len;i++)
-         begin[i] = MAX32(begin[i], ABS32(SHR32(in[C*i+c],SIG_SHIFT)));
+         begin[i+1] = MAX32(begin[i], ABS32(in[i]));
+   } else {
+      for (i=0;i<len;i++)
+         begin[i+1] = MAX32(begin[i], MAX32(ABS32(in[C*i]),
+                                            ABS32(in[C*i+1])));
    }
-   for (i=1;i<len;i++)
-      begin[i] = MAX32(begin[i-1],begin[i]);
    n = -1;
-   for (i=8;i<len-8;i++)
+
+   threshold = MULT16_32_Q15(QCONST16(.2f,15),begin[len]);
+   /* If the following condition isn't met, there's just no way
+      we'll have a transient*/
+   if (*frame_max < threshold)
    {
-      if (begin[i] < MULT16_32_Q15(QCONST16(.2f,15),begin[len-1]))
-         n=i;
+      /* It's likely we have a transient, now find it */
+      for (i=8;i<len-8;i++)
+      {
+         if (begin[i+1] < threshold)
+            n=i;
+      }
    }
    if (n<32)
    {
       n = -1;
       ratio = 0;
    } else {
-      ratio = DIV32(begin[len-1],1+begin[n-16]);
+      ratio = DIV32(begin[len],1+MAX32(*frame_max, begin[n-16]));
    }
    if (ratio < 0)
       ratio = 0;
    if (ratio > 1000)
       ratio = 1000;
    ratio *= ratio;
-   
+
    if (ratio > 2048)
       *transient_shift = 3;
    else
       *transient_shift = 0;
    
    *transient_time = n;
-   
+   *frame_max = begin[len];
+
    RESTORE_STACK;
    return ratio > 20;
 }
@@ -491,22 +505,28 @@ static void decode_flags(ec_dec *dec, int *intra_ener, int *has_pitch, int *shor
    /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/
 }
 
-static void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, celt_word16 coef, celt_sig *mem)
+void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, celt_word16 coef, celt_sig *mem)
 {
    const int C = CHANNELS(_C);
    int c;
    for (c=0;c<C;c++)
    {
       int j;
+      celt_sig * restrict x;
+      celt_word16  * restrict y;
+      celt_sig m = mem[c];
+      x = &in[C*(MAX_PERIOD-N)+c];
+      y = pcm+c;
       for (j=0;j<N;j++)
       {
-         celt_sig tmp = MAC16_32_Q15(in[C*(MAX_PERIOD-N)+C*j+c],
-                                       coef,mem[c]);
-         mem[c] = tmp;
-         pcm[C*j+c] = SCALEOUT(SIG2WORD16(tmp));
+         celt_sig tmp = MAC16_32_Q15(*x, coef,m);
+         m = tmp;
+         *y = SCALEOUT(SIG2WORD16(tmp));
+         x+=C;
+         y+=C;
       }
+      mem[c] = m;
    }
-
 }
 
 static void mdct_shape(const CELTMode *mode, celt_norm *X, int start, int end, int N, int nbShortMdcts, int mdct_weight_shift, int _C)
@@ -555,6 +575,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig
    int shortBlocks=0;
    int transient_time;
    int transient_shift;
+   int resynth;
    const int C = CHANNELS(st->channels);
    int mdct_weight_shift = 0;
    int mdct_weight_pos=0;
@@ -604,7 +625,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig
    transient_shift = 0;
    shortBlocks = 0;
 
-   if (st->mode->nbShortMdcts > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
+   resynth = st->pitch_available>0 || optional_synthesis!=NULL;
+
+   if (st->mode->nbShortMdcts > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift, &st->frame_max))
    {
 #ifndef FIXED_POINT
       float gain_1;
@@ -848,16 +871,16 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig
 
    /* Residual quantisation */
    if (C==1)
-      quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, 1, &enc);
+      quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, 1, &enc);
 #ifndef DISABLE_STEREO
    else
-      quant_bands_stereo(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
+      quant_bands_stereo(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, &enc);
 #endif
 
    quant_energy_finalise(st->mode, start, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(&enc, 0), &enc, C);
 
    /* Re-synthesis of the coded audio if required */
-   if (st->pitch_available>0 || optional_synthesis!=NULL)
+   if (resynth)
    {
       if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
         st->pitch_available+=st->frame_size;
@@ -1043,13 +1066,14 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
          st->delayedIntra = 1;
 
          st->fold_decision = 1;
-         st->tonal_average = QCONST16(1.,8);
+         st->tonal_average = QCONST16(1.f,8);
          st->gain_prod = 0;
          st->vbr_reservoir = 0;
          st->vbr_drift = 0;
          st->vbr_offset = 0;
          st->vbr_count = 0;
          st->xmem = 0;
+         st->frame_max = 0;
          CELT_MEMSET(st->pitch_buf, 0, (MAX_PERIOD>>1)+2);
       }
       break;
@@ -1521,7 +1545,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
 
    /* Decode fixed codebook and merge with pitch */
    if (C==1)
-      quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, len*8, 0, &dec);
+      quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, 1, len*8, 0, &dec);
 #ifndef DISABLE_STEREO
    else
       unquant_bands_stereo(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, len*8, &dec);