-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] RNN-T + MBR training. #593
Conversation
The model structure is like the diagram below, it has two joiners, one is the joiner for regular RNN-T, the other is |
|
||
self.encoder_output_layer = ScaledLinear( | ||
d_model, num_classes, bias=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transformer lm is actually an Embedding Layer
plus TransformerEncoder
that encode the symbols into text_embedding
.
dropout=dropout, | ||
layer_dropout=layer_dropout, | ||
) | ||
self.enhancer = TransformerDecoder(decoder_layer, num_layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The EmbeddingEnhancer
is a TransformerDecoder
that has self-attention from masked_encoder_output
and cross-attention from text_embedding
.
N, T, C = embedding.shape | ||
mask = torch.randn((N, T, C), device=embedding.device) | ||
mask = mask > mask_proportion | ||
masked_embedding = torch.masked_fill(embedding, ~mask, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I randomly mask the encoder output
here.
) | ||
return init_context | ||
|
||
def delta_wer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function implements the sampling process.
+ l2_loss_scale * l2_loss | ||
+ delta_wer_scale * delta_wer_loss | ||
+ predictor_loss_scale * predictor_loss | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The losses are combined here.
@danpovey @yaozengwei @glynpu Would you please to have a look at this, if there is anything unclear, please let me know. Thanks! |
Sure. I will have a look. |
This PR depends on k2-fsa/k2#1057 in k2.