Fix comma that should have been a semicolon
[opus.git] / scripts / rnn_train.py
1 #!/usr/bin/python
2
3 from __future__ import print_function
4
5 from keras.models import Sequential
6 from keras.models import Model
7 from keras.layers import Input
8 from keras.layers import Dense
9 from keras.layers import LSTM
10 from keras.layers import GRU
11 from keras.layers import SimpleRNN
12 from keras.layers import Dropout
13 from keras import losses
14 import h5py
15
16 from keras import backend as K
17 import numpy as np
18
19 def binary_crossentrop2(y_true, y_pred):
20     return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
21
22 print('Build model...')
23 #model = Sequential()
24 #model.add(Dense(16, activation='tanh', input_shape=(None, 25)))
25 #model.add(GRU(12, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True))
26 #model.add(Dense(2, activation='sigmoid'))
27
28 main_input = Input(shape=(None, 25), name='main_input')
29 x = Dense(16, activation='tanh')(main_input)
30 x = GRU(12, dropout=0.1, recurrent_dropout=0.1, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x)
31 x = Dense(2, activation='sigmoid')(x)
32 model = Model(inputs=main_input, outputs=x)
33
34 batch_size = 64
35
36 print('Loading data...')
37 with h5py.File('features.h5', 'r') as hf:
38     all_data = hf['features'][:]
39 print('done.')
40
41 window_size = 1500
42
43 nb_sequences = len(all_data)/window_size
44 print(nb_sequences, ' sequences')
45 x_train = all_data[:nb_sequences*window_size, :-2]
46 x_train = np.reshape(x_train, (nb_sequences, window_size, 25))
47
48 y_train = np.copy(all_data[:nb_sequences*window_size, -2:])
49 y_train = np.reshape(y_train, (nb_sequences, window_size, 2))
50
51 all_data = 0;
52 x_train = x_train.astype('float32')
53 y_train = y_train.astype('float32')
54
55 print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
56
57 # try using different optimizers and different optimizer configs
58 model.compile(loss=binary_crossentrop2,
59               optimizer='adam',
60               metrics=['binary_accuracy'])
61
62 print('Train...')
63 model.fit(x_train, y_train,
64           batch_size=batch_size,
65           epochs=200,
66           validation_data=(x_train, y_train))
67 model.save("newweights.hdf5")