You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the paper, I see there is a comparison between different models. Where can I find the corresponding code that calculates the validation reconstruction loss? Thank you!
The text was updated successfully, but these errors were encountered:
I am not expert on this model but I will make my own comment. I think you can calculate it with the following line of code
((g - x)**2).sum() / x.shape[0]
where x is the ground truth batch data and g is the output of the model (after feeding the negative gradient as the latent). Therefore, I use the following piece of code to calculate the test set summed squared error and mean squared error.
epoch_mse_loss_tst = 0 # Mean Squared Error, initially zero
epoch_sse_loss_tst = 0 # Sum Squared Error, initially zero
for j,(x,t) in enumerate(tst_loader): # Run over the test set
x = x.to(device)
z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
g = F(z)
L_inner = ((g - x)**2).sum(1).mean()
grad = torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
z = (-grad)
g = F(z)
L_outer = ((g - x)**2).sum(1).mean() # Calculate batch mean squared error
sse = ((g - x)**2).sum() / x.shape[0] # Calculate batch summed squared error
epoch_mse_loss_tst += L_outer.item() # Update epoch MSE Loss
epoch_sse_loss_tst += sse.item() # Update epoch SSE Loss
for w in F.parameters(): # Make all the grads 0 after inference (since they do not contribute to training)
w.grad.data.zero_()
epoch_sse_loss_tst = epoch_sse_loss_tst/j
epoch_mse_loss_tst = epoch_mse_loss_tst/j
Hi,
In the paper, I see there is a comparison between different models. Where can I find the corresponding code that calculates the validation reconstruction loss? Thank you!
The text was updated successfully, but these errors were encountered: