Fixes floating-point bug introduced in be9e747bcc542c277d30f6c78a57b0940e0c5b5e
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42     stopped = 1;
43     signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48     int i, j, k;
49     MLPTrain *net;
50     int inDim, outDim;
51     net = malloc(sizeof(*net));
52     net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53     for (i=0;i<nbLayers;i++)
54         net->topo[i] = topo[i];
55     inDim = topo[0];
56     outDim = topo[nbLayers-1];
57     net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58     net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59     net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60     for (i=0;i<nbLayers-1;i++)
61     {
62         net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63         net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64     }
65     double inMean[inDim];
66     for (j=0;j<inDim;j++)
67     {
68         double std=0;
69         inMean[j] = 0;
70         for (i=0;i<nbSamples;i++)
71         {
72             inMean[j] += inputs[i*inDim+j];
73             std += inputs[i*inDim+j]*inputs[i*inDim+j];
74         }
75         inMean[j] /= nbSamples;
76         std /= nbSamples;
77         net->in_rate[1+j] = .5/(.0001+std);
78         std = std-inMean[j]*inMean[j];
79         if (std<.001)
80             std = .001;
81         std = 1/sqrt(inDim*std);
82         for (k=0;k<topo[1];k++)
83             net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84     }
85     net->in_rate[0] = 1;
86     for (j=0;j<topo[1];j++)
87     {
88         double sum = 0;
89         for (k=0;k<inDim;k++)
90             sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91         net->weights[0][j*(topo[0]+1)] = -sum;
92     }
93     for (j=0;j<outDim;j++)
94     {
95         double mean = 0;
96         double std;
97         for (i=0;i<nbSamples;i++)
98             mean += outputs[i*outDim+j];
99         mean /= nbSamples;
100         std = 1/sqrt(topo[nbLayers-2]);
101         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102         for (k=0;k<topo[nbLayers-2];k++)
103             net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104     }
105     return net;
106 }
107
108 #define MAX_NEURONS 100
109 #define MAX_OUT 10
110
111 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112 {
113     int i,j;
114     int s;
115     int inDim, outDim, hiddenDim;
116     int *topo;
117     double *W0, *W1;
118     double rms=0;
119     int W0_size, W1_size;
120     double hidden[MAX_NEURONS];
121     double netOut[MAX_NEURONS];
122     double error[MAX_NEURONS];
123
124     topo = net->topo;
125     inDim = net->topo[0];
126     hiddenDim = net->topo[1];
127     outDim = net->topo[2];
128     W0_size = (topo[0]+1)*topo[1];
129     W1_size = (topo[1]+1)*topo[2];
130     W0 = net->weights[0];
131     W1 = net->weights[1];
132     memset(W0_grad, 0, W0_size*sizeof(double));
133     memset(W1_grad, 0, W1_size*sizeof(double));
134     for (i=0;i<outDim;i++)
135         netOut[i] = outputs[i];
136     for (i=0;i<outDim;i++)
137         error_rate[i] = 0;
138     for (s=0;s<nbSamples;s++)
139     {
140         float *in, *out;
141         in = inputs+s*inDim;
142         out = outputs + s*outDim;
143         for (i=0;i<hiddenDim;i++)
144         {
145             double sum = W0[i*(inDim+1)];
146             for (j=0;j<inDim;j++)
147                 sum += W0[i*(inDim+1)+j+1]*in[j];
148             hidden[i] = tansig_approx(sum);
149         }
150         for (i=0;i<outDim;i++)
151         {
152             double sum = W1[i*(hiddenDim+1)];
153             for (j=0;j<hiddenDim;j++)
154                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
155             netOut[i] = tansig_approx(sum);
156             error[i] = out[i] - netOut[i];
157             rms += error[i]*error[i];
158             error_rate[i] += fabs(error[i])>1;
159             /*error[i] = error[i]/(1+fabs(error[i]));*/
160         }
161         /* Back-propagate error */
162         for (i=0;i<outDim;i++)
163         {
164             float grad = 1-netOut[i]*netOut[i];
165             W1_grad[i*(hiddenDim+1)] += error[i]*grad;
166             for (j=0;j<hiddenDim;j++)
167                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
168         }
169         for (i=0;i<hiddenDim;i++)
170         {
171             double grad;
172             grad = 0;
173             for (j=0;j<outDim;j++)
174                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
175             grad *= 1-hidden[i]*hidden[i];
176             W0_grad[i*(inDim+1)] += grad;
177             for (j=0;j<inDim;j++)
178                 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
179         }
180     }
181     return rms;
182 }
183
184 #define NB_THREADS 8
185
186 sem_t sem_begin[NB_THREADS];
187 sem_t sem_end[NB_THREADS];
188
189 struct GradientArg {
190     int id;
191     int done;
192     MLPTrain *net;
193     float *inputs;
194     float *outputs;
195     int nbSamples;
196     double *W0_grad;
197     double *W1_grad;
198     double rms;
199     double error_rate[MAX_OUT];
200 };
201
202 void *gradient_thread_process(void *_arg)
203 {
204     int W0_size, W1_size;
205     struct GradientArg *arg = _arg;
206     int *topo = arg->net->topo;
207     W0_size = (topo[0]+1)*topo[1];
208     W1_size = (topo[1]+1)*topo[2];
209     double W0_grad[W0_size];
210     double W1_grad[W1_size];
211     arg->W0_grad = W0_grad;
212     arg->W1_grad = W1_grad;
213     while (1)
214     {
215         sem_wait(&sem_begin[arg->id]);
216         if (arg->done)
217             break;
218         arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
219         sem_post(&sem_end[arg->id]);
220     }
221     fprintf(stderr, "done\n");
222     return NULL;
223 }
224
225 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
226 {
227     int i, j;
228     int e;
229     float best_rms = 1e10;
230     int inDim, outDim, hiddenDim;
231     int *topo;
232     double *W0, *W1, *best_W0, *best_W1;
233     double *W0_old, *W1_old;
234     double *W0_old2, *W1_old2;
235     double *W0_grad, *W1_grad;
236     double *W0_oldgrad, *W1_oldgrad;
237     double *W0_rate, *W1_rate;
238     double *best_W0_rate, *best_W1_rate;
239     int W0_size, W1_size;
240     topo = net->topo;
241     W0_size = (topo[0]+1)*topo[1];
242     W1_size = (topo[1]+1)*topo[2];
243     struct GradientArg args[NB_THREADS];
244     pthread_t thread[NB_THREADS];
245     int samplePerPart = nbSamples/NB_THREADS;
246     int count_worse=0;
247     int count_retries=0;
248
249     topo = net->topo;
250     inDim = net->topo[0];
251     hiddenDim = net->topo[1];
252     outDim = net->topo[2];
253     W0 = net->weights[0];
254     W1 = net->weights[1];
255     best_W0 = net->best_weights[0];
256     best_W1 = net->best_weights[1];
257     W0_old = malloc(W0_size*sizeof(double));
258     W1_old = malloc(W1_size*sizeof(double));
259     W0_old2 = malloc(W0_size*sizeof(double));
260     W1_old2 = malloc(W1_size*sizeof(double));
261     W0_grad = malloc(W0_size*sizeof(double));
262     W1_grad = malloc(W1_size*sizeof(double));
263     W0_oldgrad = malloc(W0_size*sizeof(double));
264     W1_oldgrad = malloc(W1_size*sizeof(double));
265     W0_rate = malloc(W0_size*sizeof(double));
266     W1_rate = malloc(W1_size*sizeof(double));
267     best_W0_rate = malloc(W0_size*sizeof(double));
268     best_W1_rate = malloc(W1_size*sizeof(double));
269     memcpy(W0_old, W0, W0_size*sizeof(double));
270     memcpy(W0_old2, W0, W0_size*sizeof(double));
271     memset(W0_grad, 0, W0_size*sizeof(double));
272     memset(W0_oldgrad, 0, W0_size*sizeof(double));
273     memcpy(W1_old, W1, W1_size*sizeof(double));
274     memcpy(W1_old2, W1, W1_size*sizeof(double));
275     memset(W1_grad, 0, W1_size*sizeof(double));
276     memset(W1_oldgrad, 0, W1_size*sizeof(double));
277
278     rate /= nbSamples;
279     for (i=0;i<hiddenDim;i++)
280         for (j=0;j<inDim+1;j++)
281             W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
282     for (i=0;i<W1_size;i++)
283         W1_rate[i] = rate;
284
285     for (i=0;i<NB_THREADS;i++)
286     {
287         args[i].net = net;
288         args[i].inputs = inputs+i*samplePerPart*inDim;
289         args[i].outputs = outputs+i*samplePerPart*outDim;
290         args[i].nbSamples = samplePerPart;
291         args[i].id = i;
292         args[i].done = 0;
293         sem_init(&sem_begin[i], 0, 0);
294         sem_init(&sem_end[i], 0, 0);
295         pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
296     }
297     for (e=0;e<nbEpoch;e++)
298     {
299         double rms=0;
300         double error_rate[2] = {0,0};
301         for (i=0;i<NB_THREADS;i++)
302         {
303             sem_post(&sem_begin[i]);
304         }
305         memset(W0_grad, 0, W0_size*sizeof(double));
306         memset(W1_grad, 0, W1_size*sizeof(double));
307         for (i=0;i<NB_THREADS;i++)
308         {
309             sem_wait(&sem_end[i]);
310             rms += args[i].rms;
311             error_rate[0] += args[i].error_rate[0];
312             error_rate[1] += args[i].error_rate[1];
313             for (j=0;j<W0_size;j++)
314                 W0_grad[j] += args[i].W0_grad[j];
315             for (j=0;j<W1_size;j++)
316                 W1_grad[j] += args[i].W1_grad[j];
317         }
318
319         float mean_rate = 0, min_rate = 1e10;
320         rms = (rms/(outDim*nbSamples));
321         error_rate[0] = (error_rate[0]/(nbSamples));
322         error_rate[1] = (error_rate[1]/(nbSamples));
323         fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
324         if (rms < best_rms)
325         {
326             best_rms = rms;
327             for (i=0;i<W0_size;i++)
328             {
329                 best_W0[i] = W0[i];
330                 best_W0_rate[i] = W0_rate[i];
331             }
332             for (i=0;i<W1_size;i++)
333             {
334                 best_W1[i] = W1[i];
335                 best_W1_rate[i] = W1_rate[i];
336             }
337             count_worse=0;
338             count_retries=0;
339         } else {
340             count_worse++;
341             if (count_worse>30)
342             {
343                 count_retries++;
344                 count_worse=0;
345                 for (i=0;i<W0_size;i++)
346                 {
347                     W0[i] = best_W0[i];
348                     best_W0_rate[i] *= .7;
349                     if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
350                     W0_rate[i] = best_W0_rate[i];
351                     W0_grad[i] = 0;
352                 }
353                 for (i=0;i<W1_size;i++)
354                 {
355                     W1[i] = best_W1[i];
356                     best_W1_rate[i] *= .8;
357                     if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
358                     W1_rate[i] = best_W1_rate[i];
359                     W1_grad[i] = 0;
360                 }
361             }
362         }
363         if (count_retries>10)
364             break;
365         for (i=0;i<W0_size;i++)
366         {
367             if (W0_oldgrad[i]*W0_grad[i] > 0)
368                 W0_rate[i] *= 1.01;
369             else if (W0_oldgrad[i]*W0_grad[i] < 0)
370                 W0_rate[i] *= .9;
371             mean_rate += W0_rate[i];
372             if (W0_rate[i] < min_rate)
373                 min_rate = W0_rate[i];
374             if (W0_rate[i] < 1e-15)
375                 W0_rate[i] = 1e-15;
376             /*if (W0_rate[i] > .01)
377                 W0_rate[i] = .01;*/
378             W0_oldgrad[i] = W0_grad[i];
379             W0_old2[i] = W0_old[i];
380             W0_old[i] = W0[i];
381             W0[i] += W0_grad[i]*W0_rate[i];
382         }
383         for (i=0;i<W1_size;i++)
384         {
385             if (W1_oldgrad[i]*W1_grad[i] > 0)
386                 W1_rate[i] *= 1.01;
387             else if (W1_oldgrad[i]*W1_grad[i] < 0)
388                 W1_rate[i] *= .9;
389             mean_rate += W1_rate[i];
390             if (W1_rate[i] < min_rate)
391                 min_rate = W1_rate[i];
392             if (W1_rate[i] < 1e-15)
393                 W1_rate[i] = 1e-15;
394             W1_oldgrad[i] = W1_grad[i];
395             W1_old2[i] = W1_old[i];
396             W1_old[i] = W1[i];
397             W1[i] += W1_grad[i]*W1_rate[i];
398         }
399         mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
400         fprintf (stderr, "%g %d", mean_rate, e);
401         if (count_retries)
402             fprintf(stderr, " %d", count_retries);
403         fprintf(stderr, "\n");
404         if (stopped)
405             break;
406     }
407     for (i=0;i<NB_THREADS;i++)
408     {
409         args[i].done = 1;
410         sem_post(&sem_begin[i]);
411         pthread_join(thread[i], NULL);
412         fprintf (stderr, "joined %d\n", i);
413     }
414     free(W0_old);
415     free(W1_old);
416     free(W0_grad);
417     free(W1_grad);
418     free(W0_rate);
419     free(W1_rate);
420     return best_rms;
421 }
422
423 int main(int argc, char **argv)
424 {
425     int i, j;
426     int nbInputs;
427     int nbOutputs;
428     int nbHidden;
429     int nbSamples;
430     int nbEpoch;
431     int nbRealInputs;
432     unsigned int seed;
433     int ret;
434     float rms;
435     float *inputs;
436     float *outputs;
437     if (argc!=6)
438     {
439         fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
440         return 1;
441     }
442     nbInputs = atoi(argv[1]);
443     nbHidden = atoi(argv[2]);
444     nbOutputs = atoi(argv[3]);
445     nbSamples = atoi(argv[4]);
446     nbEpoch = atoi(argv[5]);
447     nbRealInputs = nbInputs;
448     inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
449     outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
450
451     seed = time(NULL);
452     /*seed = 1361480659;*/
453     fprintf (stderr, "Seed is %u\n", seed);
454     srand(seed);
455     build_tansig_table();
456     signal(SIGTERM, handler);
457     signal(SIGINT, handler);
458     signal(SIGHUP, handler);
459     for (i=0;i<nbSamples;i++)
460     {
461         for (j=0;j<nbRealInputs;j++)
462             ret = scanf(" %f", &inputs[i*nbInputs+j]);
463         for (j=0;j<nbOutputs;j++)
464             ret = scanf(" %f", &outputs[i*nbOutputs+j]);
465         if (feof(stdin))
466         {
467             nbSamples = i;
468             break;
469         }
470     }
471     int topo[3] = {nbInputs, nbHidden, nbOutputs};
472     MLPTrain *net;
473
474     fprintf (stderr, "Got %d samples\n", nbSamples);
475     net = mlp_init(topo, 3, inputs, outputs, nbSamples);
476     rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
477     printf ("#include \"mlp.h\"\n\n");
478     printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
479     printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
480     printf ("\n/* hidden layer */\n");
481     for (i=0;i<(topo[0]+1)*topo[1];i++)
482     {
483         printf ("%gf, ", net->weights[0][i]);
484         if (i%5==4)
485             printf("\n");
486     }
487     printf ("\n/* output layer */\n");
488     for (i=0;i<(topo[1]+1)*topo[2];i++)
489     {
490         printf ("%g, ", net->weights[1][i]);
491         if (i%5==4)
492             printf("\n");
493     }
494     printf ("};\n\n");
495     printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
496     printf ("const MLP net = {\n");
497     printf ("\t3,\n");
498     printf ("\ttopo,\n");
499     printf ("\tweights\n};\n");
500     return 0;
501 }