Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYOL loss function #1058

Open
vishnu-dev opened this issue Jul 16, 2023 · 1 comment
Open

BYOL loss function #1058

vishnu-dev opened this issue Jul 16, 2023 · 1 comment
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@vishnu-dev
Copy link
Contributor

vishnu-dev commented Jul 16, 2023

🐛 Bug

The loss function in BYOL doesn't seem to match the one defined in the paper.

To Reproduce

Running the training for BYOL.

Code sample

def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor:
"""Calculates similarity loss between the online network prediction of target network projection.
Args:
v_online (Tensor): Online network view
v_target (Tensor): Target network view
"""
_, z1 = self.online_network(v_online)
h1 = self.predictor(z1)
with torch.no_grad():
_, z2 = self.target_network(v_target)
return -2 * F.cosine_similarity(h1, z2).mean()

Expected behavior

The loss should be 2 - (2 * F.cosine_similarity(h1, z2).mean())?

Screenshot 2023-07-16 at 3 22 50 PM
@vishnu-dev vishnu-dev added bug Something isn't working help wanted Extra attention is needed labels Jul 16, 2023
@Borda
Copy link
Member

Borda commented Aug 31, 2023

Thank you, @vishnu-dev mind sending PR with the fix? 🐰

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants