Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong implementation of l2 attention #8

Open
PeterL1n opened this issue Jun 13, 2023 · 0 comments
Open

Wrong implementation of l2 attention #8

PeterL1n opened this issue Jun 13, 2023 · 0 comments

Comments

@PeterL1n
Copy link

PeterL1n commented Jun 13, 2023

https://github.com/mlpc-ucsd/ViTGAN/blob/d57078e49a1f6f8a0588f3a53b26c95c6c12cd1f/models/gan/stylegan2/vit_common.py#LL144C54-L144C54

Here. Why is there still a separateq,k? Shouldn't they be completely tied?

Also, I think the proper way is to use sum instead of mean. Also you need to negate the l2 distance.

This is the correct implementation

AB = torch.matmul(qk, qk.transpose(-1, -2))
AA = torch.sum(qk ** 2, -1, keepdim=True)
BB = AA.transpose(-1, -2)    # Since query and key are tied.
attn = -(AA - 2 * AB + BB)
attn = attn.mul(self.scale).softmax(-1)
@PeterL1n PeterL1n changed the title Bug in tying QK for L2 attention Wrong implementation of l2 attention Jun 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant