Adds low-energy feature, training with noise
authorJean-Marc Valin <jmvalin@jmvalin.ca>
Wed, 23 Nov 2011 18:52:44 +0000 (13:52 -0500)
committerJean-Marc Valin <jmvalin@jmvalin.ca>
Fri, 13 Jul 2012 18:50:35 +0000 (14:50 -0400)
src/analysis.c
src/mlp_data.c
src/mlp_train.c

index e55d68a..8f2971e 100644 (file)
@@ -84,7 +84,11 @@ typedef struct {
    float cmean[8];
    float std[9];
    float music_prob;
+   float Etracker;
+   float lowECount;
    int E_count;
+   int last_music;
+   int last_transition;
    int count;
 } TonalityAnalysisState;
 
@@ -103,7 +107,7 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
     float band_tonality[NB_TBANDS];
     float logE[NB_TBANDS];
     float BFCC[8];
-    float features[27];
+    float features[100];
     float frame_tonality;
     float frame_noisiness;
     const float pi4 = M_PI*M_PI*M_PI*M_PI;
@@ -111,10 +115,13 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
     float frame_stationarity;
     float relativeE;
     float frame_prob;
-    float alpha;
+    float alpha, alphaE;
+    float frame_loudness;
     celt_encoder_ctl(celt_enc, CELT_GET_MODE(&mode));
 
+    tonal->last_transition++;
     alpha = 1.f/IMIN(20, 1+tonal->count);
+    alphaE = 1.f/IMIN(50, 1+tonal->count);
 
     if (tonal->count<4)
        tonal->music_prob = .5;
@@ -193,6 +200,7 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
     relativeE = 0;
     info->boost_amount[0]=info->boost_amount[1]=0;
     info->boost_band[0]=info->boost_band[1]=0;
+    frame_loudness = 0;
     for (b=0;b<NB_TBANDS;b++)
     {
        float E=0, tE=0, nE=0;
@@ -209,7 +217,9 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
        tonal->E[tonal->E_count][b] = E;
        frame_noisiness += nE/(1e-15+E);
 
-       logE[b] = log(E+EPSILON);
+       frame_loudness += sqrt(E+1e-10);
+       /* Add a reasonable noise floor */
+       logE[b] = log(E+1e-10);
        tonal->lowE[b] = MIN32(logE[b], tonal->lowE[b]+.01);
        tonal->highE[b] = MAX32(logE[b], tonal->highE[b]-.1);
        if (tonal->highE[b] < tonal->lowE[b]+1)
@@ -250,6 +260,11 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
        }
        tonal->prev_band_tonality[b] = band_tonality[b];
     }
+    frame_loudness = 20*log10(frame_loudness);
+    tonal->Etracker = MAX32(tonal->Etracker-.03, frame_loudness);
+    tonal->lowECount *= (1-alphaE);
+    if (frame_loudness < tonal->Etracker-30)
+       tonal->lowECount += alphaE;
 
     for (i=0;i<8;i++)
     {
@@ -288,21 +303,21 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
     tonal->count++;
     info->tonality = frame_tonality;
 
-    for (i=0;i<5;i++)
+    for (i=0;i<4;i++)
        features[i] = -0.12299*(BFCC[i]+tonal->mem[i+24]) + 0.49195*(tonal->mem[i]+tonal->mem[i+16]) + 0.69693*tonal->mem[i+8] - 1.4349*tonal->cmean[i];
 
-    for (i=0;i<5;i++)
-       tonal->cmean[i] = (1-alpha)*tonal->cmean[i] + alpha*(i==0)*BFCC[i];
+    for (i=0;i<4;i++)
+       tonal->cmean[i] = (1-alpha)*tonal->cmean[i] + alpha*BFCC[i];
 
-    for (i=0;i<5;i++)
-        features[5+i] = 0.63246*(BFCC[i]-tonal->mem[i+24]) + 0.31623*(tonal->mem[i]-tonal->mem[i+16]);
     for (i=0;i<4;i++)
-        features[10+i] = 0.53452*(BFCC[i]+tonal->mem[i+24]) - 0.26726*(tonal->mem[i]+tonal->mem[i+16]) -0.53452*tonal->mem[i+8];
+        features[4+i] = 0.63246*(BFCC[i]-tonal->mem[i+24]) + 0.31623*(tonal->mem[i]-tonal->mem[i+16]);
+    for (i=0;i<3;i++)
+        features[8+i] = 0.53452*(BFCC[i]+tonal->mem[i+24]) - 0.26726*(tonal->mem[i]+tonal->mem[i+16]) -0.53452*tonal->mem[i+8];
 
     if (tonal->count > 5)
     {
        for (i=0;i<9;i++)
-          tonal->std[i] = (1-alpha)*tonal->std[i] + alpha*features[5+i]*features[5+i];
+          tonal->std[i] = (1-alpha)*tonal->std[i] + alpha*features[i]*features[i];
     }
 
     for (i=0;i<8;i++)
@@ -312,36 +327,45 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info, CELTEnc
        tonal->mem[i+8] = tonal->mem[i];
        tonal->mem[i] = BFCC[i];
     }
-    features[14] = info->tonality;
-    features[15] = info->activity;
-    features[16] = frame_stationarity;
-    features[17] = info->tonality_slope;
-
     for (i=0;i<9;i++)
-       features[18+i] = sqrt(tonal->std[i]);
+       features[11+i] = sqrt(tonal->std[i]);
+    features[20] = info->tonality;
+    features[21] = info->activity;
+    features[22] = frame_stationarity;
+    features[23] = info->tonality_slope;
+    features[24] = tonal->lowECount;
+
 #ifndef FIXED_POINT
     mlp_process(&net, features, &frame_prob);
+    /* Adds a "probability dead zone", with a cap on certainty */
+    frame_prob = .90*frame_prob*frame_prob*frame_prob;
+
     frame_prob = .5*(frame_prob+1);
-    frame_prob = MAX16(.01f, MIN16(0.99f, frame_prob));
-    /*frame_prob = .45*frame_prob + .55*frame_prob*frame_prob*frame_prob;*/
+
     /*printf("%f\n", frame_prob);*/
     {
        float tau, beta;
        float p0, p1;
-       tau = .0001;
+       float max_certainty;
+       /* One transition every 3 minutes */
+       tau = .00005;
        beta = .1;
+       max_certainty = 1.f/(10+1*tonal->last_transition);
        p0 = (1-tonal->music_prob)*(1-tau) +    tonal->music_prob *tau;
        p1 =    tonal->music_prob *(1-tau) + (1-tonal->music_prob)*tau;
        p0 *= pow(1-frame_prob, beta);
        p1 *= pow(frame_prob, beta);
-       tonal->music_prob = MAX16(0.01f, MIN16(0.99f, p1/(p0+p1)));
+       tonal->music_prob = MAX16(max_certainty, MIN16(1-max_certainty, p1/(p0+p1)));
        info->music_prob = tonal->music_prob;
        /*printf("%f %f\n", frame_prob, info->music_prob);*/
     }
+    if (tonal->last_music != (tonal->music_prob>.5))
+       tonal->last_transition=0;
+    tonal->last_music = tonal->music_prob>.5;
 #else
     info->music_prob = 0;
 #endif
-    /*for (i=0;i<27;i++)
+    /*for (i=0;i<25;i++)
        printf("%f ", features[i]);
     printf("\n");*/
 
index 2a18349..5f7c5fb 100644 (file)
 #include "mlp.h"
 
-/* RMS error was 0.069845, seed was 1322025605 */
+/* RMS error was 0.213119, seed was 1322073261 */
 
