Speech/music discrimination (not used for anything yet)
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42         stopped = 1;
43         signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48         int i, j, k;
49         MLPTrain *net;
50         int inDim, outDim;
51         net = malloc(sizeof(*net));
52         net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53         for (i=0;i<nbLayers;i++)
54                 net->topo[i] = topo[i];
55         inDim = topo[0];
56         outDim = topo[nbLayers-1];
57         net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58         net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59         net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60         for (i=0;i<nbLayers-1;i++)
61         {
62                 net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63                 net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64         }
65         double inMean[inDim];
66         for (j=0;j<inDim;j++)
67         {
68                 double std=0;
69                 inMean[j] = 0;
70                 for (i=0;i<nbSamples;i++)
71                 {
72                         inMean[j] += inputs[i*inDim+j];
73                         std += inputs[i*inDim+j]*inputs[i*inDim+j];
74                 }
75                 inMean[j] /= nbSamples;
76                 std /= nbSamples;
77                 net->in_rate[1+j] = .5/(.0001+std);
78                 std = std-inMean[j]*inMean[j];
79                 if (std<.001)
80                         std = .001;
81                 std = 1/sqrt(inDim*std);
82                 for (k=0;k<topo[1];k++)
83                         net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84         }
85         net->in_rate[0] = 1;
86         for (j=0;j<topo[1];j++)
87         {
88                 double sum = 0;
89                 for (k=0;k<inDim;k++)
90                         sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91                 net->weights[0][j*(topo[0]+1)] = -sum;
92         }
93         for (j=0;j<outDim;j++)
94         {
95                 double mean = 0;
96                 double std;
97                 for (i=0;i<nbSamples;i++)
98                         mean += outputs[i*outDim+j];
99                 mean /= nbSamples;
100                 std = 1/sqrt(topo[nbLayers-2]);
101                 net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102                 for (k=0;k<topo[nbLayers-2];k++)
103                         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104         }
105         return net;
106 }
107
108 #define MAX_NEURONS 100
109
110 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
111 {
112         int i,j;
113         int s;
114         int inDim, outDim, hiddenDim;
115         int *topo;
116         double *W0, *W1;
117         double rms=0;
118         int W0_size, W1_size;
119         double hidden[MAX_NEURONS];
120         double netOut[MAX_NEURONS];
121         double error[MAX_NEURONS];
122
123         *error_rate = 0;
124         topo = net->topo;
125         inDim = net->topo[0];
126         hiddenDim = net->topo[1];
127         outDim = net->topo[2];
128         W0_size = (topo[0]+1)*topo[1];
129         W1_size = (topo[1]+1)*topo[2];
130         W0 = net->weights[0];
131         W1 = net->weights[1];
132         memset(W0_grad, 0, W0_size*sizeof(double));
133         memset(W1_grad, 0, W1_size*sizeof(double));
134         for (i=0;i<outDim;i++)
135                 netOut[i] = outputs[i];
136         for (s=0;s<nbSamples;s++)
137         {
138                 float *in, *out;
139                 in = inputs+s*inDim;
140                 out = outputs + s*outDim;
141                 for (i=0;i<hiddenDim;i++)
142                 {
143                         double sum = W0[i*(inDim+1)];
144                         for (j=0;j<inDim;j++)
145                                 sum += W0[i*(inDim+1)+j+1]*in[j];
146                         hidden[i] = tansig_approx(sum);
147                 }
148                 for (i=0;i<outDim;i++)
149                 {
150                         double sum = W1[i*(hiddenDim+1)];
151                         for (j=0;j<hiddenDim;j++)
152                                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
153                         netOut[i] = tansig_approx(sum);
154                         error[i] = out[i] - netOut[i];
155                         rms += error[i]*error[i];
156                         *error_rate += fabs(error[i])>1;
157                 }
158                 /* Back-propagate error */
159                 for (i=0;i<outDim;i++)
160                 {
161                         float grad = 1-netOut[i]*netOut[i];
162                         W1_grad[i*(hiddenDim+1)] += error[i]*grad;
163                         for (j=0;j<hiddenDim;j++)
164                                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
165                 }
166                 for (i=0;i<hiddenDim;i++)
167                 {
168                         double grad;
169                         grad = 0;
170                         for (j=0;j<outDim;j++)
171                                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
172                         grad *= 1-hidden[i]*hidden[i];
173                         W0_grad[i*(inDim+1)] += grad;
174                         for (j=0;j<inDim;j++)
175                                 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
176                 }
177         }
178         return rms;
179 }
180
181 #define NB_THREADS 8
182
183 sem_t sem_begin[NB_THREADS];
184 sem_t sem_end[NB_THREADS];
185
186 struct GradientArg {
187         int id;
188         int done;
189         MLPTrain *net;
190         float *inputs;
191         float *outputs;
192         int nbSamples;
193         double *W0_grad;
194         double *W1_grad;
195         double rms;
196         double error_rate;
197 };
198
199 void *gradient_thread_process(void *_arg)
200 {
201         int W0_size, W1_size;
202         struct GradientArg *arg = _arg;
203         int *topo = arg->net->topo;
204         W0_size = (topo[0]+1)*topo[1];
205         W1_size = (topo[1]+1)*topo[2];
206         double W0_grad[W0_size];
207         double W1_grad[W1_size];
208         arg->W0_grad = W0_grad;
209         arg->W1_grad = W1_grad;
210         while (1)
211         {
212                 sem_wait(&sem_begin[arg->id]);
213                 if (arg->done)
214                         break;
215                 arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, &arg->error_rate);
216                 sem_post(&sem_end[arg->id]);
217         }
218         fprintf(stderr, "done\n");
219         return NULL;
220 }
221
222 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
223 {
224         int i, j;
225         int e;
226         float best_rms = 1e10;
227         int inDim, outDim, hiddenDim;
228         int *topo;
229         double *W0, *W1, *best_W0, *best_W1;
230         double *W0_old, *W1_old;
231         double *W0_old2, *W1_old2;
232         double *W0_grad, *W1_grad;
233         double *W0_oldgrad, *W1_oldgrad;
234         double *W0_rate, *W1_rate;
235         double *best_W0_rate, *best_W1_rate;
236         int W0_size, W1_size;
237         topo = net->topo;
238         W0_size = (topo[0]+1)*topo[1];
239         W1_size = (topo[1]+1)*topo[2];
240         struct GradientArg args[NB_THREADS];
241         pthread_t thread[NB_THREADS];
242         int samplePerPart = nbSamples/NB_THREADS;
243         int count_worse=0;
244         int count_retries=0;
245
246         topo = net->topo;
247         inDim = net->topo[0];
248         hiddenDim = net->topo[1];
249         outDim = net->topo[2];
250         W0 = net->weights[0];
251         W1 = net->weights[1];
252         best_W0 = net->best_weights[0];
253         best_W1 = net->best_weights[1];
254         W0_old = malloc(W0_size*sizeof(double));
255         W1_old = malloc(W1_size*sizeof(double));
256         W0_old2 = malloc(W0_size*sizeof(double));
257         W1_old2 = malloc(W1_size*sizeof(double));
258         W0_grad = malloc(W0_size*sizeof(double));
259         W1_grad = malloc(W1_size*sizeof(double));
260         W0_oldgrad = malloc(W0_size*sizeof(double));
261         W1_oldgrad = malloc(W1_size*sizeof(double));
262         W0_rate = malloc(W0_size*sizeof(double));
263         W1_rate = malloc(W1_size*sizeof(double));
264         best_W0_rate = malloc(W0_size*sizeof(double));
265         best_W1_rate = malloc(W1_size*sizeof(double));
266         memcpy(W0_old, W0, W0_size*sizeof(double));
267         memcpy(W0_old2, W0, W0_size*sizeof(double));
268         memset(W0_grad, 0, W0_size*sizeof(double));
269         memset(W0_oldgrad, 0, W0_size*sizeof(double));
270         memcpy(W1_old, W1, W1_size*sizeof(double));
271         memcpy(W1_old2, W1, W1_size*sizeof(double));
272         memset(W1_grad, 0, W1_size*sizeof(double));
273         memset(W1_oldgrad, 0, W1_size*sizeof(double));
274         
275         rate /= nbSamples;
276         for (i=0;i<hiddenDim;i++)
277                 for (j=0;j<inDim+1;j++)
278                         W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
279         for (i=0;i<W1_size;i++)
280                 W1_rate[i] = rate;
281         
282         for (i=0;i<NB_THREADS;i++)
283         {
284                 args[i].net = net;
285                 args[i].inputs = inputs+i*samplePerPart*inDim;
286                 args[i].outputs = outputs+i*samplePerPart*outDim;
287                 args[i].nbSamples = samplePerPart;
288                 args[i].id = i;
289                 args[i].done = 0;
290                 sem_init(&sem_begin[i], 0, 0);
291                 sem_init(&sem_end[i], 0, 0);
292                 pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
293         }
294         for (e=0;e<nbEpoch;e++)
295         {
296                 double rms=0;
297                 double error_rate = 0;
298                 for (i=0;i<NB_THREADS;i++)
299                 {
300                         sem_post(&sem_begin[i]);
301                 }
302                 memset(W0_grad, 0, W0_size*sizeof(double));
303                 memset(W1_grad, 0, W1_size*sizeof(double));
304                 for (i=0;i<NB_THREADS;i++)
305                 {
306                         sem_wait(&sem_end[i]);
307                         rms += args[i].rms;
308                         error_rate += args[i].error_rate;
309                         for (j=0;j<W0_size;j++)
310                                 W0_grad[j] += args[i].W0_grad[j];
311                         for (j=0;j<W1_size;j++)
312                                 W1_grad[j] += args[i].W1_grad[j];
313                 }
314
315                 float mean_rate = 0, min_rate = 1e10;
316                 rms = (rms/(outDim*nbSamples));
317                 error_rate = (error_rate/(outDim*nbSamples));
318                 fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms);
319                 if (rms < best_rms)
320                 {
321                         best_rms = rms;
322                         for (i=0;i<W0_size;i++)
323                         {
324                                 best_W0[i] = W0[i];
325                                 best_W0_rate[i] = W0_rate[i];
326                         }
327                         for (i=0;i<W1_size;i++)
328                         {
329                                 best_W1[i] = W1[i];
330                                 best_W1_rate[i] = W1_rate[i];
331                         }
332                         count_worse=0;
333                         count_retries=0;
334                 } else {
335                         count_worse++;
336                         if (count_worse>30)
337                         {
338                             count_retries++;
339                                 count_worse=0;
340                                 for (i=0;i<W0_size;i++)
341                                 {
342                                         W0[i] = best_W0[i];
343                                         best_W0_rate[i] *= .7;
344                                         if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
345                                         W0_rate[i] = best_W0_rate[i];
346                                         W0_grad[i] = 0;
347                                 }
348                                 for (i=0;i<W1_size;i++)
349                                 {
350                                         W1[i] = best_W1[i];
351                                         best_W1_rate[i] *= .8;
352                                         if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
353                                         W1_rate[i] = best_W1_rate[i];
354                                         W1_grad[i] = 0;
355                                 }
356                         }
357                 }
358                 if (count_retries>10)
359                     break;
360                 for (i=0;i<W0_size;i++)
361                 {
362                         if (W0_oldgrad[i]*W0_grad[i] > 0)
363                                 W0_rate[i] *= 1.01;
364                         else if (W0_oldgrad[i]*W0_grad[i] < 0)
365                                 W0_rate[i] *= .9;
366                         mean_rate += W0_rate[i];
367                         if (W0_rate[i] < min_rate)
368                                 min_rate = W0_rate[i];
369                         if (W0_rate[i] < 1e-15)
370                                 W0_rate[i] = 1e-15;
371                         /*if (W0_rate[i] > .01)
372                                 W0_rate[i] = .01;*/
373                         W0_oldgrad[i] = W0_grad[i];
374                         W0_old2[i] = W0_old[i];
375                         W0_old[i] = W0[i];
376                         W0[i] += W0_grad[i]*W0_rate[i];
377                 }
378                 for (i=0;i<W1_size;i++)
379                 {
380                         if (W1_oldgrad[i]*W1_grad[i] > 0)
381                                 W1_rate[i] *= 1.01;
382                         else if (W1_oldgrad[i]*W1_grad[i] < 0)
383                                 W1_rate[i] *= .9;
384                         mean_rate += W1_rate[i];
385                         if (W1_rate[i] < min_rate)
386                                 min_rate = W1_rate[i];
387                         if (W1_rate[i] < 1e-15)
388                                 W1_rate[i] = 1e-15;
389                         W1_oldgrad[i] = W1_grad[i];
390                         W1_old2[i] = W1_old[i];
391                         W1_old[i] = W1[i];
392                         W1[i] += W1_grad[i]*W1_rate[i];
393                 }
394                 mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
395                 fprintf (stderr, "%g %d", mean_rate, e);
396                 if (count_retries)
397                     fprintf(stderr, " %d", count_retries);
398                 fprintf(stderr, "\n");
399                 if (stopped)
400                         break;
401         }
402         for (i=0;i<NB_THREADS;i++)
403         {
404                 args[i].done = 1;
405                 sem_post(&sem_begin[i]);
406                 pthread_join(thread[i], NULL);
407                 fprintf (stderr, "joined %d\n", i);
408         }
409         free(W0_old);
410         free(W1_old);
411         free(W0_grad);
412         free(W1_grad);
413         free(W0_rate);
414         free(W1_rate);
415         return best_rms;
416 }
417
418 int main(int argc, char **argv)
419 {
420         int i, j;
421         int nbInputs;
422         int nbOutputs;
423         int nbHidden;
424         int nbSamples;
425         int nbEpoch;
426         int nbRealInputs;
427         unsigned int seed;
428         int ret;
429         float rms;
430         float *inputs;
431         float *outputs;
432         if (argc!=6)
433         {
434                 fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
435                 return 1;
436         }
437         nbInputs = atoi(argv[1]);
438         nbHidden = atoi(argv[2]);
439         nbOutputs = atoi(argv[3]);
440         nbSamples = atoi(argv[4]);
441         nbEpoch = atoi(argv[5]);
442         nbRealInputs = nbInputs;
443         inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
444         outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
445         
446         seed = time(NULL);
447         fprintf (stderr, "Seed is %u\n", seed);
448         srand(seed);
449         build_tansig_table();
450         signal(SIGTERM, handler);
451         signal(SIGINT, handler);
452         signal(SIGHUP, handler);
453         for (i=0;i<nbSamples;i++)
454         {
455                 for (j=0;j<nbRealInputs;j++)
456                         ret = scanf(" %f", &inputs[i*nbInputs+j]);
457                 for (j=0;j<nbOutputs;j++)
458                         ret = scanf(" %f", &outputs[i*nbOutputs+j]);
459                 if (feof(stdin))
460                 {
461                         nbSamples = i;
462                         break;
463                 }
464         }
465         int topo[3] = {nbInputs, nbHidden, nbOutputs};
466         MLPTrain *net;
467
468         fprintf (stderr, "Got %d samples\n", nbSamples);
469         net = mlp_init(topo, 3, inputs, outputs, nbSamples);
470         rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
471         printf ("#include \"mlp.h\"\n\n");
472         printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
473         printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
474         printf ("\n/* hidden layer */\n");
475         for (i=0;i<(topo[0]+1)*topo[1];i++)
476         {
477                 printf ("%g, ", net->weights[0][i]);
478                 if (i%5==4)
479                         printf("\n");
480         }
481         printf ("\n/* output layer */\n");
482         for (i=0;i<(topo[1]+1)*topo[2];i++)
483         {
484                 printf ("%g, ", net->weights[1][i]);
485                 if (i%5==4)
486                         printf("\n");
487         }
488         printf ("};\n\n");
489         printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
490         printf ("const MLP net = {\n");
491         printf ("\t3,\n");
492         printf ("\ttopo,\n");
493         printf ("\tweights\n};\n");
494         return 0;
495 }