-
Notifications
You must be signed in to change notification settings - Fork 2
/
grammar_train.py
executable file
·93 lines (82 loc) · 3.56 KB
/
grammar_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/python3
'''Sequence to sequence grammar check.
'''
from __future__ import print_function
import math
from keras.models import Model
from keras.layers import Input, LSTM, CuDNNLSTM, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D
from keras.optimizers import Adam
from keras import backend as K
import numpy as np
import h5py
import sys
import encoding
import deepproof_model
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.44
set_session(tf.Session(config=config))
batch_size = 128 # Batch size for training.
epochs = 1 # Number of epochs to train for.
encoder_model, decoder_model, model = deepproof_model.create(True)
input_text = None
output_text = None
for file in sys.argv[1:]:
with h5py.File(file, 'r') as hf:
if input_text is None:
input_text = hf['input'][:]
output_text = hf['output'][:]
else:
input_text = np.concatenate([input_text, hf['input'][:]])
output_text = np.concatenate([output_text, hf['output'][:]])
#input_text = input_text[0:8000, :]
#output_text = output_text[0:8000, :]
input_data = np.reshape(input_text, (input_text.shape[0], input_text.shape[1], 1))
decoder_target_data = np.reshape(output_text, (output_text.shape[0], output_text.shape[1], 1))
decoder_input_data = np.zeros((input_text.shape[0], input_text.shape[1], 1), dtype='uint8')
decoder_input_data[:,1:,:] = decoder_target_data[:,:-1,:]
max_decoder_seq_length = input_text.shape[1]
num_encoder_tokens = len(encoding.char_list)
print("Number of sentences: ", input_text.shape[0])
print("Sentence length: ", input_text.shape[1])
print("Number of chars: ", num_encoder_tokens)
# Run training
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
#model.load_weights('proof7c.h5')
model.summary()
model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
# Save model
model.save('proof8b.h5')
model.compile(optimizer=Adam(0.0003), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save('proof8b2.h5')
model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save('proof8b3.h5')
model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save('proof8b4.h5')
start = int(.9*input_text.shape[0])
for seq_index in range(start, start+1000):
# Take one sequence (part of the training test)
# for trying out decoding.
input_seq = input_data[seq_index: seq_index + 1]
decoded_sentence0 = deepproof_model.decode_sequence([encoder_model, decoder_model], input_seq)
decoded_sentence = deepproof_model.beam_decode_sequence([encoder_model, decoder_model], input_seq)
deepproof_model.decode_ground_truth([encoder_model, decoder_model], input_seq, output_text[seq_index,:])
print('-')
print('Input sentence: ', encoding.decode_string(input_text[seq_index,:]))
print('Decoded sentence0:', decoded_sentence0)
print('Decoded sentence: ', decoded_sentence)
print('Original sentence:', encoding.decode_string(output_text[seq_index,:]))