Skip to content

Commit

Permalink
Update test_linfa.py
Browse files Browse the repository at this point in the history
Update RCR example with new annealing
  • Loading branch information
ercobian authored Dec 29, 2023
1 parent 2329edb commit 7c6995f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions linfa/tests/test_linfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def rcr_nofas_adaann_example(self):
exp.surrogate.surrogate_load()

# Define log density
def log_density(x, model, surrogate, transform):
def log_density(x, model, surrogate, transform, t=1):

# Compute transformation log Jacobian
adjust = transform.compute_log_jacob_func(x)
Expand All @@ -708,13 +708,13 @@ def log_density(x, model, surrogate, transform):
for i in range(3):
ll3 += - 0.5 * torch.sum(((modelOut[:, i].unsqueeze(1) - Data[i, :].unsqueeze(0)) / stds[0, i]) ** 2, dim=1)
negLL = -(ll1 + ll2 + ll3)
res = -negLL.reshape(x.size(0), 1) + adjust
res = -t*negLL.reshape(x.size(0), 1) + adjust

# Return LL
return res

# Assign logdensity model
exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)
exp.model_logdensity = lambda x, t: log_density(x, model, exp.surrogate, trsf, t)

# Run VI
exp.run()
Expand Down

0 comments on commit 7c6995f

Please sign in to comment.