oops s/IMAX/IMIN/
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42         stopped = 1;
43         signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48         int i, j, k;
49         MLPTrain *net;
50         int inDim, outDim;
51         net = malloc(sizeof(*net));
52         net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53         for (i=0;i<nbLayers;i++)
54                 net->topo[i] = topo[i];
55         inDim = topo[0];
56         outDim = topo[nbLayers-1];
57         net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58         net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59         net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60         for (i=0;i<nbLayers-1;i++)
61         {
62                 net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63                 net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64         }
65         double inMean[inDim];
66         for (j=0;j<inDim;j++)
67         {
68                 double std=0;
69                 inMean[j] = 0;
70                 for (i=0;i<nbSamples;i++)
71                 {
72                         inMean[j] += inputs[i*inDim+j];
73                         std += inputs[i*inDim+j]*inputs[i*inDim+j];
74                 }
75                 inMean[j] /= nbSamples;
76                 std /= nbSamples;
77                 net->in_rate[1+j] = .5/(.0001+std);
78                 std = std-inMean[j]*inMean[j];
79                 if (std<.001)
80                         std = .001;
81                 std = 1/sqrt(inDim*std);
82                 for (k=0;k<topo[1];k++)
83                         net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84         }
85         net->in_rate[0] = 1;
86         for (j=0;j<topo[1];j++)
87         {
88                 double sum = 0;
89                 for (k=0;k<inDim;k++)
90                         sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91                 net->weights[0][j*(topo[0]+1)] = -sum;
92         }
93         for (j=0;j<outDim;j++)
94         {
95                 double mean = 0;
96                 double std;
97                 for (i=0;i<nbSamples;i++)
98                         mean += outputs[i*outDim+j];
99                 mean /= nbSamples;
100                 std = 1/sqrt(topo[nbLayers-2]);
101                 net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102                 for (k=0;k<topo[nbLayers-2];k++)
103                         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104         }
105         return net;
106 }
107
108 #define MAX_NEURONS 100
109
110 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
111 {
112         int i,j;
113         int s;
114         int inDim, outDim, hiddenDim;
115         int *topo;
116         double *W0, *W1;
117         double rms=0;
118         int W0_size, W1_size;
119         double hidden[MAX_NEURONS];
120         double netOut[MAX_NEURONS];
121         double error[MAX_NEURONS];
122
123         *error_rate = 0;
124         topo = net->topo;
125         inDim = net->topo[0];
126         hiddenDim = net->topo[1];
127         outDim = net->topo[2];
128         W0_size = (topo[0]+1)*topo[1];
129         W1_size = (topo[1]+1)*topo[2];
130         W0 = net->weights[0];
131         W1 = net->weights[1];
132         memset(W0_grad, 0, W0_size*sizeof(double));
133         memset(W1_grad, 0, W1_size*sizeof(double));
134         for (i=0;i<outDim;i++)
135                 netOut[i] = outputs[i];
136         for (s=0;s<nbSamples;s++)
137         {
138                 float *in, *out;
139                 in = inputs+s*inDim;
140                 out = outputs + s*outDim;
141                 for (i=0;i<hiddenDim;i++)
142                 {
143                         double sum = W0[i*(inDim+1)];
144                         for (j=0;j<inDim;j++)
145                                 sum += W0[i*(inDim+1)+j+1]*in[j];
146                         hidden[i] = tansig_approx(sum);
147                 }
148                 for (i=0;i<outDim;i++)
149                 {
150                         double sum = W1[i*(hiddenDim+1)];
151                         for (j=0;j<hiddenDim;j++)
152                                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
153                         netOut[i] = tansig_approx(sum);
154                         error[i] = out[i] - netOut[i];
155                         rms += error[i]*error[i];
156                         *error_rate += fabs(error[i])>1;
157                         /*error[i] = error[i]/(1+fabs(error[i]));*/
158                 }
159                 /* Back-propagate error */
160                 for (i=0;i<outDim;i++)
161                 {
162                         float grad = 1-netOut[i]*netOut[i];
163                         W1_grad[i*(hiddenDim+1)] += error[i]*grad;
164                         for (j=0;j<hiddenDim;j++)
165                                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
166                 }
167                 for (i=0;i<hiddenDim;i++)
168                 {
169                         double grad;
170                         grad = 0;
171                         for (j=0;j<outDim;j++)
172                                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
173                         grad *= 1-hidden[i]*hidden[i];
174                         W0_grad[i*(inDim+1)] += grad;
175                         for (j=0;j<inDim;j++)
176                                 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
177                 }
178         }
179         return rms;
180 }
181
182 #define NB_THREADS 8
183
184 sem_t sem_begin[NB_THREADS];
185 sem_t sem_end[NB_THREADS];
186
187 struct GradientArg {
188         int id;
189         int done;
190         MLPTrain *net;
191         float *inputs;
192         float *outputs;
193         int nbSamples;
194         double *W0_grad;
195         double *W1_grad;
196         double rms;
197         double error_rate;
198 };
199
200 void *gradient_thread_process(void *_arg)
201 {
202         int W0_size, W1_size;
203         struct GradientArg *arg = _arg;
204         int *topo = arg->net->topo;
205         W0_size = (topo[0]+1)*topo[1];
206         W1_size = (topo[1]+1)*topo[2];
207         double W0_grad[W0_size];
208         double W1_grad[W1_size];
209         arg->W0_grad = W0_grad;
210         arg->W1_grad = W1_grad;
211         while (1)
212         {
213                 sem_wait(&sem_begin[arg->id]);
214                 if (arg->done)
215                         break;
216                 arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, &arg->error_rate);
217                 sem_post(&sem_end[arg->id]);
218         }
219         fprintf(stderr, "done\n");
220         return NULL;
221 }
222
223 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
224 {
225         int i, j;
226         int e;
227         float best_rms = 1e10;
228         int inDim, outDim, hiddenDim;
229         int *topo;
230         double *W0, *W1, *best_W0, *best_W1;
231         double *W0_old, *W1_old;
232         double *W0_old2, *W1_old2;
233         double *W0_grad, *W1_grad;
234         double *W0_oldgrad, *W1_oldgrad;
235         double *W0_rate, *W1_rate;
236         double *best_W0_rate, *best_W1_rate;
237         int W0_size, W1_size;
238         topo = net->topo;
239         W0_size = (topo[0]+1)*topo[1];
240         W1_size = (topo[1]+1)*topo[2];
241         struct GradientArg args[NB_THREADS];
242         pthread_t thread[NB_THREADS];
243         int samplePerPart = nbSamples/NB_THREADS;
244         int count_worse=0;
245         int count_retries=0;
246
247         topo = net->topo;
248         inDim = net->topo[0];
249         hiddenDim = net->topo[1];
250         outDim = net->topo[2];
251         W0 = net->weights[0];
252         W1 = net->weights[1];
253         best_W0 = net->best_weights[0];
254         best_W1 = net->best_weights[1];
255         W0_old = malloc(W0_size*sizeof(double));
256         W1_old = malloc(W1_size*sizeof(double));
257         W0_old2 = malloc(W0_size*sizeof(double));
258         W1_old2 = malloc(W1_size*sizeof(double));
259         W0_grad = malloc(W0_size*sizeof(double));
260         W1_grad = malloc(W1_size*sizeof(double));
261         W0_oldgrad = malloc(W0_size*sizeof(double));
262         W1_oldgrad = malloc(W1_size*sizeof(double));
263         W0_rate = malloc(W0_size*sizeof(double));
264         W1_rate = malloc(W1_size*sizeof(double));
265         best_W0_rate = malloc(W0_size*sizeof(double));
266         best_W1_rate = malloc(W1_size*sizeof(double));
267         memcpy(W0_old, W0, W0_size*sizeof(double));
268         memcpy(W0_old2, W0, W0_size*sizeof(double));
269         memset(W0_grad, 0, W0_size*sizeof(double));
270         memset(W0_oldgrad, 0, W0_size*sizeof(double));
271         memcpy(W1_old, W1, W1_size*sizeof(double));
272         memcpy(W1_old2, W1, W1_size*sizeof(double));
273         memset(W1_grad, 0, W1_size*sizeof(double));
274         memset(W1_oldgrad, 0, W1_size*sizeof(double));
275         
276         rate /= nbSamples;
277         for (i=0;i<hiddenDim;i++)
278                 for (j=0;j<inDim+1;j++)
279                         W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
280         for (i=0;i<W1_size;i++)
281                 W1_rate[i] = rate;
282         
283         for (i=0;i<NB_THREADS;i++)
284         {
285                 args[i].net = net;
286                 args[i].inputs = inputs+i*samplePerPart*inDim;
287                 args[i].outputs = outputs+i*samplePerPart*outDim;
288                 args[i].nbSamples = samplePerPart;
289                 args[i].id = i;
290                 args[i].done = 0;
291                 sem_init(&sem_begin[i], 0, 0);
292                 sem_init(&sem_end[i], 0, 0);
293                 pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
294         }
295         for (e=0;e<nbEpoch;e++)
296         {
297                 double rms=0;
298                 double error_rate = 0;
299                 for (i=0;i<NB_THREADS;i++)
300                 {
301                         sem_post(&sem_begin[i]);
302                 }
303                 memset(W0_grad, 0, W0_size*sizeof(double));
304                 memset(W1_grad, 0, W1_size*sizeof(double));
305                 for (i=0;i<NB_THREADS;i++)
306                 {
307                         sem_wait(&sem_end[i]);
308                         rms += args[i].rms;
309                         error_rate += args[i].error_rate;
310                         for (j=0;j<W0_size;j++)
311                                 W0_grad[j] += args[i].W0_grad[j];
312                         for (j=0;j<W1_size;j++)
313                                 W1_grad[j] += args[i].W1_grad[j];
314                 }
315
316                 float mean_rate = 0, min_rate = 1e10;
317                 rms = (rms/(outDim*nbSamples));
318                 error_rate = (error_rate/(outDim*nbSamples));
319                 fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms);
320                 if (rms < best_rms)
321                 {
322                         best_rms = rms;
323                         for (i=0;i<W0_size;i++)
324                         {
325                                 best_W0[i] = W0[i];
326                                 best_W0_rate[i] = W0_rate[i];
327                         }
328                         for (i=0;i<W1_size;i++)
329                         {
330                                 best_W1[i] = W1[i];
331                                 best_W1_rate[i] = W1_rate[i];
332                         }
333                         count_worse=0;
334                         count_retries=0;
335                 } else {
336                         count_worse++;
337                         if (count_worse>30)
338                         {
339                             count_retries++;
340                                 count_worse=0;
341                                 for (i=0;i<W0_size;i++)
342                                 {
343                                         W0[i] = best_W0[i];
344                                         best_W0_rate[i] *= .7;
345                                         if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
346                                         W0_rate[i] = best_W0_rate[i];
347                                         W0_grad[i] = 0;
348                                 }
349                                 for (i=0;i<W1_size;i++)
350                                 {
351                                         W1[i] = best_W1[i];
352                                         best_W1_rate[i] *= .8;
353                                         if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
354                                         W1_rate[i] = best_W1_rate[i];
355                                         W1_grad[i] = 0;
356                                 }
357                         }
358                 }
359                 if (count_retries>10)
360                     break;
361                 for (i=0;i<W0_size;i++)
362                 {
363                         if (W0_oldgrad[i]*W0_grad[i] > 0)
364                                 W0_rate[i] *= 1.01;
365                         else if (W0_oldgrad[i]*W0_grad[i] < 0)
366                                 W0_rate[i] *= .9;
367                         mean_rate += W0_rate[i];
368                         if (W0_rate[i] < min_rate)
369                                 min_rate = W0_rate[i];
370                         if (W0_rate[i] < 1e-15)
371                                 W0_rate[i] = 1e-15;
372                         /*if (W0_rate[i] > .01)
373                                 W0_rate[i] = .01;*/
374                         W0_oldgrad[i] = W0_grad[i];
375                         W0_old2[i] = W0_old[i];
376                         W0_old[i] = W0[i];
377                         W0[i] += W0_grad[i]*W0_rate[i];
378                 }
379                 for (i=0;i<W1_size;i++)
380                 {
381                         if (W1_oldgrad[i]*W1_grad[i] > 0)
382                                 W1_rate[i] *= 1.01;
383                         else if (W1_oldgrad[i]*W1_grad[i] < 0)
384                                 W1_rate[i] *= .9;
385                         mean_rate += W1_rate[i];
386                         if (W1_rate[i] < min_rate)
387                                 min_rate = W1_rate[i];
388                         if (W1_rate[i] < 1e-15)
389                                 W1_rate[i] = 1e-15;
390                         W1_oldgrad[i] = W1_grad[i];
391                         W1_old2[i] = W1_old[i];
392                         W1_old[i] = W1[i];
393                         W1[i] += W1_grad[i]*W1_rate[i];
394                 }
395                 mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
396                 fprintf (stderr, "%g %d", mean_rate, e);
397                 if (count_retries)
398                     fprintf(stderr, " %d", count_retries);
399                 fprintf(stderr, "\n");
400                 if (stopped)
401                         break;
402         }
403         for (i=0;i<NB_THREADS;i++)
404         {
405                 args[i].done = 1;
406                 sem_post(&sem_begin[i]);
407                 pthread_join(thread[i], NULL);
408                 fprintf (stderr, "joined %d\n", i);
409         }
410         free(W0_old);
411         free(W1_old);
412         free(W0_grad);
413         free(W1_grad);
414         free(W0_rate);
415         free(W1_rate);
416         return best_rms;
417 }
418
419 int main(int argc, char **argv)
420 {
421         int i, j;
422         int nbInputs;
423         int nbOutputs;
424         int nbHidden;
425         int nbSamples;
426         int nbEpoch;
427         int nbRealInputs;
428         unsigned int seed;
429         int ret;
430         float rms;
431         float *inputs;
432         float *outputs;
433         if (argc!=6)
434         {
435                 fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
436                 return 1;
437         }
438         nbInputs = atoi(argv[1]);
439         nbHidden = atoi(argv[2]);
440         nbOutputs = atoi(argv[3]);
441         nbSamples = atoi(argv[4]);
442         nbEpoch = atoi(argv[5]);
443         nbRealInputs = nbInputs;
444         inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
445         outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
446         
447         seed = time(NULL);
448         fprintf (stderr, "Seed is %u\n", seed);
449         srand(seed);
450         build_tansig_table();
451         signal(SIGTERM, handler);
452         signal(SIGINT, handler);
453         signal(SIGHUP, handler);
454         for (i=0;i<nbSamples;i++)
455         {
456                 for (j=0;j<nbRealInputs;j++)
457                         ret = scanf(" %f", &inputs[i*nbInputs+j]);
458                 for (j=0;j<nbOutputs;j++)
459                         ret = scanf(" %f", &outputs[i*nbOutputs+j]);
460                 if (feof(stdin))
461                 {
462                         nbSamples = i;
463                         break;
464                 }
465         }
466         int topo[3] = {nbInputs, nbHidden, nbOutputs};
467         MLPTrain *net;
468
469         fprintf (stderr, "Got %d samples\n", nbSamples);
470         net = mlp_init(topo, 3, inputs, outputs, nbSamples);
471         rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
472         printf ("#include \"mlp.h\"\n\n");
473         printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
474         printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
475         printf ("\n/* hidden layer */\n");
476         for (i=0;i<(topo[0]+1)*topo[1];i++)
477         {
478                 printf ("%gf, ", net->weights[0][i]);
479                 if (i%5==4)
480                         printf("\n");
481         }
482         printf ("\n/* output layer */\n");
483         for (i=0;i<(topo[1]+1)*topo[2];i++)
484         {
485                 printf ("%g, ", net->weights[1][i]);
486                 if (i%5==4)
487                         printf("\n");
488         }
489         printf ("};\n\n");
490         printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
491         printf ("const MLP net = {\n");
492         printf ("\t3,\n");
493         printf ("\ttopo,\n");
494         printf ("\tweights\n};\n");
495         return 0;
496 }