-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing padding_mask to RetNet #85
Comments
In In inference, the padding token doesn't influence the subsequent encoding, maybe just skipping it is enough? |
Thank you for the quick reply. My reasoning for the parallel code was so that the decay would start from the first non-pad token instead of an arbitrary For the Would you be interested in merging this code to the torchscale package? I will fork the repo with the changes if that's the case. Thank you for the help nonetheless :) |
As opposed to the other architectures in this package, RetNet doesn't have support for padding as far as I'm aware. I was thinking the best place to introduce it was along with the positional mask. Here we don't have the luxury of the softmax, so we can't simply mask with infinity in the relevant positions.
From my attempt, the parallel code would be something along the following (assuming left padding and a padding_mask shape of (bsz, seq_len):
This would imply expanding the mask here instead of broadcasting it in the forward method.
In the recurrent formulation, perhaps masking the scaling factor accordingly works?
I would like some help on this, perhaps the authors have a better approach? @donglixp @sunyt32
The text was updated successfully, but these errors were encountered: