Simplifying fast_atan2f()
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42     stopped = 1;
43     signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48     int i, j, k;
49     MLPTrain *net;
50     int inDim, outDim;
51     net = malloc(sizeof(*net));
52     net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53     for (i=0;i<nbLayers;i++)
54         net->topo[i] = topo[i];
55     inDim = topo[0];
56     outDim = topo[nbLayers-1];
57     net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58     net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59     net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60     for (i=0;i<nbLayers-1;i++)
61     {
62         net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63         net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64     }
65     double inMean[inDim];
66     for (j=0;j<inDim;j++)
67     {
68         double std=0;
69         inMean[j] = 0;
70         for (i=0;i<nbSamples;i++)
71         {
72             inMean[j] += inputs[i*inDim+j];
73             std += inputs[i*inDim+j]*inputs[i*inDim+j];
74         }
75         inMean[j] /= nbSamples;
76         std /= nbSamples;
77         net->in_rate[1+j] = .5/(.0001+std);
78         std = std-inMean[j]*inMean[j];
79         if (std<.001)
80             std = .001;
81         std = 1/sqrt(inDim*std);
82         for (k=0;k<topo[1];k++)
83             net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84     }
85     net->in_rate[0] = 1;
86     for (j=0;j<topo[1];j++)
87     {
88         double sum = 0;
89         for (k=0;k<inDim;k++)
90             sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91         net->weights[0][j*(topo[0]+1)] = -sum;
92     }
93     for (j=0;j<outDim;j++)
94     {
95         double mean = 0;
96         double std;
97         for (i=0;i<nbSamples;i++)
98             mean += outputs[i*outDim+j];
99         mean /= nbSamples;
100         std = 1/sqrt(topo[nbLayers-2]);
101         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102         for (k=0;k<topo[nbLayers-2];k++)
103             net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104     }
105     return net;
106 }
107
108 #define MAX_NEURONS 100
109 #define MAX_OUT 10
110
111 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112 {
113     int i,j;
114     int s;
115     int inDim, outDim, hiddenDim;
116     int *topo;
117     double *W0, *W1;
118     double rms=0;
119     int W0_size, W1_size;
120     double hidden[MAX_NEURONS];
121     double netOut[MAX_NEURONS];
122     double error[MAX_NEURONS];
123
124     topo = net->topo;
125     inDim = net->topo[0];
126     hiddenDim = net->topo[1];
127     outDim = net->topo[2];
128     W0_size = (topo[0]+1)*topo[1];
129     W1_size = (topo[1]+1)*topo[2];
130     W0 = net->weights[0];
131     W1 = net->weights[1];
132     memset(W0_grad, 0, W0_size*sizeof(double));
133     memset(W1_grad, 0, W1_size*sizeof(double));
134     for (i=0;i<outDim;i++)
135         netOut[i] = outputs[i];
136     for (i=0;i<outDim;i++)
137         error_rate[i] = 0;
138     for (s=0;s<nbSamples;s++)
139     {
140         float *in, *out;
141         in = inputs+s*inDim;
142         out = outputs + s*outDim;
143         for (i=0;i<hiddenDim;i++)
144         {
145             double sum = W0[i*(inDim+1)];
146             for (j=0;j<inDim;j++)
147                 sum += W0[i*(inDim+1)+j+1]*in[j];
148             hidden[i] = tansig_approx(sum);
149         }
150         for (i=0;i<outDim;i++)
151         {
152             double sum = W1[i*(hiddenDim+1)];
153             for (j=0;j<hiddenDim;j++)
154                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
155             netOut[i] = tansig_approx(sum);
156             error[i] = out[i] - netOut[i];
157             if (out[i] == 0) error[i] *= .0;
158             error_rate[i] += fabs(error[i])>1;
159             if (i==0) error[i] *= 3;
160             rms += error[i]*error[i];
161             /*error[i] = error[i]/(1+fabs(error[i]));*/
162         }
163         /* Back-propagate error */
164         for (i=0;i<outDim;i++)
165         {
166             float grad = 1-netOut[i]*netOut[i];
167             W1_grad[i*(hiddenDim+1)] += error[i]*grad;
168             for (j=0;j<hiddenDim;j++)
169                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
170         }
171         for (i=0;i<hiddenDim;i++)
172         {
173             double grad;
174             grad = 0;
175             for (j=0;j<outDim;j++)
176                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
177             grad *= 1-hidden[i]*hidden[i];
178             W0_grad[i*(inDim+1)] += grad;
179             for (j=0;j<inDim;j++)
180                 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
181         }
182     }
183     return rms;
184 }
185
186 #define NB_THREADS 8
187
188 sem_t sem_begin[NB_THREADS];
189 sem_t sem_end[NB_THREADS];
190
191 struct GradientArg {
192     int id;
193     int done;
194     MLPTrain *net;
195     float *inputs;
196     float *outputs;
197     int nbSamples;
198     double *W0_grad;
199     double *W1_grad;
200     double rms;
201     double error_rate[MAX_OUT];
202 };
203
204 void *gradient_thread_process(void *_arg)
205 {
206     int W0_size, W1_size;
207     struct GradientArg *arg = _arg;
208     int *topo = arg->net->topo;
209     W0_size = (topo[0]+1)*topo[1];
210     W1_size = (topo[1]+1)*topo[2];
211     double W0_grad[W0_size];
212     double W1_grad[W1_size];
213     arg->W0_grad = W0_grad;
214     arg->W1_grad = W1_grad;
215     while (1)
216     {
217         sem_wait(&sem_begin[arg->id]);
218         if (arg->done)
219             break;
220         arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
221         sem_post(&sem_end[arg->id]);
222     }
223     fprintf(stderr, "done\n");
224     return NULL;
225 }
226
227 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
228 {
229     int i, j;
230     int e;
231     float best_rms = 1e10;
232     int inDim, outDim, hiddenDim;
233     int *topo;
234     double *W0, *W1, *best_W0, *best_W1;
235     double *W0_old, *W1_old;
236     double *W0_old2, *W1_old2;
237     double *W0_grad, *W1_grad;
238     double *W0_oldgrad, *W1_oldgrad;
239     double *W0_rate, *W1_rate;
240     double *best_W0_rate, *best_W1_rate;
241     int W0_size, W1_size;
242     topo = net->topo;
243     W0_size = (topo[0]+1)*topo[1];
244     W1_size = (topo[1]+1)*topo[2];
245     struct GradientArg args[NB_THREADS];
246     pthread_t thread[NB_THREADS];
247     int samplePerPart = nbSamples/NB_THREADS;
248     int count_worse=0;
249     int count_retries=0;
250
251     topo = net->topo;
252     inDim = net->topo[0];
253     hiddenDim = net->topo[1];
254     outDim = net->topo[2];
255     W0 = net->weights[0];
256     W1 = net->weights[1];
257     best_W0 = net->best_weights[0];
258     best_W1 = net->best_weights[1];
259     W0_old = malloc(W0_size*sizeof(double));
260     W1_old = malloc(W1_size*sizeof(double));
261     W0_old2 = malloc(W0_size*sizeof(double));
262     W1_old2 = malloc(W1_size*sizeof(double));
263     W0_grad = malloc(W0_size*sizeof(double));
264     W1_grad = malloc(W1_size*sizeof(double));
265     W0_oldgrad = malloc(W0_size*sizeof(double));
266     W1_oldgrad = malloc(W1_size*sizeof(double));
267     W0_rate = malloc(W0_size*sizeof(double));
268     W1_rate = malloc(W1_size*sizeof(double));
269     best_W0_rate = malloc(W0_size*sizeof(double));
270     best_W1_rate = malloc(W1_size*sizeof(double));
271     memcpy(W0_old, W0, W0_size*sizeof(double));
272     memcpy(W0_old2, W0, W0_size*sizeof(double));
273     memset(W0_grad, 0, W0_size*sizeof(double));
274     memset(W0_oldgrad, 0, W0_size*sizeof(double));
275     memcpy(W1_old, W1, W1_size*sizeof(double));
276     memcpy(W1_old2, W1, W1_size*sizeof(double));
277     memset(W1_grad, 0, W1_size*sizeof(double));
278     memset(W1_oldgrad, 0, W1_size*sizeof(double));
279
280     rate /= nbSamples;
281     for (i=0;i<hiddenDim;i++)
282         for (j=0;j<inDim+1;j++)
283             W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
284     for (i=0;i<W1_size;i++)
285         W1_rate[i] = rate;
286
287     for (i=0;i<NB_THREADS;i++)
288     {
289         args[i].net = net;
290         args[i].inputs = inputs+i*samplePerPart*inDim;
291         args[i].outputs = outputs+i*samplePerPart*outDim;
292         args[i].nbSamples = samplePerPart;
293         args[i].id = i;
294         args[i].done = 0;
295         sem_init(&sem_begin[i], 0, 0);
296         sem_init(&sem_end[i], 0, 0);
297         pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
298     }
299     for (e=0;e<nbEpoch;e++)
300     {
301         double rms=0;
302         double error_rate[2] = {0,0};
303         for (i=0;i<NB_THREADS;i++)
304         {
305             sem_post(&sem_begin[i]);
306         }
307         memset(W0_grad, 0, W0_size*sizeof(double));
308         memset(W1_grad, 0, W1_size*sizeof(double));
309         for (i=0;i<NB_THREADS;i++)
310         {
311             sem_wait(&sem_end[i]);
312             rms += args[i].rms;
313             error_rate[0] += args[i].error_rate[0];
314             error_rate[1] += args[i].error_rate[1];
315             for (j=0;j<W0_size;j++)
316                 W0_grad[j] += args[i].W0_grad[j];
317             for (j=0;j<W1_size;j++)
318                 W1_grad[j] += args[i].W1_grad[j];
319         }
320
321         float mean_rate = 0, min_rate = 1e10;
322         rms = (rms/(outDim*nbSamples));
323         error_rate[0] = (error_rate[0]/(nbSamples));
324         error_rate[1] = (error_rate[1]/(nbSamples));
325         fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
326         if (rms < best_rms)
327         {
328             best_rms = rms;
329             for (i=0;i<W0_size;i++)
330             {
331                 best_W0[i] = W0[i];
332                 best_W0_rate[i] = W0_rate[i];
333             }
334             for (i=0;i<W1_size;i++)
335             {
336                 best_W1[i] = W1[i];
337                 best_W1_rate[i] = W1_rate[i];
338             }
339             count_worse=0;
340             count_retries=0;
341         } else {
342             count_worse++;
343             if (count_worse>30)
344             {
345                 count_retries++;
346                 count_worse=0;
347                 for (i=0;i<W0_size;i++)
348                 {
349                     W0[i] = best_W0[i];
350                     best_W0_rate[i] *= .7;
351                     if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
352                     W0_rate[i] = best_W0_rate[i];
353                     W0_grad[i] = 0;
354                 }
355                 for (i=0;i<W1_size;i++)
356                 {
357                     W1[i] = best_W1[i];
358                     best_W1_rate[i] *= .8;
359                     if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
360                     W1_rate[i] = best_W1_rate[i];
361                     W1_grad[i] = 0;
362                 }
363             }
364         }
365         if (count_retries>10)
366             break;
367         for (i=0;i<W0_size;i++)
368         {
369             if (W0_oldgrad[i]*W0_grad[i] > 0)
370                 W0_rate[i] *= 1.01;
371             else if (W0_oldgrad[i]*W0_grad[i] < 0)
372                 W0_rate[i] *= .9;
373             mean_rate += W0_rate[i];
374             if (W0_rate[i] < min_rate)
375                 min_rate = W0_rate[i];
376             if (W0_rate[i] < 1e-15)
377                 W0_rate[i] = 1e-15;
378             /*if (W0_rate[i] > .01)
379                 W0_rate[i] = .01;*/
380             W0_oldgrad[i] = W0_grad[i];
381             W0_old2[i] = W0_old[i];
382             W0_old[i] = W0[i];
383             W0[i] += W0_grad[i]*W0_rate[i];
384         }
385         for (i=0;i<W1_size;i++)
386         {
387             if (W1_oldgrad[i]*W1_grad[i] > 0)
388                 W1_rate[i] *= 1.01;
389             else if (W1_oldgrad[i]*W1_grad[i] < 0)
390                 W1_rate[i] *= .9;
391             mean_rate += W1_rate[i];
392             if (W1_rate[i] < min_rate)
393                 min_rate = W1_rate[i];
394             if (W1_rate[i] < 1e-15)
395                 W1_rate[i] = 1e-15;
396             W1_oldgrad[i] = W1_grad[i];
397             W1_old2[i] = W1_old[i];
398             W1_old[i] = W1[i];
399             W1[i] += W1_grad[i]*W1_rate[i];
400         }
401         mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
402         fprintf (stderr, "%g %d", mean_rate, e);
403         if (count_retries)
404             fprintf(stderr, " %d", count_retries);
405         fprintf(stderr, "\n");
406         if (stopped)
407             break;
408     }
409     for (i=0;i<NB_THREADS;i++)
410     {
411         args[i].done = 1;
412         sem_post(&sem_begin[i]);
413         pthread_join(thread[i], NULL);
414         fprintf (stderr, "joined %d\n", i);
415     }
416     free(W0_old);
417     free(W1_old);
418     free(W0_grad);
419     free(W1_grad);
420     free(W0_rate);
421     free(W1_rate);
422     return best_rms;
423 }
424
425 int main(int argc, char **argv)
426 {
427     int i, j;
428     int nbInputs;
429     int nbOutputs;
430     int nbHidden;
431     int nbSamples;
432     int nbEpoch;
433     int nbRealInputs;
434     unsigned int seed;
435     int ret;
436     float rms;
437     float *inputs;
438     float *outputs;
439     if (argc!=6)
440     {
441         fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
442         return 1;
443     }
444     nbInputs = atoi(argv[1]);
445     nbHidden = atoi(argv[2]);
446     nbOutputs = atoi(argv[3]);
447     nbSamples = atoi(argv[4]);
448     nbEpoch = atoi(argv[5]);
449     nbRealInputs = nbInputs;
450     inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
451     outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
452
453     seed = time(NULL);
454     /*seed = 1452209040;*/
455     fprintf (stderr, "Seed is %u\n", seed);
456     srand(seed);
457     build_tansig_table();
458     signal(SIGTERM, handler);
459     signal(SIGINT, handler);
460     signal(SIGHUP, handler);
461     for (i=0;i<nbSamples;i++)
462     {
463         for (j=0;j<nbRealInputs;j++)
464             ret = scanf(" %f", &inputs[i*nbInputs+j]);
465         for (j=0;j<nbOutputs;j++)
466             ret = scanf(" %f", &outputs[i*nbOutputs+j]);
467         if (feof(stdin))
468         {
469             nbSamples = i;
470             break;
471         }
472     }
473     int topo[3] = {nbInputs, nbHidden, nbOutputs};
474     MLPTrain *net;
475
476     fprintf (stderr, "Got %d samples\n", nbSamples);
477     net = mlp_init(topo, 3, inputs, outputs, nbSamples);
478     rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
479     printf ("#include \"mlp.h\"\n\n");
480     printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
481     printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
482     printf ("\n/* hidden layer */\n");
483     for (i=0;i<(topo[0]+1)*topo[1];i++)
484     {
485         printf ("%gf, ", net->weights[0][i]);
486         if (i%5==4)
487             printf("\n");
488     }
489     printf ("\n/* output layer */\n");
490     for (i=0;i<(topo[1]+1)*topo[2];i++)
491     {
492         printf ("%g, ", net->weights[1][i]);
493         if (i%5==4)
494             printf("\n");
495     }
496     printf ("};\n\n");
497     printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
498     printf ("const MLP net = {\n");
499     printf ("\t3,\n");
500     printf ("\ttopo,\n");
501     printf ("\tweights\n};\n");
502     return 0;
503 }