a41f927114e9f9630ee6e54dca0b1258bfbf14dd
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42     stopped = 1;
43     signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48     int i, j, k;
49     MLPTrain *net;
50     int inDim, outDim;
51     net = malloc(sizeof(*net));
52     net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53     for (i=0;i<nbLayers;i++)
54         net->topo[i] = topo[i];
55     inDim = topo[0];
56     outDim = topo[nbLayers-1];
57     net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58     net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59     net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60     for (i=0;i<nbLayers-1;i++)
61     {
62         net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63         net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64     }
65     double inMean[inDim];
66     for (j=0;j<inDim;j++)
67     {
68         double std=0;
69         inMean[j] = 0;
70         for (i=0;i<nbSamples;i++)
71         {
72             inMean[j] += inputs[i*inDim+j];
73             std += inputs[i*inDim+j]*inputs[i*inDim+j];
74         }
75         inMean[j] /= nbSamples;
76         std /= nbSamples;
77         net->in_rate[1+j] = .5/(.0001+std);
78         std = std-inMean[j]*inMean[j];
79         if (std<.001)
80             std = .001;
81         std = 1/sqrt(inDim*std);
82         for (k=0;k<topo[1];k++)
83             net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84     }
85     net->in_rate[0] = 1;
86     for (j=0;j<topo[1];j++)
87     {
88         double sum = 0;
89         for (k=0;k<inDim;k++)
90             sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91         net->weights[0][j*(topo[0]+1)] = -sum;
92     }
93     for (j=0;j<outDim;j++)
94     {
95         double mean = 0;
96         double std;
97         for (i=0;i<nbSamples;i++)
98             mean += outputs[i*outDim+j];
99         mean /= nbSamples;
100         std = 1/sqrt(topo[nbLayers-2]);
101         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102         for (k=0;k<topo[nbLayers-2];k++)
103             net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104     }
105     return net;
106 }
107
108 #define MAX_NEURONS 100
109 #define MAX_OUT 10
110
111 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112 {
113     int i,j;
114     int s;
115     int inDim, outDim, hiddenDim;
116     int *topo;
117     double *W0, *W1;
118     double rms=0;
119     int W0_size, W1_size;
120     double hidden[MAX_NEURONS];
121     double netOut[MAX_NEURONS];
122     double error[MAX_NEURONS];
123
124     topo = net->topo;
125     inDim = net->topo[0];
126     hiddenDim = net->topo[1];
127     outDim = net->topo[2];
128     W0_size = (topo[0]+1)*topo[1];
129     W1_size = (topo[1]+1)*topo[2];
130     W0 = net->weights[0];
131     W1 = net->weights[1];
132     memset(W0_grad, 0, W0_size*sizeof(double));
133     memset(W1_grad, 0, W1_size*sizeof(double));
134     for (i=0;i<outDim;i++)
135         netOut[i] = outputs[i];
136     for (i=0;i<outDim;i++)
137         error_rate[i] = 0;
138     for (s=0;s<nbSamples;s++)
139     {
140         float *in, *out;
141         float inp[inDim];
142         in = inputs+s*inDim;
143         out = outputs + s*outDim;
144         for (j=0;j<inDim;j++)
145            inp[j] = in[j];
146         for (i=0;i<hiddenDim;i++)
147         {
148             double sum = W0[i*(inDim+1)];
149             for (j=0;j<inDim;j++)
150                 sum += W0[i*(inDim+1)+j+1]*inp[j];
151             hidden[i] = tansig_approx(sum);
152         }
153         for (i=0;i<outDim;i++)
154         {
155             double sum = W1[i*(hiddenDim+1)];
156             for (j=0;j<hiddenDim;j++)
157                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
158             netOut[i] = tansig_approx(sum);
159             error[i] = out[i] - netOut[i];
160             if (out[i] == 0) error[i] *= .0;
161             error_rate[i] += fabs(error[i])>1;
162             if (i==0) error[i] *= 5;
163             rms += error[i]*error[i];
164             /*error[i] = error[i]/(1+fabs(error[i]));*/
165         }
166         /* Back-propagate error */
167         for (i=0;i<outDim;i++)
168         {
169             double grad = 1-netOut[i]*netOut[i];
170             W1_grad[i*(hiddenDim+1)] += error[i]*grad;
171             for (j=0;j<hiddenDim;j++)
172                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
173         }
174         for (i=0;i<hiddenDim;i++)
175         {
176             double grad;
177             grad = 0;
178             for (j=0;j<outDim;j++)
179                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
180             grad *= 1-hidden[i]*hidden[i];
181             W0_grad[i*(inDim+1)] += grad;
182             for (j=0;j<inDim;j++)
183                 W0_grad[i*(inDim+1)+j+1] += grad*inp[j];
184         }
185     }
186     return rms;
187 }
188
189 #define NB_THREADS 8
190
191 sem_t sem_begin[NB_THREADS];
192 sem_t sem_end[NB_THREADS];
193
194 struct GradientArg {
195     int id;
196     int done;
197     MLPTrain *net;
198     float *inputs;
199     float *outputs;
200     int nbSamples;
201     double *W0_grad;
202     double *W1_grad;
203     double rms;
204     double error_rate[MAX_OUT];
205 };
206
207 void *gradient_thread_process(void *_arg)
208 {
209     int W0_size, W1_size;
210     struct GradientArg *arg = _arg;
211     int *topo = arg->net->topo;
212     W0_size = (topo[0]+1)*topo[1];
213     W1_size = (topo[1]+1)*topo[2];
214     double W0_grad[W0_size];
215     double W1_grad[W1_size];
216     arg->W0_grad = W0_grad;
217     arg->W1_grad = W1_grad;
218     while (1)
219     {
220         sem_wait(&sem_begin[arg->id]);
221         if (arg->done)
222             break;
223         arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
224         sem_post(&sem_end[arg->id]);
225     }
226     fprintf(stderr, "done\n");
227     return NULL;
228 }
229
230 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
231 {
232     int i, j;
233     int e;
234     float best_rms = 1e10;
235     int inDim, outDim, hiddenDim;
236     int *topo;
237     double *W0, *W1, *best_W0, *best_W1;
238     double *W0_grad, *W1_grad;
239     double *W0_oldgrad, *W1_oldgrad;
240     double *W0_rate, *W1_rate;
241     double *best_W0_rate, *best_W1_rate;
242     int W0_size, W1_size;
243     topo = net->topo;
244     W0_size = (topo[0]+1)*topo[1];
245     W1_size = (topo[1]+1)*topo[2];
246     struct GradientArg args[NB_THREADS];
247     pthread_t thread[NB_THREADS];
248     int samplePerPart = nbSamples/NB_THREADS;
249     int count_worse=0;
250     int count_retries=0;
251
252     topo = net->topo;
253     inDim = net->topo[0];
254     hiddenDim = net->topo[1];
255     outDim = net->topo[2];
256     W0 = net->weights[0];
257     W1 = net->weights[1];
258     best_W0 = net->best_weights[0];
259     best_W1 = net->best_weights[1];
260     W0_grad = malloc(W0_size*sizeof(double));
261     W1_grad = malloc(W1_size*sizeof(double));
262     W0_oldgrad = malloc(W0_size*sizeof(double));
263     W1_oldgrad = malloc(W1_size*sizeof(double));
264     W0_rate = malloc(W0_size*sizeof(double));
265     W1_rate = malloc(W1_size*sizeof(double));
266     best_W0_rate = malloc(W0_size*sizeof(double));
267     best_W1_rate = malloc(W1_size*sizeof(double));
268     memset(W0_grad, 0, W0_size*sizeof(double));
269     memset(W0_oldgrad, 0, W0_size*sizeof(double));
270     memset(W1_grad, 0, W1_size*sizeof(double));
271     memset(W1_oldgrad, 0, W1_size*sizeof(double));
272
273     rate /= nbSamples;
274     for (i=0;i<hiddenDim;i++)
275         for (j=0;j<inDim+1;j++)
276             W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
277     for (i=0;i<W1_size;i++)
278         W1_rate[i] = rate;
279
280     for (i=0;i<NB_THREADS;i++)
281     {
282         args[i].net = net;
283         args[i].inputs = inputs+i*samplePerPart*inDim;
284         args[i].outputs = outputs+i*samplePerPart*outDim;
285         args[i].nbSamples = samplePerPart;
286         args[i].id = i;
287         args[i].done = 0;
288         sem_init(&sem_begin[i], 0, 0);
289         sem_init(&sem_end[i], 0, 0);
290         pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
291     }
292     for (e=0;e<nbEpoch;e++)
293     {
294         double rms=0;
295         double error_rate[2] = {0,0};
296         for (i=0;i<NB_THREADS;i++)
297         {
298             sem_post(&sem_begin[i]);
299         }
300         memset(W0_grad, 0, W0_size*sizeof(double));
301         memset(W1_grad, 0, W1_size*sizeof(double));
302         for (i=0;i<NB_THREADS;i++)
303         {
304             sem_wait(&sem_end[i]);
305             rms += args[i].rms;
306             error_rate[0] += args[i].error_rate[0];
307             error_rate[1] += args[i].error_rate[1];
308             for (j=0;j<W0_size;j++)
309                 W0_grad[j] += args[i].W0_grad[j];
310             for (j=0;j<W1_size;j++)
311                 W1_grad[j] += args[i].W1_grad[j];
312         }
313
314         float mean_rate = 0, min_rate = 1e10;
315         rms = (rms/(outDim*nbSamples));
316         error_rate[0] = (error_rate[0]/(nbSamples));
317         error_rate[1] = (error_rate[1]/(nbSamples));
318         fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
319         if (rms < best_rms)
320         {
321             best_rms = rms;
322             for (i=0;i<W0_size;i++)
323             {
324                 best_W0[i] = W0[i];
325                 best_W0_rate[i] = W0_rate[i];
326             }
327             for (i=0;i<W1_size;i++)
328             {
329                 best_W1[i] = W1[i];
330                 best_W1_rate[i] = W1_rate[i];
331             }
332             count_worse=0;
333             count_retries=0;
334         } else {
335             count_worse++;
336             if (count_worse>30)
337             {
338                 count_retries++;
339                 count_worse=0;
340                 for (i=0;i<W0_size;i++)
341                 {
342                     W0[i] = best_W0[i];
343                     best_W0_rate[i] *= .7;
344                     if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
345                     W0_rate[i] = best_W0_rate[i];
346                     W0_grad[i] = 0;
347                 }
348                 for (i=0;i<W1_size;i++)
349                 {
350                     W1[i] = best_W1[i];
351                     best_W1_rate[i] *= .8;
352                     if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
353                     W1_rate[i] = best_W1_rate[i];
354                     W1_grad[i] = 0;
355                 }
356             }
357         }
358         if (count_retries>10)
359             break;
360         for (i=0;i<W0_size;i++)
361         {
362             if (W0_oldgrad[i]*W0_grad[i] > 0)
363                 W0_rate[i] *= 1.01;
364             else if (W0_oldgrad[i]*W0_grad[i] < 0)
365                 W0_rate[i] *= .9;
366             mean_rate += W0_rate[i];
367             if (W0_rate[i] < min_rate)
368                 min_rate = W0_rate[i];
369             if (W0_rate[i] < 1e-15)
370                 W0_rate[i] = 1e-15;
371             /*if (W0_rate[i] > .01)
372                 W0_rate[i] = .01;*/
373             W0_oldgrad[i] = W0_grad[i];
374             W0[i] += W0_grad[i]*W0_rate[i];
375         }
376         for (i=0;i<W1_size;i++)
377         {
378             if (W1_oldgrad[i]*W1_grad[i] > 0)
379                 W1_rate[i] *= 1.01;
380             else if (W1_oldgrad[i]*W1_grad[i] < 0)
381                 W1_rate[i] *= .9;
382             mean_rate += W1_rate[i];
383             if (W1_rate[i] < min_rate)
384                 min_rate = W1_rate[i];
385             if (W1_rate[i] < 1e-15)
386                 W1_rate[i] = 1e-15;
387             W1_oldgrad[i] = W1_grad[i];
388             W1[i] += W1_grad[i]*W1_rate[i];
389         }
390         mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
391         fprintf (stderr, "%g %d", mean_rate, e);
392         if (count_retries)
393             fprintf(stderr, " %d", count_retries);
394         fprintf(stderr, "\n");
395         if (stopped)
396             break;
397     }
398     for (i=0;i<NB_THREADS;i++)
399     {
400         args[i].done = 1;
401         sem_post(&sem_begin[i]);
402         pthread_join(thread[i], NULL);
403         fprintf (stderr, "joined %d\n", i);
404     }
405     free(W0_grad);
406     free(W0_oldgrad);
407     free(W1_grad);
408     free(W1_oldgrad);
409     free(W0_rate);
410     free(best_W0_rate);
411     free(W1_rate);
412     free(best_W1_rate);
413     return best_rms;
414 }
415
416 int main(int argc, char **argv)
417 {
418     int i, j;
419     int nbInputs;
420     int nbOutputs;
421     int nbHidden;
422     int nbSamples;
423     int nbEpoch;
424     int nbRealInputs;
425     unsigned int seed;
426     int ret;
427     float rms;
428     float *inputs;
429     float *outputs;
430     if (argc!=6)
431     {
432         fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
433         return 1;
434     }
435     nbInputs = atoi(argv[1]);
436     nbHidden = atoi(argv[2]);
437     nbOutputs = atoi(argv[3]);
438     nbSamples = atoi(argv[4]);
439     nbEpoch = atoi(argv[5]);
440     nbRealInputs = nbInputs;
441     inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
442     outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
443
444     seed = time(NULL);
445     /*seed = 1452209040;*/
446     fprintf (stderr, "Seed is %u\n", seed);
447     srand(seed);
448     build_tansig_table();
449     signal(SIGTERM, handler);
450     signal(SIGINT, handler);
451     signal(SIGHUP, handler);
452     for (i=0;i<nbSamples;i++)
453     {
454         for (j=0;j<nbRealInputs;j++)
455             ret = scanf(" %f", &inputs[i*nbInputs+j]);
456         for (j=0;j<nbOutputs;j++)
457             ret = scanf(" %f", &outputs[i*nbOutputs+j]);
458         if (feof(stdin))
459         {
460             nbSamples = i;
461             break;
462         }
463     }
464     int topo[3] = {nbInputs, nbHidden, nbOutputs};
465     MLPTrain *net;
466
467     fprintf (stderr, "Got %d samples\n", nbSamples);
468     net = mlp_init(topo, 3, inputs, outputs, nbSamples);
469     rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
470     printf ("#ifdef HAVE_CONFIG_H\n");
471     printf ("#include \"config.h\"\n");
472     printf ("#endif\n\n");
473     printf ("#include \"mlp.h\"\n\n");
474     printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
475     printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
476     printf ("\n/* hidden layer */\n");
477     for (i=0;i<(topo[0]+1)*topo[1];i++)
478     {
479         printf ("%gf,", net->weights[0][i]);
480         if (i%5==4)
481             printf("\n");
482         else
483             printf(" ");
484     }
485     printf ("\n/* output layer */\n");
486     for (i=0;i<(topo[1]+1)*topo[2];i++)
487     {
488         printf ("%g,", net->weights[1][i]);
489         if (i%5==4)
490             printf("\n");
491         else
492             printf(" ");
493     }
494     printf ("};\n\n");
495     printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
496     printf ("const MLP net = {\n");
497     printf ("\t3,\n");
498     printf ("\ttopo,\n");
499     printf ("\tweights\n};\n");
500     return 0;
501 }