-
Notifications
You must be signed in to change notification settings - Fork 17
/
train_asr.py
225 lines (191 loc) · 9.61 KB
/
train_asr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#! python
# -*- coding: utf-8 -*-
# Author: kun
# @Time: 2019-10-29 20:36
import torch
from core.solver import BaseSolver
from core.asr import ASR
from core.optim import Optimizer
from core.data import load_dataset
from core.util import human_format, cal_er, feat_to_fig
class Solver(BaseSolver):
''' Solver for training'''
def __init__(self, config, paras, mode):
super().__init__(config, paras, mode)
# Logger settings
self.best_wer = {'att': 3.0, 'ctc': 3.0}
# Curriculum learning affects data loader
self.curriculum = self.config['hparas']['curriculum']
def fetch_data(self, data):
''' Move data to device and compute text seq. length'''
_, feat, feat_len, txt = data
feat = feat.to(self.device)
feat_len = feat_len.to(self.device)
txt = txt.to(self.device)
txt_len = torch.sum(txt != 0, dim=-1)
return feat, feat_len, txt, txt_len
def load_data(self):
print("Load data for training/validation, store tokenizer and input/output shape")
self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
self.curriculum > 0, **self.config['data'])
self.verbose(msg)
def set_model(self):
print("Setup ASR model and optimizer ")
# Model
self.model = ASR(self.feat_dim, self.vocab_size, **
self.config['model']).to(self.device)
self.verbose(self.model.create_msg())
model_paras = [{'params': self.model.parameters()}]
print("# Losses")
self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
print("# Note: zero_infinity=False is unstable?")
self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)
print("Plug-ins")
self.emb_fuse = False
self.emb_reg = ('emb' in self.config) and (
self.config['emb']['enable'])
if self.emb_reg:
from core.plugin import EmbeddingRegularizer
self.emb_decoder = EmbeddingRegularizer(
self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
model_paras.append({'params': self.emb_decoder.parameters()})
self.emb_fuse = self.emb_decoder.apply_fuse
if self.emb_fuse:
self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
self.verbose(self.emb_decoder.create_msg())
print("# Optimizer")
self.optimizer = Optimizer(model_paras, **self.config['hparas'])
self.verbose(self.optimizer.create_msg())
print("# Enable AMP if needed")
self.enable_apex()
# Automatically load pre-trained model if self.paras.load is given
self.load_ckpt()
# ToDo: other training methods
def exec(self):
print("Training End-to-end ASR system")
self.verbose('Total training steps {}.'.format(
human_format(self.max_step)))
ctc_loss, att_loss, emb_loss = None, None, None
n_epochs = 0
self.timer.set()
while self.step < self.max_step:
# Renew dataloader to enable random sampling
if self.curriculum > 0 and n_epochs == self.curriculum:
self.verbose(
'Curriculum learning ends after {} epochs, starting random sampling.'.format(n_epochs))
self.tr_set, _, _, _, _, _ = \
load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
False, **self.config['data'])
print("self.tr_set: {}".format(len(self.tr_set)))
for data in self.tr_set:
# print("Pre-step : update tf_rate/lr_rate and do zero_grad")
tf_rate = self.optimizer.pre_step(self.step)
total_loss = 0
# Fetch data
feat, feat_len, txt, txt_len = self.fetch_data(data)
self.timer.cnt('rd')
# Forward model
# Note: txt should NOT start w/ <sos>
ctc_output, encode_len, att_output, att_align, dec_state = \
self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
teacher=txt, get_dec_state=self.emb_reg)
# Plugins
if self.emb_reg:
emb_loss, fuse_output = self.emb_decoder(
dec_state, att_output, label=txt)
total_loss += self.emb_decoder.weight * emb_loss
# Compute all objectives
if ctc_output is not None:
if self.paras.cudnn_ctc:
ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
txt.to_sparse().values().to(device='cpu', dtype=torch.int32),
[ctc_output.shape[1]] *
len(ctc_output),
txt_len.cpu().tolist())
else:
ctc_loss = self.ctc_loss(ctc_output.transpose(
0, 1), txt, encode_len, txt_len)
total_loss += ctc_loss * self.model.ctc_weight
if att_output is not None:
b, t, _ = att_output.shape
att_output = fuse_output if self.emb_fuse else att_output
att_loss = self.seq_loss(
att_output.view(b * t, -1), txt.view(-1))
total_loss += att_loss * (1 - self.model.ctc_weight)
self.timer.cnt('fw')
# Backprop
grad_norm = self.backward(total_loss)
self.step += 1
# Logger
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
.format(total_loss.cpu().item(), grad_norm, self.timer.show()))
self.write_log(
'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss})
self.write_log('emb_loss', {'tr': emb_loss})
self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt),
'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
if self.emb_fuse:
if self.emb_decoder.fuse_learnable:
self.write_log('fuse_lambda', {
'emb': self.emb_decoder.get_weight()})
self.write_log(
'fuse_temp', {'temp': self.emb_decoder.get_temp()})
# Validation
if (self.step == 1) or (self.step % self.valid_step == 0):
self.validate()
# End of step
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
torch.cuda.empty_cache()
self.timer.set()
if self.step > self.max_step:
break
n_epochs += 1
self.log.close()
def validate(self):
# Eval mode
self.model.eval()
if self.emb_decoder is not None:
self.emb_decoder.eval()
dev_wer = {'att': [], 'ctc': []}
for i, data in enumerate(self.dv_set):
self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
# Fetch data
feat, feat_len, txt, txt_len = self.fetch_data(data)
# Forward model
with torch.no_grad():
ctc_output, encode_len, att_output, att_align, dec_state = \
self.model(feat, feat_len, int(max(txt_len) * self.DEV_STEP_RATIO),
emb_decoder=self.emb_decoder)
dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
dev_wer['ctc'].append(
cal_er(self.tokenizer, ctc_output, txt, ctc=True))
# Show some example on tensorboard
if i == len(self.dv_set) // 2:
for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
if self.step == 1:
self.write_log('true_text{}'.format(
i), self.tokenizer.decode(txt[i].tolist()))
if att_output is not None:
self.write_log('att_align{}'.format(i), feat_to_fig(
att_align[i, 0, :, :].cpu().detach()))
self.write_log('att_text{}'.format(i), self.tokenizer.decode(
att_output[i].argmax(dim=-1).tolist()))
if ctc_output is not None:
self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(),
ignore_repeat=True))
# Ckpt if performance improves
self.save_checkpoint('latest.pth', 'wer',
dev_wer['att'], show_msg=False)
for task in ['att', 'ctc']:
dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
if dev_wer[task] < self.best_wer[task]:
self.best_wer[task] = dev_wer[task]
self.save_checkpoint('best_{}.pth'.format(
task), 'wer', dev_wer[task])
self.write_log('wer', {'dv_' + task: dev_wer[task]})
# Resume training
self.model.train()
if self.emb_decoder is not None:
self.emb_decoder.train()