Add some basic testing for OPUS_{GET|SET}_PREDICTION_DISABLED.
[opus.git] / src / mlp_train.c
1 /* Copyright (c) 2008-2011 Octasic Inc.
2    Written by Jean-Marc Valin */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28
29 #include "mlp_train.h"
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <string.h>
33 #include <semaphore.h>
34 #include <pthread.h>
35 #include <time.h>
36 #include <signal.h>
37
38 int stopped = 0;
39
40 void handler(int sig)
41 {
42         stopped = 1;
43         signal(sig, handler);
44 }
45
46 MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples)
47 {
48         int i, j, k;
49         MLPTrain *net;
50         int inDim, outDim;
51         net = malloc(sizeof(*net));
52         net->topo = malloc(nbLayers*sizeof(net->topo[0]));
53         for (i=0;i<nbLayers;i++)
54                 net->topo[i] = topo[i];
55         inDim = topo[0];
56         outDim = topo[nbLayers-1];
57         net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0]));
58         net->weights = malloc((nbLayers-1)*sizeof(net->weights));
59         net->best_weights = malloc((nbLayers-1)*sizeof(net->weights));
60         for (i=0;i<nbLayers-1;i++)
61         {
62                 net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
63                 net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0]));
64         }
65         double inMean[inDim];
66         for (j=0;j<inDim;j++)
67         {
68                 double std=0;
69                 inMean[j] = 0;
70                 for (i=0;i<nbSamples;i++)
71                 {
72                         inMean[j] += inputs[i*inDim+j];
73                         std += inputs[i*inDim+j]*inputs[i*inDim+j];
74                 }
75                 inMean[j] /= nbSamples;
76                 std /= nbSamples;
77                 net->in_rate[1+j] = .5/(.0001+std);
78                 std = std-inMean[j]*inMean[j];
79                 if (std<.001)
80                         std = .001;
81                 std = 1/sqrt(inDim*std);
82                 for (k=0;k<topo[1];k++)
83                         net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
84         }
85         net->in_rate[0] = 1;
86         for (j=0;j<topo[1];j++)
87         {
88                 double sum = 0;
89                 for (k=0;k<inDim;k++)
90                         sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1];
91                 net->weights[0][j*(topo[0]+1)] = -sum;
92         }
93         for (j=0;j<outDim;j++)
94         {
95                 double mean = 0;
96                 double std;
97                 for (i=0;i<nbSamples;i++)
98                         mean += outputs[i*outDim+j];
99                 mean /= nbSamples;
100                 std = 1/sqrt(topo[nbLayers-2]);
101                 net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean;
102                 for (k=0;k<topo[nbLayers-2];k++)
103                         net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std);
104         }
105         return net;
106 }
107
108 #define MAX_NEURONS 100
109 #define MAX_OUT 10
110
111 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate)
112 {
113         int i,j;
114         int s;
115         int inDim, outDim, hiddenDim;
116         int *topo;
117         double *W0, *W1;
118         double rms=0;
119         int W0_size, W1_size;
120         double hidden[MAX_NEURONS];
121         double netOut[MAX_NEURONS];
122         double error[MAX_NEURONS];
123
124         topo = net->topo;
125         inDim = net->topo[0];
126         hiddenDim = net->topo[1];
127         outDim = net->topo[2];
128         W0_size = (topo[0]+1)*topo[1];
129         W1_size = (topo[1]+1)*topo[2];
130         W0 = net->weights[0];
131         W1 = net->weights[1];
132         memset(W0_grad, 0, W0_size*sizeof(double));
133         memset(W1_grad, 0, W1_size*sizeof(double));
134         for (i=0;i<outDim;i++)
135                 netOut[i] = outputs[i];
136         for (i=0;i<outDim;i++)
137            error_rate[i] = 0;
138         for (s=0;s<nbSamples;s++)
139         {
140                 float *in, *out;
141                 in = inputs+s*inDim;
142                 out = outputs + s*outDim;
143                 for (i=0;i<hiddenDim;i++)
144                 {
145                         double sum = W0[i*(inDim+1)];
146                         for (j=0;j<inDim;j++)
147                                 sum += W0[i*(inDim+1)+j+1]*in[j];
148                         hidden[i] = tansig_approx(sum);
149                 }
150                 for (i=0;i<outDim;i++)
151                 {
152                         double sum = W1[i*(hiddenDim+1)];
153                         for (j=0;j<hiddenDim;j++)
154                                 sum += W1[i*(hiddenDim+1)+j+1]*hidden[j];
155                         netOut[i] = tansig_approx(sum);
156                         error[i] = out[i] - netOut[i];
157                         rms += error[i]*error[i];
158                         error_rate[i] += fabs(error[i])>1;
159                         /*error[i] = error[i]/(1+fabs(error[i]));*/
160                 }
161                 /* Back-propagate error */
162                 for (i=0;i<outDim;i++)
163                 {
164                         float grad = 1-netOut[i]*netOut[i];
165                         W1_grad[i*(hiddenDim+1)] += error[i]*grad;
166                         for (j=0;j<hiddenDim;j++)
167                                 W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j];
168                 }
169                 for (i=0;i<hiddenDim;i++)
170                 {
171                         double grad;
172                         grad = 0;
173                         for (j=0;j<outDim;j++)
174                                 grad += error[j]*W1[j*(hiddenDim+1)+i+1];
175                         grad *= 1-hidden[i]*hidden[i];
176                         W0_grad[i*(inDim+1)] += grad;
177                         for (j=0;j<inDim;j++)
178                                 W0_grad[i*(inDim+1)+j+1] += grad*in[j];
179                 }
180         }
181         return rms;
182 }
183
184 #define NB_THREADS 8
185
186 sem_t sem_begin[NB_THREADS];
187 sem_t sem_end[NB_THREADS];
188
189 struct GradientArg {
190         int id;
191         int done;
192         MLPTrain *net;
193         float *inputs;
194         float *outputs;
195         int nbSamples;
196         double *W0_grad;
197         double *W1_grad;
198         double rms;
199         double error_rate[MAX_OUT];
200 };
201
202 void *gradient_thread_process(void *_arg)
203 {
204         int W0_size, W1_size;
205         struct GradientArg *arg = _arg;
206         int *topo = arg->net->topo;
207         W0_size = (topo[0]+1)*topo[1];
208         W1_size = (topo[1]+1)*topo[2];
209         double W0_grad[W0_size];
210         double W1_grad[W1_size];
211         arg->W0_grad = W0_grad;
212         arg->W1_grad = W1_grad;
213         while (1)
214         {
215                 sem_wait(&sem_begin[arg->id]);
216                 if (arg->done)
217                         break;
218                 arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate);
219                 sem_post(&sem_end[arg->id]);
220         }
221         fprintf(stderr, "done\n");
222         return NULL;
223 }
224
225 float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate)
226 {
227         int i, j;
228         int e;
229         float best_rms = 1e10;
230         int inDim, outDim, hiddenDim;
231         int *topo;
232         double *W0, *W1, *best_W0, *best_W1;
233         double *W0_old, *W1_old;
234         double *W0_old2, *W1_old2;
235         double *W0_grad, *W1_grad;
236         double *W0_oldgrad, *W1_oldgrad;
237         double *W0_rate, *W1_rate;
238         double *best_W0_rate, *best_W1_rate;
239         int W0_size, W1_size;
240         topo = net->topo;
241         W0_size = (topo[0]+1)*topo[1];
242         W1_size = (topo[1]+1)*topo[2];
243         struct GradientArg args[NB_THREADS];
244         pthread_t thread[NB_THREADS];
245         int samplePerPart = nbSamples/NB_THREADS;
246         int count_worse=0;
247         int count_retries=0;
248
249         topo = net->topo;
250         inDim = net->topo[0];
251         hiddenDim = net->topo[1];
252         outDim = net->topo[2];
253         W0 = net->weights[0];
254         W1 = net->weights[1];
255         best_W0 = net->best_weights[0];
256         best_W1 = net->best_weights[1];
257         W0_old = malloc(W0_size*sizeof(double));
258         W1_old = malloc(W1_size*sizeof(double));
259         W0_old2 = malloc(W0_size*sizeof(double));
260         W1_old2 = malloc(W1_size*sizeof(double));
261         W0_grad = malloc(W0_size*sizeof(double));
262         W1_grad = malloc(W1_size*sizeof(double));
263         W0_oldgrad = malloc(W0_size*sizeof(double));
264         W1_oldgrad = malloc(W1_size*sizeof(double));
265         W0_rate = malloc(W0_size*sizeof(double));
266         W1_rate = malloc(W1_size*sizeof(double));
267         best_W0_rate = malloc(W0_size*sizeof(double));
268         best_W1_rate = malloc(W1_size*sizeof(double));
269         memcpy(W0_old, W0, W0_size*sizeof(double));
270         memcpy(W0_old2, W0, W0_size*sizeof(double));
271         memset(W0_grad, 0, W0_size*sizeof(double));
272         memset(W0_oldgrad, 0, W0_size*sizeof(double));
273         memcpy(W1_old, W1, W1_size*sizeof(double));
274         memcpy(W1_old2, W1, W1_size*sizeof(double));
275         memset(W1_grad, 0, W1_size*sizeof(double));
276         memset(W1_oldgrad, 0, W1_size*sizeof(double));
277         
278         rate /= nbSamples;
279         for (i=0;i<hiddenDim;i++)
280                 for (j=0;j<inDim+1;j++)
281                         W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j];
282         for (i=0;i<W1_size;i++)
283                 W1_rate[i] = rate;
284         
285         for (i=0;i<NB_THREADS;i++)
286         {
287                 args[i].net = net;
288                 args[i].inputs = inputs+i*samplePerPart*inDim;
289                 args[i].outputs = outputs+i*samplePerPart*outDim;
290                 args[i].nbSamples = samplePerPart;
291                 args[i].id = i;
292                 args[i].done = 0;
293                 sem_init(&sem_begin[i], 0, 0);
294                 sem_init(&sem_end[i], 0, 0);
295                 pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]);
296         }
297         for (e=0;e<nbEpoch;e++)
298         {
299                 double rms=0;
300                 double error_rate[2] = {0,0};
301                 for (i=0;i<NB_THREADS;i++)
302                 {
303                         sem_post(&sem_begin[i]);
304                 }
305                 memset(W0_grad, 0, W0_size*sizeof(double));
306                 memset(W1_grad, 0, W1_size*sizeof(double));
307                 for (i=0;i<NB_THREADS;i++)
308                 {
309                         sem_wait(&sem_end[i]);
310                         rms += args[i].rms;
311                         error_rate[0] += args[i].error_rate[0];
312             error_rate[1] += args[i].error_rate[1];
313                         for (j=0;j<W0_size;j++)
314                                 W0_grad[j] += args[i].W0_grad[j];
315                         for (j=0;j<W1_size;j++)
316                                 W1_grad[j] += args[i].W1_grad[j];
317                 }
318
319                 float mean_rate = 0, min_rate = 1e10;
320                 rms = (rms/(outDim*nbSamples));
321                 error_rate[0] = (error_rate[0]/(nbSamples));
322         error_rate[1] = (error_rate[1]/(nbSamples));
323                 fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms);
324                 if (rms < best_rms)
325                 {
326                         best_rms = rms;
327                         for (i=0;i<W0_size;i++)
328                         {
329                                 best_W0[i] = W0[i];
330                                 best_W0_rate[i] = W0_rate[i];
331                         }
332                         for (i=0;i<W1_size;i++)
333                         {
334                                 best_W1[i] = W1[i];
335                                 best_W1_rate[i] = W1_rate[i];
336                         }
337                         count_worse=0;
338                         count_retries=0;
339                 } else {
340                         count_worse++;
341                         if (count_worse>30)
342                         {
343                             count_retries++;
344                                 count_worse=0;
345                                 for (i=0;i<W0_size;i++)
346                                 {
347                                         W0[i] = best_W0[i];
348                                         best_W0_rate[i] *= .7;
349                                         if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15;
350                                         W0_rate[i] = best_W0_rate[i];
351                                         W0_grad[i] = 0;
352                                 }
353                                 for (i=0;i<W1_size;i++)
354                                 {
355                                         W1[i] = best_W1[i];
356                                         best_W1_rate[i] *= .8;
357                                         if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
358                                         W1_rate[i] = best_W1_rate[i];
359                                         W1_grad[i] = 0;
360                                 }
361                         }
362                 }
363                 if (count_retries>10)
364                     break;
365                 for (i=0;i<W0_size;i++)
366                 {
367                         if (W0_oldgrad[i]*W0_grad[i] > 0)
368                                 W0_rate[i] *= 1.01;
369                         else if (W0_oldgrad[i]*W0_grad[i] < 0)
370                                 W0_rate[i] *= .9;
371                         mean_rate += W0_rate[i];
372                         if (W0_rate[i] < min_rate)
373                                 min_rate = W0_rate[i];
374                         if (W0_rate[i] < 1e-15)
375                                 W0_rate[i] = 1e-15;
376                         /*if (W0_rate[i] > .01)
377                                 W0_rate[i] = .01;*/
378                         W0_oldgrad[i] = W0_grad[i];
379                         W0_old2[i] = W0_old[i];
380                         W0_old[i] = W0[i];
381                         W0[i] += W0_grad[i]*W0_rate[i];
382                 }
383                 for (i=0;i<W1_size;i++)
384                 {
385                         if (W1_oldgrad[i]*W1_grad[i] > 0)
386                                 W1_rate[i] *= 1.01;
387                         else if (W1_oldgrad[i]*W1_grad[i] < 0)
388                                 W1_rate[i] *= .9;
389                         mean_rate += W1_rate[i];
390                         if (W1_rate[i] < min_rate)
391                                 min_rate = W1_rate[i];
392                         if (W1_rate[i] < 1e-15)
393                                 W1_rate[i] = 1e-15;
394                         W1_oldgrad[i] = W1_grad[i];
395                         W1_old2[i] = W1_old[i];
396                         W1_old[i] = W1[i];
397                         W1[i] += W1_grad[i]*W1_rate[i];
398                 }
399                 mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
400                 fprintf (stderr, "%g %d", mean_rate, e);
401                 if (count_retries)
402                     fprintf(stderr, " %d", count_retries);
403                 fprintf(stderr, "\n");
404                 if (stopped)
405                         break;
406         }
407         for (i=0;i<NB_THREADS;i++)
408         {
409                 args[i].done = 1;
410                 sem_post(&sem_begin[i]);
411                 pthread_join(thread[i], NULL);
412                 fprintf (stderr, "joined %d\n", i);
413         }
414         free(W0_old);
415         free(W1_old);
416         free(W0_grad);
417         free(W1_grad);
418         free(W0_rate);
419         free(W1_rate);
420         return best_rms;
421 }
422
423 int main(int argc, char **argv)
424 {
425         int i, j;
426         int nbInputs;
427         int nbOutputs;
428         int nbHidden;
429         int nbSamples;
430         int nbEpoch;
431         int nbRealInputs;
432         unsigned int seed;
433         int ret;
434         float rms;
435         float *inputs;
436         float *outputs;
437         if (argc!=6)
438         {
439                 fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n");
440                 return 1;
441         }
442         nbInputs = atoi(argv[1]);
443         nbHidden = atoi(argv[2]);
444         nbOutputs = atoi(argv[3]);
445         nbSamples = atoi(argv[4]);
446         nbEpoch = atoi(argv[5]);
447         nbRealInputs = nbInputs;
448         inputs = malloc(nbInputs*nbSamples*sizeof(*inputs));
449         outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs));
450         
451         seed = time(NULL);
452     /*seed = 1361480659;*/
453         fprintf (stderr, "Seed is %u\n", seed);
454         srand(seed);
455         build_tansig_table();
456         signal(SIGTERM, handler);
457         signal(SIGINT, handler);
458         signal(SIGHUP, handler);
459         for (i=0;i<nbSamples;i++)
460         {
461                 for (j=0;j<nbRealInputs;j++)
462                         ret = scanf(" %f", &inputs[i*nbInputs+j]);
463                 for (j=0;j<nbOutputs;j++)
464                         ret = scanf(" %f", &outputs[i*nbOutputs+j]);
465                 if (feof(stdin))
466                 {
467                         nbSamples = i;
468                         break;
469                 }
470         }
471         int topo[3] = {nbInputs, nbHidden, nbOutputs};
472         MLPTrain *net;
473
474         fprintf (stderr, "Got %d samples\n", nbSamples);
475         net = mlp_init(topo, 3, inputs, outputs, nbSamples);
476         rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1);
477         printf ("#include \"mlp.h\"\n\n");
478         printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed);
479         printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]);
480         printf ("\n/* hidden layer */\n");
481         for (i=0;i<(topo[0]+1)*topo[1];i++)
482         {
483                 printf ("%gf, ", net->weights[0][i]);
484                 if (i%5==4)
485                         printf("\n");
486         }
487         printf ("\n/* output layer */\n");
488         for (i=0;i<(topo[1]+1)*topo[2];i++)
489         {
490                 printf ("%g, ", net->weights[1][i]);
491                 if (i%5==4)
492                         printf("\n");
493         }
494         printf ("};\n\n");
495         printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]);
496         printf ("const MLP net = {\n");
497         printf ("\t3,\n");
498         printf ("\ttopo,\n");
499         printf ("\tweights\n};\n");
500         return 0;
501 }