-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
42 lines (32 loc) · 1.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
import getopt
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import Masking
from keras.callbacks import ModelCheckpoint
from yaoai import setup
def main(argv):
sentence_length = 40
dataX, dataY, story, word_to_int, int_to_word = setup(sentence_length, './input.txt')
print(dataX)
model = Sequential([
Masking(mask_value=0.0, input_shape=(dataX.shape[1:])),
LSTM(256, return_sequences=True),
Dropout(0.2),
LSTM(256, return_sequences=True),
Dropout(0.2),
LSTM(256, return_sequences=True),
Dropout(0.2),
Dense(dataY.shape[2], activation='softmax'),
])
model.compile(loss='categorical_crossentropy', optimizer='adam')
file_path = 'word-LSTM-{epoch:02d}.hdf5'
checkpoint = ModelCheckpoint(file_path, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
model.fit(dataX, dataY, epochs=200, batch_size=sentence_length, callbacks=callbacks_list)
file_path = 'model.hdf5'
model.save(file_path)
if __name__ == '__main__':
main(sys.argv)