Retrain RNN classifier weights to include reverberated speech
[opus.git] / src / analysis.c
index 4ad44d3..b192ae4 100644 (file)
@@ -50,6 +50,8 @@
 
 #ifndef DISABLE_FLOAT_API
 
+#define TRANSITION_PENALTY 10
+
 static const float dct_table[128] = {
         0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f,
         0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f,
@@ -224,19 +226,23 @@ void tonality_analysis_reset(TonalityAnalysisState *tonal)
   /* Clear non-reusable fields. */
   char *start = (char*)&tonal->TONALITY_ANALYSIS_RESET_START;
   OPUS_CLEAR(start, sizeof(TonalityAnalysisState) - (start - (char*)tonal));
-  tonal->music_confidence = .9f;
-  tonal->speech_confidence = .1f;
 }
 
 void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int len)
 {
    int pos;
    int curr_lookahead;
-   float psum;
    float tonality_max;
    float tonality_avg;
    int tonality_count;
    int i;
+   int pos0;
+   float prob_avg;
+   float prob_count;
+   float prob_min, prob_max;
+   float vad_prob;
+   int mpos, vpos;
+   int bandwidth_span;
 
    pos = tonal->read_pos;
    curr_lookahead = tonal->write_pos-tonal->read_pos;
@@ -254,9 +260,12 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
       pos--;
    if (pos<0)
       pos = DETECT_SIZE-1;
+   pos0 = pos;
    OPUS_COPY(info_out, &tonal->info[pos], 1);
    tonality_max = tonality_avg = info_out->tonality;
    tonality_count = 1;
+   /* Look at the neighbouring frames and pick largest bandwidth found (to be safe). */
+   bandwidth_span = 6;
    /* If possible, look ahead for a tone to compensate for the delay in the tone detector. */
    for (i=0;i<3;i++)
    {
@@ -268,8 +277,122 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
       tonality_max = MAX32(tonality_max, tonal->info[pos].tonality);
       tonality_avg += tonal->info[pos].tonality;
       tonality_count++;
+      info_out->bandwidth = IMAX(info_out->bandwidth, tonal->info[pos].bandwidth);
+      bandwidth_span--;
+   }
+   pos = pos0;
+   /* Look back in time to see if any has a wider bandwidth than the current frame. */
+   for (i=0;i<bandwidth_span;i++)
+   {
+      pos--;
+      if (pos < 0)
+         pos = DETECT_SIZE-1;
+      if (pos == tonal->write_pos)
+         break;
+      info_out->bandwidth = IMAX(info_out->bandwidth, tonal->info[pos].bandwidth);
+   }
+   info_out->tonality = MAX32(tonality_avg/tonality_count, tonality_max-.2f);
+
+   mpos = vpos = pos0;
+   /* If we have enough look-ahead, compensate for the ~5-frame delay in the music prob and
+      ~1 frame delay in the VAD prob. */
+   if (curr_lookahead > 15)
+   {
+      mpos += 5;
+      if (mpos>=DETECT_SIZE)
+         mpos -= DETECT_SIZE;
+      vpos += 1;
+      if (vpos>=DETECT_SIZE)
+         vpos -= DETECT_SIZE;
+   }
+
+   /* The following calculations attempt to minimize a "badness function"
+      for the transition. When switching from speech to music, the badness
+      of switching at frame k is
+      b_k = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T)
+      where
+      v_i is the activity probability (VAD) at frame i,
+      p_i is the music probability at frame i
+      T is the probability threshold for switching
+      S is the penalty for switching during active audio rather than silence
+      the current frame has index i=0
+
+      Rather than apply badness to directly decide when to switch, what we compute
+      instead is the threshold for which the optimal switching point is now. When
+      considering whether to switch now (frame 0) or at frame k, we have:
+      S*v_0 = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T)
+      which gives us:
+      T = ( \sum_{i=0}^{k-1} v_i*p_i + S*(v_k-v_0) ) / ( \sum_{i=0}^{k-1} v_i )
+      We take the min threshold across all positive values of k (up to the maximum
+      amount of lookahead we have) to give us the threshold for which the current
+      frame is the optimal switch point.
+
+      The last step is that we need to consider whether we want to switch at all.
+      For that we use the average of the music probability over the entire window.
+      If the threshold is higher than that average we're not going to
+      switch, so we compute a min with the average as well. The result of all these
+      min operations is music_prob_min, which gives the threshold for switching to music
+      if we're currently encoding for speech.
+
+      We do the exact opposite to compute music_prob_max which is used for switching
+      from music to speech.
+    */
+   prob_min = 1.f;
+   prob_max = 0.f;
+   vad_prob = tonal->info[vpos].activity_probability;
+   prob_count = MAX16(.1f, vad_prob);
+   prob_avg = MAX16(.1f, vad_prob)*tonal->info[mpos].music_prob;
+   while (1)
+   {
+      float pos_vad;
+      mpos++;
+      if (mpos==DETECT_SIZE)
+         mpos = 0;
+      if (mpos == tonal->write_pos)
+         break;
+      vpos++;
+      if (vpos==DETECT_SIZE)
+         vpos = 0;
+      if (vpos == tonal->write_pos)
+         break;
+      pos_vad = tonal->info[vpos].activity_probability;
+      prob_min = MIN16((prob_avg - TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_min);
+      prob_max = MAX16((prob_avg + TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_max);
+      prob_count += MAX16(.1f, pos_vad);
+      prob_avg += MAX16(.1f, pos_vad)*tonal->info[mpos].music_prob;
+   }
+   info_out->music_prob = prob_avg/prob_count;
+   prob_min = MIN16(prob_avg/prob_count, prob_min);
+   prob_max = MAX16(prob_avg/prob_count, prob_max);
+   prob_min = MAX16(prob_min, 0.f);
+   prob_max = MIN16(prob_max, 1.f);
+
+   /* If we don't have enough look-ahead, do our best to make a decent decision. */
+   if (curr_lookahead < 10)
+   {
+      float pmin, pmax;
+      pmin = prob_min;
+      pmax = prob_max;
+      pos = pos0;
+      /* Look for min/max in the past. */
+      for (i=0;i<IMIN(tonal->count-1, 15);i++)
+      {
+         pos--;
+         if (pos < 0)
+            pos = DETECT_SIZE-1;
+         pmin = MIN16(pmin, tonal->info[pos].music_prob);
+         pmax = MAX16(pmax, tonal->info[pos].music_prob);
+      }
+      /* Bias against switching on active audio. */
+      pmin = MAX16(0.f, pmin - .1f*vad_prob);
+      pmax = MIN16(1.f, pmax + .1f*vad_prob);
+      prob_min += (1.f-.1f*curr_lookahead)*(pmin - prob_min);
+      prob_max += (1.f-.1f*curr_lookahead)*(pmax - prob_max);
    }
-   info_out->tonality = MAX32(tonality_avg/tonality_count, tonality_max-.2);
+   info_out->music_prob_min = prob_min;
+   info_out->music_prob_max = prob_max;
+
+   /* printf("%f %f %f %f %f\n", prob_min, prob_max, prob_avg/prob_count, vad_prob, info_out->music_prob); */
    tonal->read_subframe += len/(tonal->Fs/400);
    while (tonal->read_subframe>=8)
    {
@@ -278,21 +401,6 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
    }
    if (tonal->read_pos>=DETECT_SIZE)
       tonal->read_pos-=DETECT_SIZE;
-
-   /* The -1 is to compensate for the delay in the features themselves. */
-   curr_lookahead = IMAX(curr_lookahead-1, 0);
-
-   psum=0;
-   /* Summing the probability of transition patterns that involve music at
-      time (DETECT_SIZE-curr_lookahead-1) */
-   for (i=0;i<DETECT_SIZE-curr_lookahead;i++)
-      psum += tonal->pmusic[i];
-   for (;i<DETECT_SIZE;i++)
-      psum += tonal->pspeech[i];
-   psum = psum*tonal->music_confidence + (1-psum)*tonal->speech_confidence;
-   /*printf("%f %f %f %f %f\n", psum, info_out->music_prob, info_out->vad_prob, info_out->activity_probability, info_out->tonality);*/
-
-   info_out->music_prob = psum;
 }
 
 static const float std_feature_bias[9] = {
@@ -340,6 +448,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     float alpha, alphaE, alphaE2;
     float frame_loudness;
     float bandwidth_mask;
+    int is_masked[NB_TBANDS+1];
     int bandwidth=0;
     float maxE = 0;
     float noise_floor;
@@ -352,11 +461,16 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     float band_log2[NB_TBANDS+1];
     float leakage_from[NB_TBANDS+1];
     float leakage_to[NB_TBANDS+1];
+    float layer_out[MAX_NEURONS];
+    float below_max_pitch;
+    float above_max_pitch;
     SAVE_STACK;
 
     alpha = 1.f/IMIN(10, 1+tonal->count);
     alphaE = 1.f/IMIN(25, 1+tonal->count);
-    alphaE2 = 1.f/IMIN(500, 1+tonal->count);
+    /* Noise floor related decay for bandwidth detection: -2.2 dB/second */
+    alphaE2 = 1.f/IMIN(100, 1+tonal->count);
+    if (tonal->count <= 1) alphaE2 = 1;
 
     if (tonal->Fs == 48000)
     {
@@ -368,12 +482,6 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
        offset = 3*offset/2;
     }
 
-    if (tonal->count<4) {
-       if (tonal->application == OPUS_APPLICATION_VOIP)
-          tonal->music_prob = .1;
-       else
-          tonal->music_prob = .625;
-    }
     kfft = celt_mode->mdct.kfft[0];
     if (tonal->count==0)
        tonal->mem_fill = 240;
@@ -632,9 +740,12 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     maxE = 0;
     noise_floor = 5.7e-4f/(1<<(IMAX(0,lsb_depth-8)));
     noise_floor *= noise_floor;
+    below_max_pitch=0;
+    above_max_pitch=0;
     for (b=0;b<NB_TBANDS;b++)
     {
        float E=0;
+       float Em;
        int band_start, band_end;
        /* Keep a margin of 300 Hz for aliasing */
        band_start = tbands[b];
@@ -647,36 +758,59 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
        }
        E = SCALE_ENER(E);
        maxE = MAX32(maxE, E);
+       if (band_start < 64)
+       {
+          below_max_pitch += E;
+       } else {
+          above_max_pitch += E;
+       }
        tonal->meanE[b] = MAX32((1-alphaE2)*tonal->meanE[b], E);
-       E = MAX32(E, tonal->meanE[b]);
-       /* Use a simple follower with 13 dB/Bark slope for spreading function */
-       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
+       Em = MAX32(E, tonal->meanE[b]);
        /* Consider the band "active" only if all these conditions are met:
-          1) less than 10 dB below the simple follower
-          2) less than 90 dB below the peak band (maximal masking possible considering
+          1) less than 90 dB below the peak band (maximal masking possible considering
              both the ATH and the loudness-dependent slope of the spreading function)
-          3) above the PCM quantization noise floor
+          2) above the PCM quantization noise floor
           We use b+1 because the first CELT band isn't included in tbands[]
        */
-       if (E>.1*bandwidth_mask && E*1e9f > maxE && E > noise_floor*(band_end-band_start))
+       if (E*1e9f > maxE && (Em > 3*noise_floor*(band_end-band_start) || E > noise_floor*(band_end-band_start)))
           bandwidth = b+1;
+       /* Check if the band is masked (see below). */
+       is_masked[b] = E < (tonal->prev_bandwidth >= b+1  ? .01f : .05f)*bandwidth_mask;
+       /* Use a simple follower with 13 dB/Bark slope for spreading function. */
+       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
     }
     /* Special case for the last two bands, for which we don't have spectrum but only
-       the energy above 12 kHz. */
-    {
-       float E = hp_ener*(1.f/(240*240));
+       the energy above 12 kHz. The difficulty here is that the high-pass we use
+       leaks some LF energy, so we need to increase the threshold without accidentally cutting
+       off the band. */
+    if (tonal->Fs == 48000) {
+       float noise_ratio;
+       float Em;
+       float E = hp_ener*(1.f/(60*60));
+       noise_ratio = tonal->prev_bandwidth==20 ? 10.f : 30.f;
+
 #ifdef FIXED_POINT
        /* silk_resampler_down2_hp() shifted right by an extra 8 bits. */
        E *= 256.f*(1.f/Q15ONE)*(1.f/Q15ONE);
 #endif
-       maxE = MAX32(maxE, E);
+       above_max_pitch += E;
        tonal->meanE[b] = MAX32((1-alphaE2)*tonal->meanE[b], E);
-       E = MAX32(E, tonal->meanE[b]);
-       /* Use a simple follower with 13 dB/Bark slope for spreading function */
-       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
-       if (E>.1*bandwidth_mask && E*1e9f > maxE && E > noise_floor*160)
+       Em = MAX32(E, tonal->meanE[b]);
+       if (Em > 3*noise_ratio*noise_floor*160 || E > noise_ratio*noise_floor*160)
           bandwidth = 20;
+       /* Check if the band is masked (see below). */
+       is_masked[b] = E < (tonal->prev_bandwidth == 20  ? .01f : .05f)*bandwidth_mask;
     }
+    if (above_max_pitch > below_max_pitch)
+       info->max_pitch_ratio = below_max_pitch/above_max_pitch;
+    else
+       info->max_pitch_ratio = 1;
+    /* In some cases, resampling aliasing can create a small amount of energy in the first band
+       being cut. So if the last band is masked, we don't include it.  */
+    if (bandwidth == 20 && is_masked[NB_TBANDS])
+       bandwidth-=2;
+    else if (bandwidth > 0 && bandwidth <= NB_TBANDS && is_masked[bandwidth-1])
+       bandwidth--;
     if (tonal->count<=2)
        bandwidth = 20;
     frame_loudness = 20*(float)log10(frame_loudness);
@@ -703,7 +837,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     frame_stationarity /= NB_TBANDS;
     relativeE /= NB_TBANDS;
     if (tonal->count<10)
-       relativeE = .5;
+       relativeE = .5f;
     frame_noisiness /= NB_TBANDS;
 #if 1
     info->activity = frame_noisiness + (1-frame_noisiness)*relativeE;
@@ -756,139 +890,15 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     features[23] = info->tonality_slope + 0.069216f;
     features[24] = tonal->lowECount - 0.067930f;
 
-    mlp_process(&net, features, frame_probs);
-    frame_probs[0] = .5f*(frame_probs[0]+1);
-    /* Curve fitting between the MLP probability and the actual probability */
-    /*frame_probs[0] = .01f + 1.21f*frame_probs[0]*frame_probs[0] - .23f*(float)pow(frame_probs[0], 10);*/
-    /* Probability of active audio (as opposed to silence) */
-    frame_probs[1] = .5f*frame_probs[1]+.5f;
-    frame_probs[1] *= frame_probs[1];
+    compute_dense(&layer0, layer_out, features);
+    compute_gru(&layer1, tonal->rnn_state, layer_out);
+    compute_dense(&layer2, frame_probs, tonal->rnn_state);
 
     /* Probability of speech or music vs noise */
     info->activity_probability = frame_probs[1];
+    info->music_prob = frame_probs[0];
 
-    /*printf("%f %f\n", frame_probs[0], frame_probs[1]);*/
-    {
-       /* Probability of state transition */
-       float tau;
-       /* Represents independence of the MLP probabilities, where
-          beta=1 means fully independent. */
-       float beta;
-       /* Denormalized probability of speech (p0) and music (p1) after update */
-       float p0, p1;
-       /* Probabilities for "all speech" and "all music" */
-       float s0, m0;
-       /* Probability sum for renormalisation */
-       float psum;
-       /* Instantaneous probability of speech and music, with beta pre-applied. */
-       float speech0;
-       float music0;
-       float p, q;
-
-       /* More silence transitions for speech than for music. */
-       tau = .001f*tonal->music_prob + .01f*(1-tonal->music_prob);
-       p = MAX16(.05f,MIN16(.95f,frame_probs[1]));
-       q = MAX16(.05f,MIN16(.95f,tonal->vad_prob));
-       beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p));
-       /* p0 and p1 are the probabilities of speech and music at this frame
-          using only information from previous frame and applying the
-          state transition model */
-       p0 = (1-tonal->vad_prob)*(1-tau) +    tonal->vad_prob *tau;
-       p1 =    tonal->vad_prob *(1-tau) + (1-tonal->vad_prob)*tau;
-       /* We apply the current probability with exponent beta to work around
-          the fact that the probability estimates aren't independent. */
-       p0 *= (float)pow(1-frame_probs[1], beta);
-       p1 *= (float)pow(frame_probs[1], beta);
-       /* Normalise the probabilities to get the Marokv probability of music. */
-       tonal->vad_prob = p1/(p0+p1);
-       info->vad_prob = tonal->vad_prob;
-       /* Consider that silence has a 50-50 probability of being speech or music. */
-       frame_probs[0] = tonal->vad_prob*frame_probs[0] + (1-tonal->vad_prob)*.5f;
-
-       /* One transition every 3 minutes of active audio */
-       tau = .0001f;
-       /* Adapt beta based on how "unexpected" the new prob is */
-       p = MAX16(.05f,MIN16(.95f,frame_probs[0]));
-       q = MAX16(.05f,MIN16(.95f,tonal->music_prob));
-       beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p));
-       /* p0 and p1 are the probabilities of speech and music at this frame
-          using only information from previous frame and applying the
-          state transition model */
-       p0 = (1-tonal->music_prob)*(1-tau) +    tonal->music_prob *tau;
-       p1 =    tonal->music_prob *(1-tau) + (1-tonal->music_prob)*tau;
-       /* We apply the current probability with exponent beta to work around
-          the fact that the probability estimates aren't independent. */
-       p0 *= (float)pow(1-frame_probs[0], beta);
-       p1 *= (float)pow(frame_probs[0], beta);
-       /* Normalise the probabilities to get the Marokv probability of music. */
-       tonal->music_prob = p1/(p0+p1);
-       info->music_prob = tonal->music_prob;
-
-       /*printf("%f %f %f %f\n", frame_probs[0], frame_probs[1], tonal->music_prob, tonal->vad_prob);*/
-       /* This chunk of code deals with delayed decision. */
-       psum=1e-20f;
-       /* Instantaneous probability of speech and music, with beta pre-applied. */
-       speech0 = (float)pow(1-frame_probs[0], beta);
-       music0  = (float)pow(frame_probs[0], beta);
-       if (tonal->count==1)
-       {
-          if (tonal->application == OPUS_APPLICATION_VOIP)
-             tonal->pmusic[0] = .1;
-          else
-             tonal->pmusic[0] = .625;
-          tonal->pspeech[0] = 1-tonal->pmusic[0];
-       }
-       /* Updated probability of having only speech (s0) or only music (m0),
-          before considering the new observation. */
-       s0 = tonal->pspeech[0] + tonal->pspeech[1];
-       m0 = tonal->pmusic [0] + tonal->pmusic [1];
-       /* Updates s0 and m0 with instantaneous probability. */
-       tonal->pspeech[0] = s0*(1-tau)*speech0;
-       tonal->pmusic [0] = m0*(1-tau)*music0;
-       /* Propagate the transition probabilities */
-       for (i=1;i<DETECT_SIZE-1;i++)
-       {
-          tonal->pspeech[i] = tonal->pspeech[i+1]*speech0;
-          tonal->pmusic [i] = tonal->pmusic [i+1]*music0;
-       }
-       /* Probability that the latest frame is speech, when all the previous ones were music. */
-       tonal->pspeech[DETECT_SIZE-1] = m0*tau*speech0;
-       /* Probability that the latest frame is music, when all the previous ones were speech. */
-       tonal->pmusic [DETECT_SIZE-1] = s0*tau*music0;
-
-       /* Renormalise probabilities to 1 */
-       for (i=0;i<DETECT_SIZE;i++)
-          psum += tonal->pspeech[i] + tonal->pmusic[i];
-       psum = 1.f/psum;
-       for (i=0;i<DETECT_SIZE;i++)
-       {
-          tonal->pspeech[i] *= psum;
-          tonal->pmusic [i] *= psum;
-       }
-       psum = tonal->pmusic[0];
-       for (i=1;i<DETECT_SIZE;i++)
-          psum += tonal->pspeech[i];
-
-       /* Estimate our confidence in the speech/music decisions */
-       if (frame_probs[1]>.75)
-       {
-          if (tonal->music_prob>.9)
-          {
-             float adapt;
-             adapt = 1.f/(++tonal->music_confidence_count);
-             tonal->music_confidence_count = IMIN(tonal->music_confidence_count, 500);
-             tonal->music_confidence += adapt*MAX16(-.2f,frame_probs[0]-tonal->music_confidence);
-          }
-          if (tonal->music_prob<.1)
-          {
-             float adapt;
-             adapt = 1.f/(++tonal->speech_confidence_count);
-             tonal->speech_confidence_count = IMIN(tonal->speech_confidence_count, 500);
-             tonal->speech_confidence += adapt*MIN16(.2f,frame_probs[0]-tonal->speech_confidence);
-          }
-       }
-    }
-    tonal->last_music = tonal->music_prob>.5f;
+    /*printf("%f %f %f\n", frame_probs[0], frame_probs[1], info->music_prob);*/
 #ifdef MLP_TRAINING
     for (i=0;i<25;i++)
        printf("%f ", features[i]);
@@ -896,6 +906,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
 #endif
 
     info->bandwidth = bandwidth;
+    tonal->prev_bandwidth = bandwidth;
     /*printf("%d %d\n", info->bandwidth, info->opus_bandwidth);*/
     info->noisiness = frame_noisiness;
     info->valid = 1;