-
Notifications
You must be signed in to change notification settings - Fork 1
/
training_helper.py
66 lines (54 loc) · 3.14 KB
/
training_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import time
def train_model(model, dataset, params, ckpt, ckpt_manager, out_file):
optimizer = tf.keras.optimizers.Adagrad(params['learning_rate'], initial_accumulator_value=params['adagrad_init_acc'], clipnorm=params['max_grad_norm'])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=False, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 1))
dec_lens = tf.reduce_sum(tf.cast(mask, dtype=tf.float32), axis=-1)
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
loss_ = tf.reduce_sum(loss_, axis=-1)/dec_lens # we have to make sure no empty abstract is being used otherwise dec_lens may contain null values
return tf.reduce_mean(loss_)
@tf.function(input_signature=(tf.TensorSpec(shape=[params["batch_size"], None], dtype=tf.int32),
tf.TensorSpec(shape=[params["batch_size"], None], dtype=tf.int32),
tf.TensorSpec(shape=[params["batch_size"], params["max_dec_len"]], dtype=tf.int32),
tf.TensorSpec(shape=[params["batch_size"], params["max_dec_len"]], dtype=tf.int32),
tf.TensorSpec(shape=[], dtype=tf.int32)))
def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len):
loss = 0
with tf.GradientTape() as tape:
enc_hidden, enc_output = model.call_encoder(enc_inp)
predictions, _ = model(enc_output, enc_hidden, enc_inp, enc_extended_inp, dec_inp, batch_oov_len)
loss = loss_function(dec_tar, predictions)
variables = model.encoder.trainable_variables + model.attention.trainable_variables + model.decoder.trainable_variables + model.pointer.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
try:
f = open(out_file,"a+")
for batch in dataset:
t0 = time.time()
loss = train_step(batch[0]["enc_input"], batch[0]["extended_enc_input"], batch[1]["dec_input"], batch[1]["dec_target"], batch[0]["max_oov_len"])
print('Step {}, time {:.4f}, Loss {:.4f}'.format(int(ckpt.step),
time.time()-t0,
loss.numpy()))
f.write('Step {}, time {:.4f}, Loss {:.4f}\n'.format(int(ckpt.step),
time.time()-t0,
loss.numpy()))
if int(ckpt.step) == params["max_steps"]:
ckpt_manager.save(checkpoint_number=int(ckpt.step))
print("Saved checkpoint for step {}".format(int(ckpt.step)))
f.close()
break
if int(ckpt.step) % params["checkpoints_save_steps"] ==0 :
ckpt_manager.save(checkpoint_number=int(ckpt.step))
print("Saved checkpoint for step {}".format(int(ckpt.step)))
ckpt.step.assign_add(1)
f.close()
except KeyboardInterrupt:
ckpt_manager.save(int(ckpt.step))
print("Saved checkpoint for step {}".format(int(ckpt.step)))
f.close()