-static const float weights[581] = {
+static const float weights[271] = {
 
 /* hidden layer */
-1.52937, 0.240116, -0.0808422, -0.019036, -0.199579, 
--0.0777284, 0.0506183, -0.0155793, -0.181522, 0.03296, 
--0.0133165, 0.179916, -0.122064, -0.0236821, -0.192921, 
--7.41934, -0.1775, 3.17407, 6.74356, -0.976582, 
--0.206271, -2.36372, -0.187823, 0.0721386, 0.182445, 
-1.05044, 0.0536177, -0.079352, -7.53152, -0.000840837, 
-0.16012, -0.737235, 0.907956, -0.977085, -0.20257, 
-0.240006, -0.125651, 0.0292286, 0.0881101, -0.00420089, 
-0.144169, -0.411473, 0.518353, 24.7599, 7.13795, 
--0.198564, -8.46416, 1.62151, 1.97394, 1.06731, 
-0.212259, -0.294498, -1.43245, -1.7405, -2.24067, 
-0.108053, -7.95142, -0.124435, -0.0220739, -0.0720463, 
-0.0344748, 0.0543195, -0.0278653, -0.00880633, -0.0800235, 
--0.0628277, 0.018996, -0.136465, -0.0202844, -0.0871884, 
--0.00116258, -1.47936, 0.305648, 2.91397, -0.628477, 
-0.379082, 1.01097, 0.108782, 0.719828, 1.80274, 
--0.180042, -1.30241, -0.645234, 0.208385, 0.838486, 
-0.214029, -0.00758414, 0.418987, 0.325509, -0.350113, 
--0.229126, 0.0648517, -0.061675, -0.146279, -0.0269004, 
-0.117271, -0.105326, 0.345759, 0.143439, -3.98419, 
-0.755833, -1.65706, -0.890625, 1.08994, 0.87214, 
-1.44122, -0.243985, -0.952904, -0.838194, -1.98792, 
--0.638013, 0.597385, 5.6797, 0.0436189, -0.013998, 
-0.0319089, -0.0968088, -0.0172178, 0.0481363, -0.0418244, 
--0.0131255, -0.0462831, 0.0483697, 0.0402757, -0.127525, 
-0.0414987, -0.187244, 3.89027, -4.62985, 0.516058, 
--1.22865, -0.190653, -0.165429, 0.100515, -1.03228, 
--0.321134, 0.0992221, 0.606127, 0.0554443, 1.45482, 
--0.253561, 0.096237, -0.0246694, 0.371363, -0.0914336, 
-0.103553, 0.0189161, -0.0105249, 0.0604053, 0.0411006, 
-0.0639854, 0.109094, -0.00979879, 0.324731, -0.0567231, 
--2.01856, 1.16363, 1.30608, -0.289826, 0.159106, 
--0.0876134, 0.327582, 0.0923214, -0.247681, 0.253649, 
-0.138327, -0.872563, -2.63221, 0.86812, 0.0144657, 
-0.0157262, 0.0286942, 0.0295632, 0.0478701, -0.00398791, 
-0.0521685, 0.0127316, -0.0668749, -0.0383492, -0.00951385, 
--0.00556075, 0.0322671, 0.00116312, 10.1702, 0.908796, 
-0.344389, -9.78791, -0.0103152, -0.814278, 0.224774, 
--0.84427, -1.20048, 0.111694, 1.02544, 0.475563, 
-0.18777, -1.12989, 0.0435492, -0.121442, 0.275871, 
-0.16293, -0.544925, 0.508003, 1.13018, 0.472551, 
--0.270614, 0.65337, 0.57463, 0.597287, 0.351477, 
-0.122146, -13.7731, -1.05173, -0.607099, -1.8646, 
--0.185848, -0.17991, -0.247395, 0.270694, 0.674929, 
-0.138423, 0.542756, -0.399379, -2.26304, 0.474791, 
-0.107511, 0.280356, -0.76117, -0.711896, 2.34404, 
--0.246417, -0.0171513, -0.743835, 0.221894, -0.0688801, 
-0.769435, -0.172876, 0.0275787, -1.09694, 43.5206, 
-4.98303, 9.28984, -9.01778, 0.412928, 0.797472, 
-0.504419, 0.143125, -2.30534, 0.187757, 0.0600608, 
-0.0396294, 2.83625, -2.27696, 0.0577414, 0.0259672, 
--0.20545, 0.0843962, -0.241982, -0.181173, -0.0303534, 
-0.368835, -0.325961, 0.300561, -0.0341177, 0.0938862, 
--0.123384, -0.0835186, -6.73898, 1.21737, 2.19072, 
-1.04479, 0.0516565, -0.809079, 0.149924, -0.397419, 
--2.47109, -0.320068, -0.0366975, -0.18933, 0.952215, 
-2.08882, 0.0150159, 0.00712614, 0.139391, -0.0632642, 
--0.00458523, 0.0274453, 0.00354731, -0.014494, -0.000608929, 
-0.0232959, 0.00615738, 0.0222414, 0.109995, -0.0635244, 
-1.51811, -0.00460887, 0.899197, 1.56449, -0.00806591, 
-0.310737, -0.441759, -0.615191, 0.305784, -0.118165, 
--0.690831, -0.932832, -0.468231, 0.869044, 0.0748202, 
--0.0780407, 0.0780088, 0.024609, -0.00519675, -0.0687518, 
-0.044041, -0.0570666, 0.037407, -0.14082, 0.0348575, 
--0.14069, 0.180557, -0.0571276, -32.4574, -0.0710406, 
--5.39569, 2.65794, 0.181025, 0.493114, 0.111346, 
-0.506378, 1.16452, -0.563642, -0.278853, -0.447802, 
-0.232193, -6.42728, -0.109856, 0.022866, -0.0839836, 
--0.0839169, -0.319109, -0.902373, -0.157901, 0.298015, 
--0.16787, 0.0928949, -0.71114, 0.0373198, -0.0722619, 
--0.122185, -17.1527, 4.88383, 2.38947, 8.88169, 
--0.00947956, 0.0823654, 0.799126, 1.28023, 0.526565, 
--0.0149172, -1.12657, -0.24462, -3.23915, 0.0058726, 
-0.127453, 0.29968, -0.208872, 0.0242737, 0.479791, 
--3.21354, 1.52516, -0.692431, -0.165378, -0.731346, 
-0.314575, -0.569414, -0.0801118, 0.086923, 8.02887, 
--0.235296, -0.276748, -7.72231, -0.115556, -0.976338, 
-0.0980647, -2.55159, -0.410249, 0.968147, -0.131815, 
--0.511169, 0.0891097, -1.20927, -0.013714, 0.210441, 
-0.0838065, 0.161028, -0.01217, -0.00352592, 0.0893854, 
--0.0787796, 0.0651729, -0.0219344, -0.000346421, 0.180829, 
-0.0847809, 0.116443, 0.0632044, -1.31752, 1.0611, 
--1.73675, -0.0728349, 0.183918, 0.0451227, -0.525749, 
--0.255804, 0.10796, -0.300506, 0.595564, 0.697484, 
--1.05565, 0.0174366, 0.273086, 0.00330815, 0.133234, 
--0.0170445, 0.385873, -0.392262, 0.441475, -0.244255, 
-0.159125, 0.102305, 0.271859, -0.0477384, 0.0994522, 
-12.1298, -1.23951, 1.33346, -0.669615, 0.0928112, 
-0.301609, -0.0872949, -0.530852, -0.791418, -0.211807, 
-0.233519, 0.00779643, -0.208301, -11.4151, -0.0142366, 
-0.0636537, 0.0716135, -0.0404406, -0.116361, -0.172063, 
--0.229208, -0.0866464, -0.134274, 0.0473229, -0.190149, 
--0.112906, -0.0548081, -0.15995, -0.761112, 6.38996, 
-0.687383, -5.54174, 0.27056, 0.156832, 0.410845, 
-1.73844, 0.868191, -0.415126, -0.362902, -0.488269, 
--0.679177, 2.90133, 0.0325332, -0.13787, -0.099454, 
--0.125178, 0.0312495, 0.0215725, -0.0180029, 0.000786626, 
--0.0165868, 0.00228741, 0.0489981, -0.142374, -0.070654, 
--0.0980396, -5.4804, -1.42786, 0.370574, 5.29591, 
--0.103996, -0.291983, -0.220591, -0.445399, 0.153502, 
-0.0372166, 0.181633, 0.0616784, 0.69087, -0.246005, 
-0.179546, -0.0553659, 0.0486791, -0.429761, 0.379239, 
--0.293269, 0.370228, -0.232783, 0.192865, -0.153421, 
-0.412135, -0.181689, 0.0816143, -0.245055, 0.851942, 
--0.771025, -1.52187, -0.656314, 0.638177, 0.882082, 
--0.466803, -0.193064, 0.233788, 1.40801, 0.388835, 
--0.206663, 1.3398, -10.2146, -0.086454, -0.33979, 
--0.11021, -0.044777, 0.0383833, -0.171588, -0.0686855, 
-0.0452209, 0.0578143, 0.0676435, -0.24006, -0.40246, 
--0.0549284, 0.00786321, -11.5892, 13.4008, -0.148449, 
-1.4845, -0.198285, -0.280658, -2.56881, -0.597918, 
-0.0713039, -0.0129557, 0.907337, -1.36003, 0.357266, 
+0.693025, 0.12016, -0.2263, 0.254033, -0.128153, 
+0.14498, -0.139098, 0.160911, -0.101749, -0.0495703, 
+-0.064263, 0.0583359, 0.00431816, 0.137356, 0.199892, 
+0.0859346, -0.159615, -0.109472, -0.483732, -0.0564564, 
+-0.1935, 1.94157, 0.975237, -1.99915, -0.566324, 
+0.457616, 1.05455, -0.0896938, 0.119609, -0.135437, 
+0.0485118, -0.168008, 0.099568, -0.164749, 0.0837494, 
+0.043713, -0.00298809, 0.0312529, 0.0128313, 0.0217679, 
+-0.0185004, 0.0452, -0.0287335, -0.0872755, 0.145857, 
+-0.534325, 0.255553, 1.68378, -0.212433, -0.416954, 
+0.558858, 0.295516, 2.43878, 0.105586, 0.0929094, 
+0.0461483, 0.0715373, 0.054567, 0.0499543, -0.0589255, 
+0.0518142, 0.143022, 0.0927234, 0.0891022, -0.0743216, 
+-0.0497012, -0.150656, 0.0945407, -0.438528, -0.272849, 
+-0.613949, -0.0907026, -0.0846951, -3.10146, -1.64289, 
+-0.538651, -1.76321, -2.90649, -0.738564, 0.103786, 
+0.0706484, 0.194004, 0.0256865, -0.00755591, -0.0148712, 
+0.0249828, -0.0288777, 0.0913917, 0.0510661, 0.117025, 
+-0.00722344, -0.175031, 0.283402, -0.312663, 0.19557, 
+0.0678725, -0.0488421, 0.296382, 0.123097, 6.72257, 
+0.296421, 0.219984, -12.444, -1.80343, 4.56011, 
+0.0481676, 0.027474, 0.0752634, 0.0243683, -0.0166126, 
+0.00484121, 0.00782469, -0.0210589, -0.000731938, -0.0244603, 
+-0.0222773, -0.0433534, -0.250382, -0.118675, -1.05746, 
+-0.217, -0.0425872, -0.231335, -0.500574, 0.477752, 
+0.247352, 0.611455, 0.956451, 4.80434, 0.670325, 
+-0.394506, 0.0344075, 0.148722, 0.136525, 0.110126, 
+-0.0668243, 0.012418, -0.0307898, 0.0666452, -0.0171564, 
+0.136063, 0.0268092, -0.0725383, 0.0749385, 0.0845712, 
+0.48549, 0.0295497, 0.276011, 0.479614, 0.640422, 
+-0.245035, -0.751258, -1.0434, -3.62517, -2.84518, 
+3.95485, -1.64489, 0.209162, -0.132563, 0.252206, 
+-0.135669, 0.0428919, -0.0799343, 0.0932631, -0.0781831, 
+0.189111, -0.0977977, 0.110201, 0.0235682, 0.0302204, 
+-0.240156, -0.289576, 0.136261, 0.363298, 0.484825, 
+0.389438, 0.193011, -0.566411, -0.000798943, 1.53221, 
+-2.99065, -0.226329, 2.19154, -0.00811221, -0.0912965, 
+-0.0275327, -0.0861904, 0.00839052, 0.0904659, -0.00893919, 
+0.0360984, -0.015071, -0.0647292, -0.0734295, -0.0261419, 
+-0.402868, 0.176918, -0.212574, -0.00624752, -0.0633636, 
+0.406343, -0.274925, 0.432039, 4.29364, 0.231697, 
+-1.06323, 2.08855, -30.2379, 0.0637952, 0.120849, 
+-0.0639096, 0.13732, -0.180987, 0.191846, 0.368452, 
+0.209571, 0.164176, -0.0378771, -0.11027, 0.0156665, 
+0.0207418, 0.0259591, -0.0239737, -0.00352567, -0.0545034, 
+0.0656311, -0.0817674, 0.521953, 0.241775, 7.41463, 
+0.765259, 0.863685, -2.13211, 0.310299, 2.10354, 
+0.0360669, -0.171618, -0.148112, -0.279292, 0.709935, 
+0.845869, 0.4868, 0.224496, 0.343149, 0.202645, 
+-0.091977, 0.0201541, 0.077029, 0.0259409, 0.0167369, 
+0.0780503, 0.118852, -0.18355, -0.158253, 0.145728, 
+5.04716, -0.126459, 1.02251, 1.57638, 0.0646518, 
 
 /* output layer */
--9.51428, -0.855928, 0.674433, -1.45903, -1.15718, 
-3.25902, -0.85739, 1.45401, -0.346373, 0.563214, 
--0.97603, 1.52396, -0.804053, -0.56299, 0.213345, 
--2.50068, -1.06777, -1.073, -2.88991, -1.10272, 
--2.77165, };
+-3.6362, 1.96839, 1.93382, -2.10686, 2.27415, 
+0.599088, -0.615374, -2.02317, 0.383689, 0.360266, 
+0.532961, };
 
-static const int topo[3] = {27, 20, 1};
+static const int topo[3] = {25, 10, 1};
 
 const MLP net = {
        3,
index a9d548b..f0fb40d 100644 (file)
@@ -154,6 +154,7 @@ double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamp
                        error[i] = out[i] - netOut[i];
                        rms += error[i]*error[i];
                        *error_rate += fabs(error[i])>1;
+                        //error[i] = error[i]/(1+fabs(error[i]));
                }
                /* Back-propagate error */
                for (i=0;i<outDim;i++)