Retrain RNN classifier weights to include reverberated speech
[opus.git] / src / analysis.c
index 1d6dd82..b192ae4 100644 (file)
@@ -242,6 +242,7 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
    float prob_min, prob_max;
    float vad_prob;
    int mpos, vpos;
+   int bandwidth_span;
 
    pos = tonal->read_pos;
    curr_lookahead = tonal->write_pos-tonal->read_pos;
@@ -263,6 +264,8 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
    OPUS_COPY(info_out, &tonal->info[pos], 1);
    tonality_max = tonality_avg = info_out->tonality;
    tonality_count = 1;
+   /* Look at the neighbouring frames and pick largest bandwidth found (to be safe). */
+   bandwidth_span = 6;
    /* If possible, look ahead for a tone to compensate for the delay in the tone detector. */
    for (i=0;i<3;i++)
    {
@@ -274,6 +277,19 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
       tonality_max = MAX32(tonality_max, tonal->info[pos].tonality);
       tonality_avg += tonal->info[pos].tonality;
       tonality_count++;
+      info_out->bandwidth = IMAX(info_out->bandwidth, tonal->info[pos].bandwidth);
+      bandwidth_span--;
+   }
+   pos = pos0;
+   /* Look back in time to see if any has a wider bandwidth than the current frame. */
+   for (i=0;i<bandwidth_span;i++)
+   {
+      pos--;
+      if (pos < 0)
+         pos = DETECT_SIZE-1;
+      if (pos == tonal->write_pos)
+         break;
+      info_out->bandwidth = IMAX(info_out->bandwidth, tonal->info[pos].bandwidth);
    }
    info_out->tonality = MAX32(tonality_avg/tonality_count, tonality_max-.2f);
 
@@ -432,6 +448,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     float alpha, alphaE, alphaE2;
     float frame_loudness;
     float bandwidth_mask;
+    int is_masked[NB_TBANDS+1];
     int bandwidth=0;
     float maxE = 0;
     float noise_floor;
@@ -445,11 +462,15 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     float leakage_from[NB_TBANDS+1];
     float leakage_to[NB_TBANDS+1];
     float layer_out[MAX_NEURONS];
+    float below_max_pitch;
+    float above_max_pitch;
     SAVE_STACK;
 
     alpha = 1.f/IMIN(10, 1+tonal->count);
     alphaE = 1.f/IMIN(25, 1+tonal->count);
-    alphaE2 = 1.f/IMIN(500, 1+tonal->count);
+    /* Noise floor related decay for bandwidth detection: -2.2 dB/second */
+    alphaE2 = 1.f/IMIN(100, 1+tonal->count);
+    if (tonal->count <= 1) alphaE2 = 1;
 
     if (tonal->Fs == 48000)
     {
@@ -719,9 +740,12 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     maxE = 0;
     noise_floor = 5.7e-4f/(1<<(IMAX(0,lsb_depth-8)));
     noise_floor *= noise_floor;
+    below_max_pitch=0;
+    above_max_pitch=0;
     for (b=0;b<NB_TBANDS;b++)
     {
        float E=0;
+       float Em;
        int band_start, band_end;
        /* Keep a margin of 300 Hz for aliasing */
        band_start = tbands[b];
@@ -734,41 +758,59 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
        }
        E = SCALE_ENER(E);
        maxE = MAX32(maxE, E);
+       if (band_start < 64)
+       {
+          below_max_pitch += E;
+       } else {
+          above_max_pitch += E;
+       }
        tonal->meanE[b] = MAX32((1-alphaE2)*tonal->meanE[b], E);
-       E = MAX32(E, tonal->meanE[b]);
-       /* Use a simple follower with 13 dB/Bark slope for spreading function */
-       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
+       Em = MAX32(E, tonal->meanE[b]);
        /* Consider the band "active" only if all these conditions are met:
-          1) less than 10 dB below the simple follower
-          2) less than 90 dB below the peak band (maximal masking possible considering
+          1) less than 90 dB below the peak band (maximal masking possible considering
              both the ATH and the loudness-dependent slope of the spreading function)
-          3) above the PCM quantization noise floor
+          2) above the PCM quantization noise floor
           We use b+1 because the first CELT band isn't included in tbands[]
        */
-       if (E>.1*bandwidth_mask && E*1e9f > maxE && E > noise_floor*(band_end-band_start))
+       if (E*1e9f > maxE && (Em > 3*noise_floor*(band_end-band_start) || E > noise_floor*(band_end-band_start)))
           bandwidth = b+1;
+       /* Check if the band is masked (see below). */
+       is_masked[b] = E < (tonal->prev_bandwidth >= b+1  ? .01f : .05f)*bandwidth_mask;
+       /* Use a simple follower with 13 dB/Bark slope for spreading function. */
+       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
     }
     /* Special case for the last two bands, for which we don't have spectrum but only
-       the energy above 12 kHz. */
+       the energy above 12 kHz. The difficulty here is that the high-pass we use
+       leaks some LF energy, so we need to increase the threshold without accidentally cutting
+       off the band. */
     if (tonal->Fs == 48000) {
-       float ratio;
-       float E = hp_ener*(1.f/(240*240));
-       ratio = tonal->prev_bandwidth==20 ? 0.03f : 0.07f;
+       float noise_ratio;
+       float Em;
+       float E = hp_ener*(1.f/(60*60));
+       noise_ratio = tonal->prev_bandwidth==20 ? 10.f : 30.f;
+
 #ifdef FIXED_POINT
        /* silk_resampler_down2_hp() shifted right by an extra 8 bits. */
        E *= 256.f*(1.f/Q15ONE)*(1.f/Q15ONE);
 #endif
-       maxE = MAX32(maxE, E);
+       above_max_pitch += E;
        tonal->meanE[b] = MAX32((1-alphaE2)*tonal->meanE[b], E);
-       E = MAX32(E, tonal->meanE[b]);
-       /* Use a simple follower with 13 dB/Bark slope for spreading function */
-       bandwidth_mask = MAX32(.05f*bandwidth_mask, E);
-       if (E>ratio*bandwidth_mask && E*1e9f > maxE && E > noise_floor*160)
-          bandwidth = 20;
-       /* This detector is unreliable, so if the bandwidth is close to SWB, assume it's FB. */
-       if (bandwidth >= 17)
+       Em = MAX32(E, tonal->meanE[b]);
+       if (Em > 3*noise_ratio*noise_floor*160 || E > noise_ratio*noise_floor*160)
           bandwidth = 20;
+       /* Check if the band is masked (see below). */
+       is_masked[b] = E < (tonal->prev_bandwidth == 20  ? .01f : .05f)*bandwidth_mask;
     }
+    if (above_max_pitch > below_max_pitch)
+       info->max_pitch_ratio = below_max_pitch/above_max_pitch;
+    else
+       info->max_pitch_ratio = 1;
+    /* In some cases, resampling aliasing can create a small amount of energy in the first band
+       being cut. So if the last band is masked, we don't include it.  */
+    if (bandwidth == 20 && is_masked[NB_TBANDS])
+       bandwidth-=2;
+    else if (bandwidth > 0 && bandwidth <= NB_TBANDS && is_masked[bandwidth-1])
+       bandwidth--;
     if (tonal->count<=2)
        bandwidth = 20;
     frame_loudness = 20*(float)log10(frame_loudness);
@@ -854,9 +896,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
 
     /* Probability of speech or music vs noise */
     info->activity_probability = frame_probs[1];
-    /* It seems like the RNN tends to have a bias towards speech and this
-       warping of the probabilities compensates for it. */
-    info->music_prob = frame_probs[0] * (2 - frame_probs[0]);
+    info->music_prob = frame_probs[0];
 
     /*printf("%f %f %f\n", frame_probs[0], frame_probs[1], info->music_prob);*/
 #ifdef MLP_TRAINING