Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where can I obtain the validation reconstruction loss? #4

Open
Nianzhen-GU opened this issue Dec 6, 2021 · 1 comment
Open

Where can I obtain the validation reconstruction loss? #4

Nianzhen-GU opened this issue Dec 6, 2021 · 1 comment

Comments

@Nianzhen-GU
Copy link

Hi,

In the paper, I see there is a comparison between different models. Where can I find the corresponding code that calculates the validation reconstruction loss? Thank you!
Screen Shot 2021-12-06 at 11 00 52 AM

@BariscanBozkurt
Copy link

BariscanBozkurt commented Dec 14, 2021

I am not expert on this model but I will make my own comment. I think you can calculate it with the following line of code

((g - x)**2).sum() / x.shape[0]

where x is the ground truth batch data and g is the output of the model (after feeding the negative gradient as the latent). Therefore, I use the following piece of code to calculate the test set summed squared error and mean squared error.

epoch_mse_loss_tst = 0 # Mean Squared Error, initially zero
epoch_sse_loss_tst = 0 # Sum Squared Error, initially zero
for j,(x,t) in enumerate(tst_loader): # Run over the test set
    x = x.to(device)
    z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
    g = F(z)
    L_inner = ((g - x)**2).sum(1).mean()
    grad = torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
    z = (-grad)

    g = F(z)
    L_outer = ((g - x)**2).sum(1).mean() # Calculate batch mean squared error
    sse = ((g - x)**2).sum() / x.shape[0] # Calculate batch summed squared error
    epoch_mse_loss_tst += L_outer.item() # Update epoch MSE Loss
    epoch_sse_loss_tst += sse.item() # Update epoch SSE Loss
    for w in F.parameters(): # Make all the grads 0 after inference (since they do not contribute to training)
        w.grad.data.zero_()

epoch_sse_loss_tst = epoch_sse_loss_tst/j
epoch_mse_loss_tst = epoch_mse_loss_tst/j

I hope it is helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants