Add RNN for VAD and speech/music classification
authorJean-Marc Valin <jmvalin@jmvalin.ca>
Wed, 12 Jul 2017 20:55:28 +0000 (16:55 -0400)
committerJean-Marc Valin <jmvalin@jmvalin.ca>
Thu, 5 Oct 2017 21:40:27 +0000 (17:40 -0400)
Based on two dense layers with a GRU layer in the middle

celt/celt.h
scripts/dump_rnn.py [new file with mode: 0755]
scripts/rnn_train.py [new file with mode: 0755]
src/analysis.c
src/analysis.h
src/mlp.c
src/mlp.h
src/mlp_data.c
src/mlp_train.c [deleted file]
src/mlp_train.h [deleted file]
src/opus_encoder.c

index 7017530..f73f29d 100644 (file)
@@ -59,7 +59,8 @@ typedef struct {
    float noisiness;
    float activity;
    float music_prob;
-   float vad_prob;
+   float music_prob_min;
+   float music_prob_max;
    int   bandwidth;
    float activity_probability;
    /* Store as Q6 char to save space. */
diff --git a/scripts/dump_rnn.py b/scripts/dump_rnn.py
new file mode 100755 (executable)
index 0000000..dd66403
--- /dev/null
@@ -0,0 +1,57 @@
+#!/usr/bin/python
+
+from __future__ import print_function
+
+from keras.models import Sequential
+from keras.layers import Dense
+from keras.layers import LSTM
+from keras.layers import GRU
+from keras.models import load_model
+from keras import backend as K
+
+import numpy as np
+
+def printVector(f, vector, name):
+    v = np.reshape(vector, (-1));
+    #print('static const float ', name, '[', len(v), '] = \n', file=f)
+    f.write('static const opus_int16 {}[{}] = {{\n   '.format(name, len(v)))
+    for i in range(0, len(v)):
+        f.write('{}'.format(int(round(8192*v[i]))))
+        if (i!=len(v)-1):
+            f.write(',')
+        else:
+            break;
+        if (i%8==7):
+            f.write("\n   ")
+        else:
+            f.write(" ")
+    #print(v, file=f)
+    f.write('\n};\n\n')
+    return;
+
+def binary_crossentrop2(y_true, y_pred):
+        return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
+
+
+model = load_model("weights.hdf5", custom_objects={'binary_crossentrop2': binary_crossentrop2})
+
+weights = model.get_weights()
+
+f = open('rnn_weights.c', 'w')
+
+f.write('/*This file is automatically generated from a Keras model*/\n\n')
+f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "mlp.h"\n\n')
+
+printVector(f, weights[0], 'layer0_weights')
+printVector(f, weights[1], 'layer0_bias')
+printVector(f, weights[2], 'layer1_weights')
+printVector(f, weights[3], 'layer1_recur_weights')
+printVector(f, weights[4], 'layer1_bias')
+printVector(f, weights[5], 'layer2_weights')
+printVector(f, weights[6], 'layer2_bias')
+
+f.write('const DenseLayer layer0 = {\n   layer0_bias,\n   layer0_weights,\n   25, 16, 0\n};\n\n')
+f.write('const GRULayer layer1 = {\n   layer1_bias,\n   layer1_weights,\n   layer1_recur_weights,\n   16, 12\n};\n\n')
+f.write('const DenseLayer layer2 = {\n   layer2_bias,\n   layer2_weights,\n   12, 2, 1\n};\n\n')
+
+f.close()
diff --git a/scripts/rnn_train.py b/scripts/rnn_train.py
new file mode 100755 (executable)
index 0000000..ffdaa1e
--- /dev/null
@@ -0,0 +1,67 @@
+#!/usr/bin/python
+
+from __future__ import print_function
+
+from keras.models import Sequential
+from keras.models import Model
+from keras.layers import Input
+from keras.layers import Dense
+from keras.layers import LSTM
+from keras.layers import GRU
+from keras.layers import SimpleRNN
+from keras.layers import Dropout
+from keras import losses
+import h5py
+
+from keras import backend as K
+import numpy as np
+
+def binary_crossentrop2(y_true, y_pred):
+    return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
+
+print('Build model...')
+#model = Sequential()
+#model.add(Dense(16, activation='tanh', input_shape=(None, 25)))
+#model.add(GRU(12, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True))
+#model.add(Dense(2, activation='sigmoid'))
+
+main_input = Input(shape=(None, 25), name='main_input')
+x = Dense(16, activation='tanh')(main_input)
+x = GRU(12, dropout=0.1, recurrent_dropout=0.1, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x)
+x = Dense(2, activation='sigmoid')(x)
+model = Model(inputs=main_input, outputs=x)
+
+batch_size = 64
+
+print('Loading data...')
+with h5py.File('features.h5', 'r') as hf:
+    all_data = hf['features'][:]
+print('done.')
+
+window_size = 1500
+
+nb_sequences = len(all_data)/window_size
+print(nb_sequences, ' sequences')
+x_train = all_data[:nb_sequences*window_size, :-2]
+x_train = np.reshape(x_train, (nb_sequences, window_size, 25))
+
+y_train = np.copy(all_data[:nb_sequences*window_size, -2:])
+y_train = np.reshape(y_train, (nb_sequences, window_size, 2))
+
+all_data = 0;
+x_train = x_train.astype('float32')
+y_train = y_train.astype('float32')
+
+print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
+
+# try using different optimizers and different optimizer configs
+model.compile(loss=binary_crossentrop2,
+              optimizer='adam',
+              metrics=['binary_accuracy'])
+
+print('Train...')
+model.fit(x_train, y_train,
+          batch_size=batch_size,
+          epochs=200,
+          validation_data=(x_train, y_train))
+model.save("newweights.hdf5")
index f4160e4..1d6dd82 100644 (file)
@@ -50,6 +50,8 @@
 
 #ifndef DISABLE_FLOAT_API
 
+#define TRANSITION_PENALTY 10
+
 static const float dct_table[128] = {
         0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f,
         0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f,
@@ -224,19 +226,22 @@ void tonality_analysis_reset(TonalityAnalysisState *tonal)
   /* Clear non-reusable fields. */
   char *start = (char*)&tonal->TONALITY_ANALYSIS_RESET_START;
   OPUS_CLEAR(start, sizeof(TonalityAnalysisState) - (start - (char*)tonal));
-  tonal->music_confidence = .9f;
-  tonal->speech_confidence = .1f;
 }
 
 void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int len)
 {
    int pos;
    int curr_lookahead;
-   float psum;
    float tonality_max;
    float tonality_avg;
    int tonality_count;
    int i;
+   int pos0;
+   float prob_avg;
+   float prob_count;
+   float prob_min, prob_max;
+   float vad_prob;
+   int mpos, vpos;
 
    pos = tonal->read_pos;
    curr_lookahead = tonal->write_pos-tonal->read_pos;
@@ -254,6 +259,7 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
       pos--;
    if (pos<0)
       pos = DETECT_SIZE-1;
+   pos0 = pos;
    OPUS_COPY(info_out, &tonal->info[pos], 1);
    tonality_max = tonality_avg = info_out->tonality;
    tonality_count = 1;
@@ -270,6 +276,107 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
       tonality_count++;
    }
    info_out->tonality = MAX32(tonality_avg/tonality_count, tonality_max-.2f);
+
+   mpos = vpos = pos0;
+   /* If we have enough look-ahead, compensate for the ~5-frame delay in the music prob and
+      ~1 frame delay in the VAD prob. */
+   if (curr_lookahead > 15)
+   {
+      mpos += 5;
+      if (mpos>=DETECT_SIZE)
+         mpos -= DETECT_SIZE;
+      vpos += 1;
+      if (vpos>=DETECT_SIZE)
+         vpos -= DETECT_SIZE;
+   }
+
+   /* The following calculations attempt to minimize a "badness function"
+      for the transition. When switching from speech to music, the badness
+      of switching at frame k is
+      b_k = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T)
+      where
+      v_i is the activity probability (VAD) at frame i,
+      p_i is the music probability at frame i
+      T is the probability threshold for switching
+      S is the penalty for switching during active audio rather than silence
+      the current frame has index i=0
+
+      Rather than apply badness to directly decide when to switch, what we compute
+      instead is the threshold for which the optimal switching point is now. When
+      considering whether to switch now (frame 0) or at frame k, we have:
+      S*v_0 = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T)
+      which gives us:
+      T = ( \sum_{i=0}^{k-1} v_i*p_i + S*(v_k-v_0) ) / ( \sum_{i=0}^{k-1} v_i )
+      We take the min threshold across all positive values of k (up to the maximum
+      amount of lookahead we have) to give us the threshold for which the current
+      frame is the optimal switch point.
+
+      The last step is that we need to consider whether we want to switch at all.
+      For that we use the average of the music probability over the entire window.
+      If the threshold is higher than that average we're not going to
+      switch, so we compute a min with the average as well. The result of all these
+      min operations is music_prob_min, which gives the threshold for switching to music
+      if we're currently encoding for speech.
+
+      We do the exact opposite to compute music_prob_max which is used for switching
+      from music to speech.
+    */
+   prob_min = 1.f;
+   prob_max = 0.f;
+   vad_prob = tonal->info[vpos].activity_probability;
+   prob_count = MAX16(.1f, vad_prob);
+   prob_avg = MAX16(.1f, vad_prob)*tonal->info[mpos].music_prob;
+   while (1)
+   {
+      float pos_vad;
+      mpos++;
+      if (mpos==DETECT_SIZE)
+         mpos = 0;
+      if (mpos == tonal->write_pos)
+         break;
+      vpos++;
+      if (vpos==DETECT_SIZE)
+         vpos = 0;
+      if (vpos == tonal->write_pos)
+         break;
+      pos_vad = tonal->info[vpos].activity_probability;
+      prob_min = MIN16((prob_avg - TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_min);
+      prob_max = MAX16((prob_avg + TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_max);
+      prob_count += MAX16(.1f, pos_vad);
+      prob_avg += MAX16(.1f, pos_vad)*tonal->info[mpos].music_prob;
+   }
+   info_out->music_prob = prob_avg/prob_count;
+   prob_min = MIN16(prob_avg/prob_count, prob_min);
+   prob_max = MAX16(prob_avg/prob_count, prob_max);
+   prob_min = MAX16(prob_min, 0.f);
+   prob_max = MIN16(prob_max, 1.f);
+
+   /* If we don't have enough look-ahead, do our best to make a decent decision. */
+   if (curr_lookahead < 10)
+   {
+      float pmin, pmax;
+      pmin = prob_min;
+      pmax = prob_max;
+      pos = pos0;
+      /* Look for min/max in the past. */
+      for (i=0;i<IMIN(tonal->count-1, 15);i++)
+      {
+         pos--;
+         if (pos < 0)
+            pos = DETECT_SIZE-1;
+         pmin = MIN16(pmin, tonal->info[pos].music_prob);
+         pmax = MAX16(pmax, tonal->info[pos].music_prob);
+      }
+      /* Bias against switching on active audio. */
+      pmin = MAX16(0.f, pmin - .1f*vad_prob);
+      pmax = MIN16(1.f, pmax + .1f*vad_prob);
+      prob_min += (1.f-.1f*curr_lookahead)*(pmin - prob_min);
+      prob_max += (1.f-.1f*curr_lookahead)*(pmax - prob_max);
+   }
+   info_out->music_prob_min = prob_min;
+   info_out->music_prob_max = prob_max;
+
+   /* printf("%f %f %f %f %f\n", prob_min, prob_max, prob_avg/prob_count, vad_prob, info_out->music_prob); */
    tonal->read_subframe += len/(tonal->Fs/400);
    while (tonal->read_subframe>=8)
    {
@@ -278,21 +385,6 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int
    }
    if (tonal->read_pos>=DETECT_SIZE)
       tonal->read_pos-=DETECT_SIZE;
-
-   /* The -1 is to compensate for the delay in the features themselves. */
-   curr_lookahead = IMAX(curr_lookahead-1, 0);
-
-   psum=0;
-   /* Summing the probability of transition patterns that involve music at
-      time (DETECT_SIZE-curr_lookahead-1) */
-   for (i=0;i<DETECT_SIZE-curr_lookahead;i++)
-      psum += tonal->pmusic[i];
-   for (;i<DETECT_SIZE;i++)
-      psum += tonal->pspeech[i];
-   psum = psum*tonal->music_confidence + (1-psum)*tonal->speech_confidence;
-   /*printf("%f %f %f %f %f\n", psum, info_out->music_prob, info_out->vad_prob, info_out->activity_probability, info_out->tonality);*/
-
-   info_out->music_prob = psum;
 }
 
 static const float std_feature_bias[9] = {
@@ -352,6 +444,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     float band_log2[NB_TBANDS+1];
     float leakage_from[NB_TBANDS+1];
     float leakage_to[NB_TBANDS+1];
+    float layer_out[MAX_NEURONS];
     SAVE_STACK;
 
     alpha = 1.f/IMIN(10, 1+tonal->count);
@@ -368,12 +461,6 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
        offset = 3*offset/2;
     }
 
-    if (tonal->count<4) {
-       if (tonal->application == OPUS_APPLICATION_VOIP)
-          tonal->music_prob = .1f;
-       else
-          tonal->music_prob = .625f;
-    }
     kfft = celt_mode->mdct.kfft[0];
     if (tonal->count==0)
        tonal->mem_fill = 240;
@@ -761,139 +848,17 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt
     features[23] = info->tonality_slope + 0.069216f;
     features[24] = tonal->lowECount - 0.067930f;
 
-    mlp_process(&net, features, frame_probs);
-    frame_probs[0] = .5f*(frame_probs[0]+1);
-    /* Curve fitting between the MLP probability and the actual probability */
-    /*frame_probs[0] = .01f + 1.21f*frame_probs[0]*frame_probs[0] - .23f*(float)pow(frame_probs[0], 10);*/
-    /* Probability of active audio (as opposed to silence) */
-    frame_probs[1] = .5f*frame_probs[1]+.5f;
-    frame_probs[1] *= frame_probs[1];
+    compute_dense(&layer0, layer_out, features);
+    compute_gru(&layer1, tonal->rnn_state, layer_out);
+    compute_dense(&layer2, frame_probs, tonal->rnn_state);
 
     /* Probability of speech or music vs noise */
     info->activity_probability = frame_probs[1];
+    /* It seems like the RNN tends to have a bias towards speech and this
+       warping of the probabilities compensates for it. */
+    info->music_prob = frame_probs[0] * (2 - frame_probs[0]);
 
-    /*printf("%f %f\n", frame_probs[0], frame_probs[1]);*/
-    {
-       /* Probability of state transition */
-       float tau;
-       /* Represents independence of the MLP probabilities, where
-          beta=1 means fully independent. */
-       float beta;
-       /* Denormalized probability of speech (p0) and music (p1) after update */
-       float p0, p1;
-       /* Probabilities for "all speech" and "all music" */
-       float s0, m0;
-       /* Probability sum for renormalisation */
-       float psum;
-       /* Instantaneous probability of speech and music, with beta pre-applied. */
-       float speech0;
-       float music0;
-       float p, q;
-
-       /* More silence transitions for speech than for music. */
-       tau = .001f*tonal->music_prob + .01f*(1-tonal->music_prob);
-       p = MAX16(.05f,MIN16(.95f,frame_probs[1]));
-       q = MAX16(.05f,MIN16(.95f,tonal->vad_prob));
-       beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p));
-       /* p0 and p1 are the probabilities of speech and music at this frame
-          using only information from previous frame and applying the
-          state transition model */
-       p0 = (1-tonal->vad_prob)*(1-tau) +    tonal->vad_prob *tau;
-       p1 =    tonal->vad_prob *(1-tau) + (1-tonal->vad_prob)*tau;
-       /* We apply the current probability with exponent beta to work around
-          the fact that the probability estimates aren't independent. */
-       p0 *= (float)pow(1-frame_probs[1], beta);
-       p1 *= (float)pow(frame_probs[1], beta);
-       /* Normalise the probabilities to get the Marokv probability of music. */
-       tonal->vad_prob = p1/(p0+p1);
-       info->vad_prob = tonal->vad_prob;
-       /* Consider that silence has a 50-50 probability of being speech or music. */
-       frame_probs[0] = tonal->vad_prob*frame_probs[0] + (1-tonal->vad_prob)*.5f;
-
-       /* One transition every 3 minutes of active audio */
-       tau = .0001f;
-       /* Adapt beta based on how "unexpected" the new prob is */
-       p = MAX16(.05f,MIN16(.95f,frame_probs[0]));
-       q = MAX16(.05f,MIN16(.95f,tonal->music_prob));
-       beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p));
-       /* p0 and p1 are the probabilities of speech and music at this frame
-          using only information from previous frame and applying the
-          state transition model */
-       p0 = (1-tonal->music_prob)*(1-tau) +    tonal->music_prob *tau;
-       p1 =    tonal->music_prob *(1-tau) + (1-tonal->music_prob)*tau;
-       /* We apply the current probability with exponent beta to work around
-          the fact that the probability estimates aren't independent. */
-       p0 *= (float)pow(1-frame_probs[0], beta);
-       p1 *= (float)pow(frame_probs[0], beta);
-       /* Normalise the probabilities to get the Marokv probability of music. */
-       tonal->music_prob = p1/(p0+p1);
-       info->music_prob = tonal->music_prob;
-
-       /*printf("%f %f %f %f\n", frame_probs[0], frame_probs[1], tonal->music_prob, tonal->vad_prob);*/
-       /* This chunk of code deals with delayed decision. */
-       psum=1e-20f;
-       /* Instantaneous probability of speech and music, with beta pre-applied. */
-       speech0 = (float)pow(1-frame_probs[0], beta);
-       music0  = (float)pow(frame_probs[0], beta);
-       if (tonal->count==1)
-       {
-          if (tonal->application == OPUS_APPLICATION_VOIP)
-             tonal->pmusic[0] = .1f;
-          else
-             tonal->pmusic[0] = .625f;
-          tonal->pspeech[0] = 1-tonal->pmusic[0];
-       }
-       /* Updated probability of having only speech (s0) or only music (m0),
-          before considering the new observation. */
-       s0 = tonal->pspeech[0] + tonal->pspeech[1];
-       m0 = tonal->pmusic [0] + tonal->pmusic [1];
-       /* Updates s0 and m0 with instantaneous probability. */
-       tonal->pspeech[0] = s0*(1-tau)*speech0;
-       tonal->pmusic [0] = m0*(1-tau)*music0;
-       /* Propagate the transition probabilities */
-       for (i=1;i<DETECT_SIZE-1;i++)
-       {
-          tonal->pspeech[i] = tonal->pspeech[i+1]*speech0;
-          tonal->pmusic [i] = tonal->pmusic [i+1]*music0;
-       }
-       /* Probability that the latest frame is speech, when all the previous ones were music. */
-       tonal->pspeech[DETECT_SIZE-1] = m0*tau*speech0;
-       /* Probability that the latest frame is music, when all the previous ones were speech. */
-       tonal->pmusic [DETECT_SIZE-1] = s0*tau*music0;
-
-       /* Renormalise probabilities to 1 */
-       for (i=0;i<DETECT_SIZE;i++)
-          psum += tonal->pspeech[i] + tonal->pmusic[i];
-       psum = 1.f/psum;
-       for (i=0;i<DETECT_SIZE;i++)
-       {
-          tonal->pspeech[i] *= psum;
-          tonal->pmusic [i] *= psum;
-       }
-       psum = tonal->pmusic[0];
-       for (i=1;i<DETECT_SIZE;i++)
-          psum += tonal->pspeech[i];
-
-       /* Estimate our confidence in the speech/music decisions */
-       if (frame_probs[1]>.75)
-       {
-          if (tonal->music_prob>.9)
-          {
-             float adapt;
-             adapt = 1.f/(++tonal->music_confidence_count);
-             tonal->music_confidence_count = IMIN(tonal->music_confidence_count, 500);
-             tonal->music_confidence += adapt*MAX16(-.2f,frame_probs[0]-tonal->music_confidence);
-          }
-          if (tonal->music_prob<.1)
-          {
-             float adapt;
-             adapt = 1.f/(++tonal->speech_confidence_count);
-             tonal->speech_confidence_count = IMIN(tonal->speech_confidence_count, 500);
-             tonal->speech_confidence += adapt*MIN16(.2f,frame_probs[0]-tonal->speech_confidence);
-          }
-       }
-    }
-    tonal->last_music = tonal->music_prob>.5f;
+    /*printf("%f %f %f\n", frame_probs[0], frame_probs[1], info->music_prob);*/
 #ifdef MLP_TRAINING
     for (i=0;i<25;i++)
        printf("%f ", features[i]);
index cac51df..289c845 100644 (file)
@@ -30,6 +30,7 @@
 
 #include "celt.h"
 #include "opus_private.h"
+#include "mlp.h"
 
 #define NB_FRAMES 8
 #define NB_TBANDS 18
@@ -64,28 +65,16 @@ typedef struct {
    float mem[32];
    float cmean[8];
    float std[9];
-   float music_prob;
-   float vad_prob;
    float Etracker;
    float lowECount;
    int E_count;
-   int last_music;
    int count;
    int analysis_offset;
-   /** Probability of having speech for time i to DETECT_SIZE-1 (and music before).
-       pspeech[0] is the probability that all frames in the window are speech. */
-   float pspeech[DETECT_SIZE];
-   /** Probability of having music for time i to DETECT_SIZE-1 (and speech before).
-       pmusic[0] is the probability that all frames in the window are music. */
-   float pmusic[DETECT_SIZE];
-   float speech_confidence;
-   float music_confidence;
-   int speech_confidence_count;
-   int music_confidence_count;
    int write_pos;
    int read_pos;
    int read_subframe;
    float hp_ener_accum;
+   float rnn_state[MAX_NEURONS];
    opus_val32 downmix_state[3];
    AnalysisInfo info[DETECT_SIZE];
 } TonalityAnalysisState;
index ff9e50d..0e5ef16 100644 (file)
--- a/src/mlp.c
+++ b/src/mlp.c
@@ -1,5 +1,5 @@
 /* Copyright (c) 2008-2011 Octasic Inc.
-   Written by Jean-Marc Valin */
+                 2012-2017 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
 #include "config.h"
 #endif
 
+#include <math.h>
 #include "opus_types.h"
 #include "opus_defines.h"
-
-#include <math.h>
-#include "mlp.h"
 #include "arch.h"
 #include "tansig_table.h"
-#define MAX_NEURONS 100
+#include "mlp.h"
 
-#if 0
-static OPUS_INLINE opus_val16 tansig_approx(opus_val32 _x) /* Q19 */
-{
-    int i;
-    opus_val16 xx; /* Q11 */
-    /*double x, y;*/
-    opus_val16 dy, yy; /* Q14 */
-    /*x = 1.9073e-06*_x;*/
-    if (_x>=QCONST32(8,19))
-        return QCONST32(1.,14);
-    if (_x<=-QCONST32(8,19))
-        return -QCONST32(1.,14);
-    xx = EXTRACT16(SHR32(_x, 8));
-    /*i = lrint(25*x);*/
-    i = SHR32(ADD32(1024,MULT16_16(25, xx)),11);
-    /*x -= .04*i;*/
-    xx -= EXTRACT16(SHR32(MULT16_16(20972,i),8));
-    /*x = xx*(1./2048);*/
-    /*y = tansig_table[250+i];*/
-    yy = tansig_table[250+i];
-    /*y = yy*(1./16384);*/
-    dy = 16384-MULT16_16_Q14(yy,yy);
-    yy = yy + MULT16_16_Q14(MULT16_16_Q11(xx,dy),(16384 - MULT16_16_Q11(yy,xx)));
-    return yy;
-}
-#else
-/*extern const float tansig_table[501];*/
 static OPUS_INLINE float tansig_approx(float x)
 {
     int i;
@@ -92,54 +63,79 @@ static OPUS_INLINE float tansig_approx(float x)
     y = y + x*dy*(1 - y*x);
     return sign*y;
 }
-#endif
 
-#if 0
-void mlp_process(const MLP *m, const opus_val16 *in, opus_val16 *out)
+static OPUS_INLINE float sigmoid_approx(float x)
 {
-    int j;
-    opus_val16 hidden[MAX_NEURONS];
-    const opus_val16 *W = m->weights;
-    /* Copy to tmp_in */
-    for (j=0;j<m->topo[1];j++)
-    {
-        int k;
-        opus_val32 sum = SHL32(EXTEND32(*W++),8);
-        for (k=0;k<m->topo[0];k++)
-            sum = MAC16_16(sum, in[k],*W++);
-        hidden[j] = tansig_approx(sum);
-    }
-    for (j=0;j<m->topo[2];j++)
-    {
-        int k;
-        opus_val32 sum = SHL32(EXTEND32(*W++),14);
-        for (k=0;k<m->topo[1];k++)
-            sum = MAC16_16(sum, hidden[k], *W++);
-        out[j] = tansig_approx(EXTRACT16(PSHR32(sum,17)));
-    }
+   return .5 + .5*tansig_approx(.5*x);
 }
-#else
-void mlp_process(const MLP *m, const float *in, float *out)
+
+void compute_dense(const DenseLayer *layer, float *output, const float *input)
 {
-    int j;
-    float hidden[MAX_NEURONS];
-    const float *W = m->weights;
-    /* Copy to tmp_in */
-    for (j=0;j<m->topo[1];j++)
-    {
-        int k;
-        float sum = *W++;
-        for (k=0;k<m->topo[0];k++)
-            sum = sum + in[k]**W++;
-        hidden[j] = tansig_approx(sum);
-    }
-    for (j=0;j<m->topo[2];j++)
-    {
-        int k;
-        float sum = *W++;
-        for (k=0;k<m->topo[1];k++)
-            sum = sum + hidden[k]**W++;
-        out[j] = tansig_approx(sum);
-    }
+   int i, j;
+   int N, M;
+   int stride;
+   M = layer->nb_inputs;
+   N = layer->nb_neurons;
+   stride = N;
+   for (i=0;i<N;i++)
+   {
+      /* Compute update gate. */
+      float sum = layer->bias[i];
+      for (j=0;j<M;j++)
+         sum += layer->input_weights[j*stride + i]*input[j];
+      output[i] = WEIGHTS_SCALE*sum;
+   }
+   if (layer->sigmoid) {
+      for (i=0;i<N;i++)
+         output[i] = sigmoid_approx(output[i]);
+   } else {
+      for (i=0;i<N;i++)
+         output[i] = tansig_approx(output[i]);
+   }
 }
-#endif
+
+void compute_gru(const GRULayer *gru, float *state, const float *input)
+{
+   int i, j;
+   int N, M;
+   int stride;
+   float z[MAX_NEURONS];
+   float r[MAX_NEURONS];
+   float h[MAX_NEURONS];
+   M = gru->nb_inputs;
+   N = gru->nb_neurons;
+   stride = 3*N;
+   for (i=0;i<N;i++)
+   {
+      /* Compute update gate. */
+      float sum = gru->bias[i];
+      for (j=0;j<M;j++)
+         sum += gru->input_weights[j*stride + i]*input[j];
+      for (j=0;j<N;j++)
+         sum += gru->recurrent_weights[j*stride + i]*state[j];
+      z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
+   }
+   for (i=0;i<N;i++)
+   {
+      /* Compute reset gate. */
+      float sum = gru->bias[N + i];
+      for (j=0;j<M;j++)
+         sum += gru->input_weights[N + j*stride + i]*input[j];
+      for (j=0;j<N;j++)
+         sum += gru->recurrent_weights[N + j*stride + i]*state[j];
+      r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
+   }
+   for (i=0;i<N;i++)
+   {
+      /* Compute output. */
+      float sum = gru->bias[2*N + i];
+      for (j=0;j<M;j++)
+         sum += gru->input_weights[2*N + j*stride + i]*input[j];
+      for (j=0;j<N;j++)
+         sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
+      h[i] = z[i]*state[i] + (1-z[i])*tansig_approx(WEIGHTS_SCALE*sum);
+   }
+   for (i=0;i<N;i++)
+      state[i] = h[i];
+}
+
index 618e246..e3d1e9e 100644 (file)
--- a/src/mlp.h
+++ b/src/mlp.h
@@ -1,5 +1,4 @@
-/* Copyright (c) 2008-2011 Octasic Inc.
-   Written by Jean-Marc Valin */
+/* Copyright (c) 2017 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
 #ifndef _MLP_H_
 #define _MLP_H_
 
-#include "arch.h"
+#include "opus_types.h"
+
+#define WEIGHTS_SCALE (1.f/8192)
+
+#define MAX_NEURONS 20
 
 typedef struct {
-    int layers;
-    const int *topo;
-    const float *weights;
-} MLP;
+  const opus_int16 *bias;
+  const opus_int16 *input_weights;
+  int nb_inputs;
+  int nb_neurons;
+  int sigmoid;
+} DenseLayer;
+
+typedef struct {
+  const opus_int16 *bias;
+  const opus_int16 *input_weights;
+  const opus_int16 *recurrent_weights;
+  int nb_inputs;
+  int nb_neurons;
+} GRULayer;
+
+extern const DenseLayer layer0;
+extern const GRULayer layer1;
+extern const DenseLayer layer2;
 
-extern const MLP net;
+void compute_dense(const DenseLayer *layer, float *output, const float *input);
 
-void mlp_process(const MLP *m, const float *in, float *out);
+void compute_gru(const GRULayer *gru, float *state, const float *input);
 
 #endif /* _MLP_H_ */
index a819880..10b787d 100644 (file)
+/*This file is automatically generated from a Keras model*/
+
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
 
 #include "mlp.h"
 
-/* RMS error was 0.280492, seed was 1480478173 */
-/* 0.005976 0.031821 (0.280494 0.280492) done */
+static const opus_int16 layer0_weights[400] = {
+   -249, 690, -57, 358, -560, -144, 186, 75,
+   -804, -1176, -433, -78, 125, -1141, -857, -2,
+   1892, 91, 976, 1112, -1636, -73, -1740, -1604,
+   2012, -1043, 828, 230, 8698, -92, -665, -747,
+   1530, -1315, 2317, 697, 2885, -1399, 2661, 483,
+   -1628, 502, -592, 299, 3910, -781, 2738, 1338,
+   -1562, -149, 3468, 1448, 3057, 1202, 2098, 2777,
+   -1540, -3018, -249, 4656, 2508, 373, 2412, -776,
+   7160, -519, -917, -155, -1311, -1239, -637, -1245,
+   -1450, 1963, 3297, 1489, 1582, -123, -549, 1004,
+   -4085, 8792, -2145, 220, 2741, 624, -3560, 106,
+   -2476, 661, 1601, 2177, -1793, -623, 3349, 1959,
+   2777, -4635, 451, -996, -3260, -665, 1103, 201,
+   -2566, 3033, 1065, 1866, 989, -102, -1328, 126,
+   1, 4365, 82, 2355, -1011, -107, -5323, -1758,
+   -691, 1744, 683, -2732, 1309, -1135, -726, 1071,
+   9423, 1120, -705, -188, -200, -2668, -750, -1839,
+   793, 718, -1011, 222, 567, 31, -1520, 3142,
+   -5491, -3549, -2718, -276, 2078, -706, -779, -2304,
+   -2983, -660, 1664, -999, -3297, -1200, 1017, -499,
+   -764, 3215, -720, 255, 1539, -1142, -3604, -351,
+   -982, 846, 4069, 481, 5673, -1184, -2883, -1387,
+   519, -1617, 315, 1875, -119, 2383, 1141, 1583,
+   1013, -531, 349, 121, -139, 327, 531, 611,
+   853, 1118, 2013, -294, -1150, 693, 531, 583,
+   -1506, 224, -818, 655, 1981, 1056, -2327, -1457,
+   -2846, 3779, 1230, -2587, -191, 1647, -3484, -3450,
+   -3384, -93, -1028, 825, 868, 38, 557, -125,
+   1830, 1981, 1063, 9906, -455, 172, -1788, 4417,
+   472, -1398, -4638, 999, -6158, 1943, 4703, -2986,
+   -938, 3053, -631, -384, 848, -3909, 1352, -2362,
+   -2306, 515, 2385, -2373, -1642, 582, -262, -571,
+   8, 1615, -2501, 1225, -660, -857, -522, 2419,
+   654, -1137, 67, -890, 83, 23, 2166, 524,
+   -978, 5330, 1237, 1163, -2251, -142, -2331, 3034,
+   395, -1799, 944, 1978, -2788, 1324, 3271, -4643,
+   -1313, -2472, 1296, -2316, -1803, -10224, -8577, 8271,
+   -1920, -3366, -1704, 3250, -2514, 11995, 6655, 4298,
+   1046, 483, 651, -901, -1417, 804, 396, -2617,
+   1000, 2265, 5354, -1050, 2505, 41, 3928, 1878,
+   -21057, 12783, 32767, -8139, -32768, 1106, -12076, -26511,
+   -3484, 24604, 8938, 22944, -9490, -6208, -22142, 23250,
+   -12708, -299, 14432, -2311, -11941, -797, -3287, -4744,
+   -10758, 10226, -851, 8565, 4104, -4002, 4456, 12642,
+   1685, -7093, -997, 16081, 814, -5316, -13491, 12766,
+   -1637, -213, 7271, -3037, -6772, 3053, -12425, -6955,
+   12553, 7635, -32768, -18611, 22929, 3056, 11196, 5202,
+   31582, 5741, -22206, 6145, -673, -25488, -7005, -16479,
+   10693, -11369, -10848, -1895, 8051, 7360, 1067, -220,
+   6643, 17077, -12356, 3288, 4619, 9751, -656, -1217
+};
+
+static const opus_int16 layer0_bias[16] = {
+   -164, 2802, -2100, 410, 4003, -888, 3010, -644,
+   4499, -121, 3753, -1606, -4855, -1828, -682, -79
+};
+
+static const opus_int16 layer1_weights[576] = {
+   543, 2150, 143, 1450, 7898, -3201, -2648, -4311,
+   7028, -2608, 1844, 126, -858, 4572, -347, -11298,
+   11315, -4344, 1858, -5906, -5962, 2847, -3894, -1496,
+   5309, -651, -3143, -3141, 429, -679, -1524, -1966,
+   -1175, 2917, 97, -1094, -3186, 4346, 832, 3726,
+   5452, 1371, 505, -1282, -435, 3438, 691, -2692,
+   -872, -1332, 3722, 841, -1081, 2414, -1275, 2131,
+   -7351, -962, -2295, 1141, 2810, -839, 1444, -1005,
+   3900, 1160, 1070, -801, -1856, 2152, -79, 122,
+   -2790, -5641, -2021, -4328, 992, 664, 1078, 4919,
+   -5314, -665, -4650, -4734, 3417, -300, -3038, 6124,
+   -1161, -1786, -2922, 10536, 2726, 1200, -1840, 3752,
+   -3420, 1710, 2414, -2704, 918, 518, 1057, 1837,
+   3098, 1665, 2780, 1636, -3883, -150, -3216, -5393,
+   1819, -3555, -3063, -3252, -2948, 8249, -3856, -3981,
+   406, -5407, -2135, 3006, -1920, -694, 1349, 2321,
+   -3114, -1262, -1296, -406, -712, 185, 1802, 62,
+   -1559, -62, 2270, -195, -1043, 2092, -3543, 1833,
+   1193, 1880, 3076, 6353, 1671, -634, 3180, -21,
+   -612, 800, 6405, 2825, 1187, 583, -2961, -6221,
+   -1035, -1686, 3563, 7102, 7122, 3946, 3264, -2081,
+   574, -2400, 22, 112, 1073, -2386, -3224, -3508,
+   -1347, -3521, 992, -2582, -7175, 1241, -1368, -6035,
+   -2555, -6012, -11198, -2492, -4061, -7604, -3521, -5613,
+   -3823, -6300, 6377, -6267, -3568, -1121, -2755, -6177,
+   2627, -2735, -4447, -2327, -577, 824, 2159, -1206,
+   47, -3988, -3918, -1073, -540, -595, 2777, -1114,
+   985, 407, -1907, -3836, -7385, 9579, 120, 4717,
+   -1921, -5036, 1388, -2388, -1476, 2967, 2905, 3306,
+   -631, -1730, 4974, 51, -1131, -3307, -1678, -354,
+   2481, -1133, 997, -1374, 2350, 1945, -274, -2238,
+   -1642, 869, 139, -2974, -1210, -362, 3461, -3912,
+   -7937, -1246, 5396, -6235, -6650, -9613, -5547, 2541,
+   -330, -2843, -3100, -227, 1859, 3371, 5094, 4045,
+   -8379, -2052, 363, 2005, 2248, 772, -872, 1686,
+   -3885, 1413, 704, -379, -1130, -703, -3406, 179,
+   2895, 11203, -1085, -2496, -10569, 877, 2982, 4245,
+   7216, -3703, 2468, 1361, -66, 236, -958, -3101,
+   2424, -2604, 1854, -5674, 2951, -1898, 3078, 20,
+   1217, -3799, 802, -458, -1522, -3094, -2448, -2067,
+   658, -3163, 1976, -1577, -8063, 380, -1328, 5963,
+   -7396, -5218, -7379, -9166, -616, -1731, 2383, 3735,
+   10889, -5348, 1128, -6396, -4613, -1547, 2619, -2967,
+   2229, 3582, -156, -3970, -2606, -3270, 2515, -568,
+   -2800, -3145, -2641, 2530, 1079, 3184, -814, -1762,
+   2128, -6864, 5163, -3934, 2410, 2574, 1568, -5281,
+   -1199, -2462, 713, -1456, 4651, -8439, -2239, -4620,
+   316, 1772, 89, -2021, -658, -9442, -1249, -195,
+   -1311, -1129, 1734, 1991, 421, 579, 833, 2917,
+   1025, -3243, -2909, 1950, -2845, 898, -1011, 5505,
+   4705, 2989, -4835, -939, 3768, -1641, 10910, 34,
+   -938, 1839, 4835, -2526, -1699, -9939, 4135, 2330,
+   746, -2420, 898, 588, -3496, -2904, -3896, 639,
+   1046, 440, 1254, 2025, 2089, 3468, 697, 888,
+   4553, 2152, 4522, 2916, 3432, 4376, -717, -8019,
+   8063, -1602, -5389, -1549, 4541, 412, 413, -5267,
+   5859, 147, 2962, 6490, -2794, 1448, -1348, -815,
+   -1089, -934, 1485, -1420, 827, -2345, -403, 2359,
+   -1298, 238, 1127, 1984, 3667, -6776, 1191, -1049,
+   6323, 3381, 4703, 5709, 1693, -3948, -4716, 5403,
+   -3221, -1108, 478, -4250, 2643, 1458, -4684, -5321,
+   -1610, -1048, 4730, 1253, 1975, 1904, 2112, -1591,
+   -5355, 1317, -2438, 113, -1285, 4023, -1129, 3054,
+   -5091, 1484, -742, -1258, 1044, -1035, -442, 789,
+   1525, 10987, -897, 2773, 357, 4770, 1942, 524,
+   1315, 3575, -656, 1394, -14, -4854, 2764, 5455,
+   1649, 1005, -1792, 1558, -1490, 3447, -1066, 662,
+   -974, -870, 1611, 2541, -2744, -1782, -1456, -820,
+   261, -1722, -3869, -9244, 4372, 4013, -2733, -13592,
+   5458, -6824, -634, 707, 742, 4432, -3446, -4348,
+   916, 505, 3267, -9216, -3492, 2121, -4923, 4175,
+   -119, -1497, 1421, 3593, 1398, 273, 2351, 404
+};
+
+static const opus_int16 layer1_recur_weights[432] = {
+   381, -8053, -3581, -73, 5728, -10914, -4592, -14935,
+   2526, -3600, 3424, 5804, -2523, 2785, -2245, 734,
+   1045, -2857, 3888, -11398, 3406, -2679, 4999, -103,
+   6707, -7102, 1158, -4524, 3212, 2065, -255, -4255,
+   1682, -987, 333, 1958, 2943, -1600, 6811, 2103,
+   4030, -4778, 5490, -11909, -1505, 3493, -9066, -3412,
+   -1673, -7387, -1995, 451, -2989, -2608, 317, 2076,
+   -6350, 4404, -1222, -3854, -4675, 12616, 3739, 126,
+   1343, 8117, 620, -415, -1140, -931, -2678, -1561,
+   -1454, 1010, 1821, -1230, -3869, 3745, 2041, -1243,
+   -196, -4974, -9547, -6367, 3797, 105, -698, -1409,
+   -7030, 5843, -6749, -7885, -1051, 3730, -1202, 2938,
+   1536, 2797, 4495, -309, 1954, 1637, 3972, 723,
+   1782, 4101, 5525, -6803, 3625, 4203, -3680, -4308,
+   -5662, 2223, 1929, 1113, 7828, 61, -5548, -10833,
+   8655, 3489, 3680, -829, -496, 6740, 1317, -1402,
+   2411, 402, 1420, 1971, -3876, 4533, 4610, 6555,
+   2928, -2090, -1689, 1243, 3253, 1051, 4787, -3870,
+   -2253, 4030, -507, 3956, -7122, 6049, 3373, 5868,
+   782, 3961, -2132, -3936, 3944, -195, 1283, -382,
+   -141, 1447, 2272, 4714, 579, 3492, -2719, 937,
+   3498, -5240, 3375, 3040, 290, -7514, -2126, -7146,
+   3084, 1281, 4354, 338, 5197, -1488, 1623, 1854,
+   -2707, -2176, 3413, -2245, 851, 1715, -2870, 1309,
+   -1127, 662, -1673, 7551, -4901, -4459, 1943, -5998,
+   -4459, 1988, -1437, -6808, -530, 812, 6763, 1088,
+   -108, -547, -2758, 5672, 857, 2366, 1770, -3537,
+   -8239, 63, 6457, 3256, 2453, 5478, 3192, 4728,
+   -5188, -1048, -1468, 1944, -1620, -4830, 8233, 4379,
+   887, -1339, 1825, 8806, -7448, 5491, 2284, 1983,
+   4417, -50, -411, -1528, -609, 3553, -7104, 2208,
+   -4777, -877, -3517, 939, -5368, -7444, 4267, -994,
+   -3320, 3897, 1161, 3366, -6309, 6119, -3928, -2835,
+   1384, -1238, 1558, -90, -1277, 3429, -2350, 929,
+   -7380, 705, -1443, -6141, -4110, 5939, 3391, -2137,
+   222, 408, 619, 5516, 6060, 471, -2335, 31,
+   636, -7196, 2346, -2082, 2530, -2093, 1603, -7208,
+   -6764, 2089, -10548, -3235, -3035, -9519, 5596, -5862,
+   -264, -514, -5881, 2064, 2158, -688, 1983, 9081,
+   -395, 1106, 1501, 506, -466, -3651, -879, 9723,
+   5714, -1403, 3090, 2208, -127, -6849, -579, -1405,
+   6088, -8262, -8095, -1043, -9232, -1771, -2790, -5700,
+   -1568, -1509, -1257, -2664, -1594, 560, -7664, -3712,
+   -971, 3808, -3434, -1332, -3769, -1509, 316, 3281,
+   1581, -2888, -2234, -118, 919, 3520, 8085, -2894,
+   1110, 12122, -1275, -2171, -1876, 8625, 1850, 1449,
+   6177, 1800, 627, -5902, 3864, 4634, -3149, -1776,
+   1389, 2766, 481, 2372, -71, 1265, -357, 1275,
+   -2011, 2432, 8081, 2382, 8879, 1983, -1742, -4043,
+   -361, 6496, 5009, -320, 4582, -2144, -4184, -1141,
+   -2661, -3733, -380, -1826, -17320, -3020, -11362, -10212,
+   -2959, -897, -2687, 1760, 2843, 836, -1765, 2219,
+   -3431, 298, 1666, -4254, 1589, -244, -745, -1628,
+   1684, 2892, -4366, 2072, -6710, -1399, -8910, 2407
+};
+
+static const opus_int16 layer1_bias[36] = {
+   14206, 6258, 9052, 6611, -3603, 8785, 5625, 9775,
+   6516, 4736, 8943, 3466, -888, -778, 5042, -3041,
+   2719, 1724, 1216, 1698, 805, 2729, 1820, 4066,
+   -3456, 3091, 1570, 542, 599, 2583, 2052, 1258,
+   -2255, 1508, 1183, -5095
+};
+
+static const opus_int16 layer2_weights[24] = {
+   946, -14834, -5002, 14299, 10342, 1471, 7109, -508,
+   11745, -1786, -621, 15227, -4577, 30114, 5174, 12698,
+   22279, -527, 7727, 2246, 9892, -2297, -15579, 853
+};
 
-static const float weights[450] = {
+static const opus_int16 layer2_bias[2] = {
+   3700, 8418
+};
 
-/* hidden layer */
--0.514624f, 0.0234227f, -0.14329f, -0.0878216f, -0.00187827f,
--0.0257443f, 0.108524f, 0.00333881f, 0.00585017f, -0.0246132f,
-0.142723f, -0.00436494f, 0.0101354f, -0.11124f, -0.0809367f,
--0.0750772f, 0.0295524f, 0.00823944f, 0.150392f, 0.0320876f,
--0.0710564f, -1.43818f, 0.652076f, 0.0650744f, -1.54821f,
-0.168949f, -1.92724f, 0.0517976f, -0.0670737f, -0.0690121f,
-0.00247528f, -0.0522024f, 0.0631368f, 0.0532776f, 0.047751f,
--0.011715f, 0.142374f, -0.0290885f, -0.279263f, -0.433499f,
--0.0795174f, -0.380458f, -0.051263f, 0.218537f, -0.322478f,
-1.06667f, -0.104607f, -4.70108f, 0.312037f, 0.277397f,
--2.71859f, 1.70037f, -0.141845f, 0.0115618f, 0.0629883f,
-0.0403871f, 0.0139428f, -0.00430733f, -0.0429038f, -0.0590318f,
--0.0501526f, -0.0284802f, -0.0415686f, -0.0438999f, 0.0822666f,
-0.197194f, 0.0363275f, -0.0584307f, 0.0752364f, -0.0799796f,
--0.146275f, 0.161661f, -0.184585f, 0.145568f, 0.442823f,
-1.61221f, 1.11162f, 2.62177f, -2.482f, -0.112599f,
--0.110366f, -0.140794f, -0.181694f, 0.0648674f, 0.0842248f,
-0.0933993f, 0.150122f, 0.129171f, 0.176848f, 0.141758f,
--0.271822f, 0.235113f, 0.0668579f, -0.433957f, 0.113633f,
--0.169348f, -1.40091f, 0.62861f, -0.134236f, 0.402173f,
-1.86373f, 1.53998f, -4.32084f, 0.735343f, 0.800214f,
--0.00968415f, 0.0425904f, 0.0196811f, -0.018426f, -0.000343953f,
--0.00416389f, 0.00111558f, 0.0173069f, -0.00998596f, -0.025898f,
-0.00123764f, -0.00520373f, -0.0565033f, 0.0637394f, 0.0051213f,
-0.0221361f, 0.00819962f, -0.0467061f, -0.0548258f, -0.00314063f,
--1.18332f, 1.88091f, -0.41148f, -2.95727f, -0.521449f,
--0.271641f, 0.124946f, -0.0532936f, 0.101515f, 0.000208564f,
--0.0488748f, 0.0642388f, -0.0383848f, 0.0135046f, -0.0413592f,
--0.0326402f, -0.0137421f, -0.0225219f, -0.0917294f, -0.277759f,
--0.185418f, 0.0471128f, -0.125879f, 0.262467f, -0.212794f,
--0.112931f, -1.99885f, -0.404787f, 0.224402f, 0.637962f,
--0.27808f, -0.0723953f, -0.0537655f, -0.0336359f, -0.0906601f,
--0.0641309f, -0.0713542f, 0.0524317f, 0.00608819f, 0.0754101f,
--0.0488401f, -0.00671865f, 0.0418239f, 0.0536284f, -0.132639f,
-0.0267648f, -0.248432f, -0.0104153f, 0.035544f, -0.212753f,
--0.302895f, -0.0357854f, 0.376838f, 0.597025f, -0.664647f,
-0.268422f, -0.376772f, -1.05472f, 0.0144178f, 0.179122f,
-0.0360155f, 0.220262f, -0.0056381f, 0.0317197f, 0.0621066f,
--0.00779298f, 0.00789378f, 0.00350605f, 0.0104809f, 0.0362871f,
--0.157708f, -0.0659779f, -0.0926278f, 0.00770791f, 0.0631621f,
-0.0817343f, -0.424295f, -0.0437727f, -0.24251f, 0.711217f,
--0.736455f, -2.194f, -0.107612f, -0.175156f, -0.0366573f,
--0.0123156f, -0.0628516f, -0.0218977f, -0.00693699f, 0.00695185f,
-0.00507362f, 0.00359334f, 0.0052661f, 0.035561f, 0.0382701f,
-0.0342179f, -0.00790271f, -0.0170925f, 0.047029f, 0.0197362f,
--0.0153435f, 0.0644152f, -0.36862f, -0.0674876f, -2.82672f,
-1.34122f, -0.0788029f, -3.47792f, 0.507246f, -0.816378f,
--0.0142383f, -0.127349f, -0.106926f, -0.0359524f, 0.105045f,
-0.291554f, 0.195413f, 0.0866214f, -0.066577f, -0.102188f,
-0.0979466f, -0.12982f, 0.400181f, -0.409336f, -0.0593326f,
--0.0656203f, -0.204474f, 0.179802f, 0.000509084f, 0.0995954f,
--2.377f, -0.686359f, 0.934861f, 1.10261f, 1.3901f,
--4.33616f, -0.00264017f, 0.00713045f, 0.106264f, 0.143726f,
--0.0685305f, -0.054656f, -0.0176725f, -0.0772669f, -0.0264526f,
--0.0103824f, -0.0269872f, -0.00687f, 0.225804f, 0.407751f,
--0.0612611f, -0.0576863f, -0.180131f, -0.222772f, -0.461742f,
-0.335236f, 1.03399f, 4.24112f, -0.345796f, -0.594549f,
--76.1407f, -0.265276f, 0.0507719f, 0.0643044f, 0.0384832f,
-0.0424459f, -0.0387817f, -0.0235996f, -0.0740556f, -0.0270029f,
-0.00882177f, -0.0552371f, -0.00485851f, 0.314295f, 0.360431f,
--0.0787085f, 0.110355f, -0.415958f, -0.385088f, -0.272224f,
--1.55108f, -0.141848f, 0.448877f, -0.563447f, -2.31403f,
--0.120077f, -1.49918f, -0.817726f, -0.0495854f, -0.0230782f,
--0.0224014f, 0.117076f, 0.0393216f, 0.051997f, 0.0330763f,
--0.110796f, 0.0211117f, -0.0197258f, 0.0187461f, 0.0125183f,
-0.14876f, 0.0920565f, -0.342475f, 0.135272f, -0.168155f,
--0.033423f, -0.0604611f, -0.128835f, 0.664947f, -0.144997f,
-2.27649f, 1.28663f, 0.841217f, -2.42807f, 0.0230471f,
-0.226709f, -0.0374803f, 0.155436f, 0.0400342f, -0.184686f,
-0.128488f, -0.0939518f, -0.0578559f, 0.0265967f, -0.0999322f,
--0.0322768f, -0.322994f, -0.189371f, -0.738069f, -0.0754914f,
-0.214717f, -0.093728f, -0.695741f, 0.0899298f, -2.06188f,
--0.273719f, -0.896977f, 0.130553f, 0.134638f, 1.29355f,
-0.00520749f, -0.0324224f, 0.00530451f, 0.0192385f, 0.00328708f,
-0.0250838f, 0.0053365f, -0.0177321f, 0.00618789f, 0.00525364f,
-0.00104596f, -0.0360459f, 0.0402403f, -0.0406351f, 0.0136883f,
-0.0880722f, -0.0197449f, 0.089938f, 0.0100456f, -0.0475638f,
--0.73267f, 0.037433f, -0.146551f, -0.230221f, -3.06489f,
--1.40194f, 0.0198483f, 0.0397953f, -0.0190239f, 0.0470715f,
--0.131363f, -0.191721f, -0.0176224f, -0.0480352f, -0.221799f,
--0.26794f, -0.0292615f, 0.0612127f, -0.129877f, 0.00628332f,
--0.085918f, 0.0175379f, 0.0541011f, -0.0810874f, -0.380809f,
--0.222056f, -0.508859f, -0.473369f, 0.484958f, -2.28411f,
-0.0139516f,
-/* output layer */
-3.90017f, 1.71789f, -1.43372f, -2.70839f, 1.77107f,
-5.48006f, 1.44661f, 2.01134f, -1.88383f, -3.64958f,
--1.26351f, 0.779421f, 2.11357f, 3.10409f, 1.68846f,
--4.46197f, -1.61455f, 3.59832f, 2.43531f, -1.26458f,
-0.417941f, 1.47437f, 2.16635f, -1.909f, -0.828869f,
-1.38805f, -2.67975f, -0.110044f, 1.95596f, 0.697931f,
--0.313226f, -0.889315f, 0.283236f, 0.946102f, };
+const DenseLayer layer0 = {
+   layer0_bias,
+   layer0_weights,
+   25, 16, 0
+};
 
-static const int topo[3] = {25, 16, 2};
+const GRULayer layer1 = {
+   layer1_bias,
+   layer1_weights,
+   layer1_recur_weights,
+   16, 12
+};
 
-const MLP net = {
-    3,
-    topo,
-    weights
+const DenseLayer layer2 = {
+   layer2_bias,
+   layer2_weights,
+   12, 2, 1
 };
+
diff --git a/src/mlp_train.c b/src/mlp_train.c
deleted file mode 100644 (file)
index 8d9d127..0000000
+++ /dev/null
@@ -1,501 +0,0 @@
-/* Copyright (c) 2008-2011 Octasic Inc.
-   Written by Jean-Marc Valin */
-/*
-   Redistribution and use in source and binary forms, with or without
-   modification, are permitted provided that the following conditions
-   are met:
-
-   - Redistributions of source code must retain the above copyright
-   notice, this list of conditions and the following disclaimer.
-
-   - Redistributions in binary form must reproduce the above copyright
-   notice, this list of conditions and the following disclaimer in the
-   documentation and/or other materials provided with the distribution.
-
-   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
-   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
-   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
-   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-*/
-
-
-#include "mlp_train.h"
-#include <stdlib.h>
-#include <stdio.h>
-#include <string.h>
-#include <semaphore.h>
-#include <pthread.h>
-#include <time.h>
-#include <signal.h>
-
-int stopped = 0;
-
-void handler(int sig)
-{
-    stopped = 1;
-    signal(sig, handler);
-}
-
-MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
-{
-    int i, j, k;
-    MLPTrain *net;
-    int inDim, outDim;
-    net = malloc(sizeof(*net));
-    net->topo = malloc(nbLayers*sizeof(net->topo[0]));
-    for (i=0;i<nbLayers;i++)
-        net->topo[i] = topo[i];
-    inDim = topo[0];
-    outDim = topo[nbLayers-1];
-    net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
-    net->weights = malloc((nbLayers-1)*sizeof(net->weights));
-    net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
-    for (i=0;i<nbLayers-1;i++)
-    {
-        net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
-        net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
-    }
-    double inMean[inDim];
-    for (j=0;j<inDim;j++)
-    {
-        double std=0;
-        inMean[j] = 0;
-        for (i=0;i<nbSamples;i++)
-        {
-            inMean[j] += inputs[i*inDim+j];
-            std += inputs[i*inDim+j]*inputs[i*inDim+j];
-        }
-        inMean[j] /= nbSamples;
-        std /= nbSamples;
-        net->in_rate[1+j] = .5/(.0001+std);
-        std = std-inMean[j]*inMean[j];
-        if (std<.001)
-            std = .001;
-        std = 1/sqrt(inDim*std);
-        for (k=0;k<topo[1];k++)
-            net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
-    }
-    net->in_rate[0] = 1;
-    for (j=0;j<topo[1];j++)
-    {
-        double sum = 0;
-        for (k=0;k<inDim;k++)
-            sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
-        net->weights[0][j*(topo[0]+1)] = -sum;
-    }
-    for (j=0;j<outDim;j++)
-    {
-        double mean = 0;
-        double std;
-        for (i=0;i<nbSamples;i++)
-            mean += outputs[i*outDim+j];
-        mean /= nbSamples;
-        std = 1/sqrt(topo[nbLayers-2]);
-        net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
-        for (k=0;k<topo[nbLayers-2];k++)
-            net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
-    }
-    return net;
-}
-
-#define MAX_NEURONS 100
-#define MAX_OUT 10
-
-double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
-{
-    int i,j;
-    int s;
-    int inDim, outDim, hiddenDim;
-    int *topo;
-    double *W0, *W1;
-    double rms=0;
-    int W0_size, W1_size;
-    double hidden[MAX_NEURONS];
-    double netOut[MAX_NEURONS];
-    double error[MAX_NEURONS];
-
-    topo = net->topo;
-    inDim = net->topo[0];
-    hiddenDim = net->topo[1];
-    outDim = net->topo[2];
-    W0_size = (topo[0]+1)*topo[1];
-    W1_size = (topo[1]+1)*topo[2];
-    W0 = net->weights[0];
-    W1 = net->weights[1];
-    memset(W0_grad, 0, W0_size*sizeof(double));
-    memset(W1_grad, 0, W1_size*sizeof(double));
-    for (i=0;i<outDim;i++)
-        netOut[i] = outputs[i];
-    for (i=0;i<outDim;i++)
-        error_rate[i] = 0;
-    for (s=0;s<nbSamples;s++)
-    {
-        float *in, *out;
-        float inp[inDim];
-        in = inputs+s*inDim;
-        out = outputs + s*outDim;
-        for (j=0;j<inDim;j++)
-           inp[j] = in[j];
-        for (i=0;i<hiddenDim;i++)
-        {
-            double sum = W0[i*(inDim+1)];
-            for (j=0;j<inDim;j++)
-                sum += W0[i*(inDim+1)+j+1]*inp[j];
-            hidden[i] = tansig_approx(sum);
-        }
-        for (i=0;i<outDim;i++)
-        {
-            double sum = W1[i*(hiddenDim+1)];
-            for (j=0;j<hiddenDim;j++)
-                sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
-            netOut[i] = tansig_approx(sum);
-            error[i] = out[i] - netOut[i];
-            if (out[i] == 0) error[i] *= .0;
-            error_rate[i] += fabs(error[i])>1;
-            if (i==0) error[i] *= 5;
-            rms += error[i]*error[i];
-            /*error[i] = error[i]/(1+fabs(error[i]));*/
-        }
-        /* Back-propagate error */
-        for (i=0;i<outDim;i++)
-        {
-            double grad = 1-netOut[i]*netOut[i];
-            W1_grad[i*(hiddenDim+1)] += error[i]*grad;
-            for (j=0;j<hiddenDim;j++)
-                W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
-        }
-        for (i=0;i<hiddenDim;i++)
-        {
-            double grad;
-            grad = 0;
-            for (j=0;j<outDim;j++)
-                grad += error[j]*W1[j*(hiddenDim+1)+i+1];
-            grad *= 1-hidden[i]*hidden[i];
-            W0_grad[i*(inDim+1)] += grad;
-            for (j=0;j<inDim;j++)
-                W0_grad[i*(inDim+1)+j+1] += grad*inp[j];
-        }
-    }
-    return rms;
-}
-
-#define NB_THREADS 8
-
-sem_t sem_begin[NB_THREADS];
-sem_t sem_end[NB_THREADS];
-
-struct GradientArg {
-    int id;
-    int done;
-    MLPTrain *net;
-    float *inputs;
-    float *outputs;
-    int nbSamples;
-    double *W0_grad;
-    double *W1_grad;
-    double rms;
-    double error_rate[MAX_OUT];
-};
-
-void *gradient_thread_process(void *_arg)
-{
-    int W0_size, W1_size;
-    struct GradientArg *arg = _arg;
-    int *topo = arg->net->topo;
-    W0_size = (topo[0]+1)*topo[1];
-    W1_size = (topo[1]+1)*topo[2];
-    double W0_grad[W0_size];
-    double W1_grad[W1_size];
-    arg->W0_grad = W0_grad;
-    arg->W1_grad = W1_grad;
-    while (1)
-    {
-        sem_wait(&sem_begin[arg->id]);
-        if (arg->done)
-            break;
-        arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
-        sem_post(&sem_end[arg->id]);
-    }
-    fprintf(stderr, "done\n");
-    return NULL;
-}
-
-float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
-{
-    int i, j;
-    int e;
-    float best_rms = 1e10;
-    int inDim, outDim, hiddenDim;
-    int *topo;
-    double *W0, *W1, *best_W0, *best_W1;
-    double *W0_grad, *W1_grad;
-    double *W0_oldgrad, *W1_oldgrad;
-    double *W0_rate, *W1_rate;
-    double *best_W0_rate, *best_W1_rate;
-    int W0_size, W1_size;
-    topo = net->topo;
-    W0_size = (topo[0]+1)*topo[1];
-    W1_size = (topo[1]+1)*topo[2];
-    struct GradientArg args[NB_THREADS];
-    pthread_t thread[NB_THREADS];
-    int samplePerPart = nbSamples/NB_THREADS;
-    int count_worse=0;
-    int count_retries=0;
-
-    topo = net->topo;
-    inDim = net->topo[0];
-    hiddenDim = net->topo[1];
-    outDim = net->topo[2];
-    W0 = net->weights[0];
-    W1 = net->weights[1];
-    best_W0 = net->best_weights[0];
-    best_W1 = net->best_weights[1];
-    W0_grad = malloc(W0_size*sizeof(double));
-    W1_grad = malloc(W1_size*sizeof(double));
-    W0_oldgrad = malloc(W0_size*sizeof(double));
-    W1_oldgrad = malloc(W1_size*sizeof(double));
-    W0_rate = malloc(W0_size*sizeof(double));
-    W1_rate = malloc(W1_size*sizeof(double));
-    best_W0_rate = malloc(W0_size*sizeof(double));
-    best_W1_rate = malloc(W1_size*sizeof(double));
-    memset(W0_grad, 0, W0_size*sizeof(double));
-    memset(W0_oldgrad, 0, W0_size*sizeof(double));
-    memset(W1_grad, 0, W1_size*sizeof(double));
-    memset(W1_oldgrad, 0, W1_size*sizeof(double));
-
-    rate /= nbSamples;
-    for (i=0;i<hiddenDim;i++)
-        for (j=0;j<inDim+1;j++)
-            W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
-    for (i=0;i<W1_size;i++)
-        W1_rate[i] = rate;
-
-    for (i=0;i<NB_THREADS;i++)
-    {
-        args[i].net = net;
-        args[i].inputs = inputs+i*samplePerPart*inDim;
-        args[i].outputs = outputs+i*samplePerPart*outDim;
-        args[i].nbSamples = samplePerPart;
-        args[i].id = i;
-        args[i].done = 0;
-        sem_init(&sem_begin[i], 0, 0);
-        sem_init(&sem_end[i], 0, 0);
-        pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
-    }
-    for (e=0;e<nbEpoch;e++)
-    {
-        double rms=0;
-        double error_rate[2] = {0,0};
-        for (i=0;i<NB_THREADS;i++)
-        {
-            sem_post(&sem_begin[i]);
-        }
-        memset(W0_grad, 0, W0_size*sizeof(double));
-        memset(W1_grad, 0, W1_size*sizeof(double));
-        for (i=0;i<NB_THREADS;i++)
-        {
-            sem_wait(&sem_end[i]);
-            rms += args[i].rms;
-            error_rate[0] += args[i].error_rate[0];
-            error_rate[1] += args[i].error_rate[1];
-            for (j=0;j<W0_size;j++)
-                W0_grad[j] += args[i].W0_grad[j];
-            for (j=0;j<W1_size;j++)
-                W1_grad[j] += args[i].W1_grad[j];
-        }
-
-        float mean_rate = 0, min_rate = 1e10;
-        rms = (rms/(outDim*nbSamples));
-        error_rate[0] = (error_rate[0]/(nbSamples));
-        error_rate[1] = (error_rate[1]/(nbSamples));
-        fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
-        if (rms < best_rms)
-        {
-            best_rms = rms;
-            for (i=0;i<W0_size;i++)
-            {
-                best_W0[i] = W0[i];
-                best_W0_rate[i] = W0_rate[i];
-            }
-            for (i=0;i<W1_size;i++)
-            {
-                best_W1[i] = W1[i];
-                best_W1_rate[i] = W1_rate[i];
-            }
-            count_worse=0;
-            count_retries=0;
-        } else {
-            count_worse++;
-            if (count_worse>30)
-            {
-                count_retries++;
-                count_worse=0;
-                for (i=0;i<W0_size;i++)
-                {
-                    W0[i] = best_W0[i];
-                    best_W0_rate[i] *= .7;
-                    if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
-                    W0_rate[i] = best_W0_rate[i];
-                    W0_grad[i] = 0;
-                }
-                for (i=0;i<W1_size;i++)
-                {
-                    W1[i] = best_W1[i];
-                    best_W1_rate[i] *= .8;
-                    if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
-                    W1_rate[i] = best_W1_rate[i];
-                    W1_grad[i] = 0;
-                }
-            }
-        }
-        if (count_retries>10)
-            break;
-        for (i=0;i<W0_size;i++)
-        {
-            if (W0_oldgrad[i]*W0_grad[i] > 0)
-                W0_rate[i] *= 1.01;
-            else if (W0_oldgrad[i]*W0_grad[i] < 0)
-                W0_rate[i] *= .9;
-            mean_rate += W0_rate[i];
-            if (W0_rate[i] < min_rate)
-                min_rate = W0_rate[i];
-            if (W0_rate[i] < 1e-15)
-                W0_rate[i] = 1e-15;
-            /*if (W0_rate[i] > .01)
-                W0_rate[i] = .01;*/
-            W0_oldgrad[i] = W0_grad[i];
-            W0[i] += W0_grad[i]*W0_rate[i];
-        }
-        for (i=0;i<W1_size;i++)
-        {
-            if (W1_oldgrad[i]*W1_grad[i] > 0)
-                W1_rate[i] *= 1.01;
-            else if (W1_oldgrad[i]*W1_grad[i] < 0)
-                W1_rate[i] *= .9;
-            mean_rate += W1_rate[i];
-            if (W1_rate[i] < min_rate)
-                min_rate = W1_rate[i];
-            if (W1_rate[i] < 1e-15)
-                W1_rate[i] = 1e-15;
-            W1_oldgrad[i] = W1_grad[i];
-            W1[i] += W1_grad[i]*W1_rate[i];
-        }
-        mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
-        fprintf (stderr, "%g %d", mean_rate, e);
-        if (count_retries)
-            fprintf(stderr, " %d", count_retries);
-        fprintf(stderr, "\n");
-        if (stopped)
-            break;
-    }
-    for (i=0;i<NB_THREADS;i++)
-    {
-        args[i].done = 1;
-        sem_post(&sem_begin[i]);
-        pthread_join(thread[i], NULL);
-        fprintf (stderr, "joined %d\n", i);
-    }
-    free(W0_grad);
-    free(W0_oldgrad);
-    free(W1_grad);
-    free(W1_oldgrad);
-    free(W0_rate);
-    free(best_W0_rate);
-    free(W1_rate);
-    free(best_W1_rate);
-    return best_rms;
-}
-
-int main(int argc, char **argv)
-{
-    int i, j;
-    int nbInputs;
-    int nbOutputs;
-    int nbHidden;
-    int nbSamples;
-    int nbEpoch;
-    int nbRealInputs;
-    unsigned int seed;
-    int ret;
-    float rms;
-    float *inputs;
-    float *outputs;
-    if (argc!=6)
-    {
-        fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
-        return 1;
-    }
-    nbInputs = atoi(argv[1]);
-    nbHidden = atoi(argv[2]);
-    nbOutputs = atoi(argv[3]);
-    nbSamples = atoi(argv[4]);
-    nbEpoch = atoi(argv[5]);
-    nbRealInputs = nbInputs;
-    inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
-    outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
-
-    seed = time(NULL);
-    /*seed = 1452209040;*/
-    fprintf (stderr, "Seed is %u\n", seed);
-    srand(seed);
-    build_tansig_table();
-    signal(SIGTERM, handler);
-    signal(SIGINT, handler);
-    signal(SIGHUP, handler);
-    for (i=0;i<nbSamples;i++)
-    {
-        for (j=0;j<nbRealInputs;j++)
-            ret = scanf(" %f", &inputs[i*nbInputs+j]);
-        for (j=0;j<nbOutputs;j++)
-            ret = scanf(" %f", &outputs[i*nbOutputs+j]);
-        if (feof(stdin))
-        {
-            nbSamples = i;
-            break;
-        }
-    }
-    int topo[3] = {nbInputs, nbHidden, nbOutputs};
-    MLPTrain *net;
-
-    fprintf (stderr, "Got %d samples\n", nbSamples);
-    net = mlp_init(topo, 3, inputs, outputs, nbSamples);
-    rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
-    printf ("#ifdef HAVE_CONFIG_H\n");
-    printf ("#include \"config.h\"\n");
-    printf ("#endif\n\n");
-    printf ("#include \"mlp.h\"\n\n");
-    printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
-    printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
-    printf ("\n/* hidden layer */\n");
-    for (i=0;i<(topo[0]+1)*topo[1];i++)
-    {
-        printf ("%gf,", net->weights[0][i]);
-        if (i%5==4)
-            printf("\n");
-        else
-            printf(" ");
-    }
-    printf ("\n/* output layer */\n");
-    for (i=0;i<(topo[1]+1)*topo[2];i++)
-    {
-        printf ("%gf,", net->weights[1][i]);
-        if (i%5==4)
-            printf("\n");
-        else
-            printf(" ");
-    }
-    printf ("};\n\n");
-    printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
-    printf ("const MLP net = {\n");
-    printf ("    3,\n");
-    printf ("    topo,\n");
-    printf ("    weights\n};\n");
-    return 0;
-}
diff --git a/src/mlp_train.h b/src/mlp_train.h
deleted file mode 100644 (file)
index 4940415..0000000
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright (c) 2008-2011 Octasic Inc.
-   Written by Jean-Marc Valin */
-/*
-   Redistribution and use in source and binary forms, with or without
-   modification, are permitted provided that the following conditions
-   are met:
-
-   - Redistributions of source code must retain the above copyright
-   notice, this list of conditions and the following disclaimer.
-
-   - Redistributions in binary form must reproduce the above copyright
-   notice, this list of conditions and the following disclaimer in the
-   documentation and/or other materials provided with the distribution.
-
-   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
-   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
-   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
-   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-*/
-
-#ifndef _MLP_TRAIN_H_
-#define _MLP_TRAIN_H_
-
-#include <math.h>
-#include <stdlib.h>
-
-double tansig_table[501];
-static inline double tansig_double(double x)
-{
-    return 2./(1.+exp(-2.*x)) - 1.;
-}
-static inline void build_tansig_table(void)
-{
-    int i;
-    for (i=0;i<501;i++)
-        tansig_table[i] = tansig_double(.04*(i-250));
-}
-
-static inline double tansig_approx(double x)
-{
-    int i;
-    double y, dy;
-    if (x>=10)
-        return 1;
-    if (x<=-10)
-        return -1;
-    i = lrint(25*x);
-    x -= .04*i;
-    y = tansig_table[250+i];
-    dy = 1-y*y;
-    y = y + x*dy*(1 - y*x);
-    return y;
-}
-
-static inline float randn(float sd)
-{
-   float U1, U2, S, x;
-   do {
-      U1 = ((float)rand())/RAND_MAX;
-      U2 = ((float)rand())/RAND_MAX;
-      U1 = 2*U1-1;
-      U2 = 2*U2-1;
-      S = U1*U1 + U2*U2;
-   } while (S >= 1 || S == 0.0f);
-   x = sd*sqrt(-2 * log(S) / S) * U1;
-   return x;
-}
-
-
-typedef struct {
-    int layers;
-    int *topo;
-    double **weights;
-    double **best_weights;
-    double *in_rate;
-} MLPTrain;
-
-
-#endif /* _MLP_TRAIN_H_ */
index 3770fc6..0494170 100644 (file)
@@ -1189,7 +1189,16 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_
     {
        int analysis_bandwidth;
        if (st->signal_type == OPUS_AUTO)
-          st->voice_ratio = (int)floor(.5+100*(1-analysis_info.music_prob));
+       {
+          float prob;
+          if (st->prev_mode == 0)
+             prob = analysis_info.music_prob;
+          else if (st->prev_mode == MODE_CELT_ONLY)
+             prob = analysis_info.music_prob_max;
+          else
+             prob = analysis_info.music_prob_min;
+          st->voice_ratio = (int)floor(.5+100*(1-prob));
+       }
 
        analysis_bandwidth = analysis_info.bandwidth;
        if (analysis_bandwidth<=12)