Skip to content

Commit

Permalink
Update optim file
Browse files Browse the repository at this point in the history
  • Loading branch information
chahak13 committed Aug 17, 2023
1 parent 54e21c9 commit a0ce1ce
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions optim_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@
def compute_loss(params, *, solver, target_vel, config):
# material = init_simple({"E": params, "density": 1, "id": -1})
material = init_linear_elastic(
{"youngs_modulus": params, "density": 1, "poisson_ratio": 0, "id": -1}
{
"youngs_modulus": params["ym"],
"density": 1,
"poisson_ratio": 0,
"id": -1,
}
)
# breakpoint()
particles_ = [
init_particle_state(
config.parsed_config["particles"][0].loc,
Expand Down Expand Up @@ -102,11 +106,11 @@ def compute_loss(params, *, solver, target_vel, config):

def optax_adam(params, niter, mpm, target_vel, config):
# Initialize parameters of the model + optimizer.
start_learning_rate = 1
start_learning_rate = 4
optimizer = optax.adam(start_learning_rate)
opt_state = optimizer.init(params)

param_list = []
param_list = {"ym": [], "pr": []}
loss_list = []
# A simple update loop.
t = tqdm(range(niter), desc=f"E: {params}")
Expand All @@ -115,16 +119,23 @@ def optax_adam(params, niter, mpm, target_vel, config):
lo, grads = jax.value_and_grad(partial_f, argnums=0)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
t.set_description(f"YM: {params}")
param_list.append(params)
t.set_description(f"YM: {params['ym']:.2f}")
param_list["ym"].append(params["ym"])
# param_list["pr"].append(params["pr"])
loss_list.append(lo)
return param_list, loss_list


params = 900.5
# params = {"pr": 0.4}
params = {"ym": 1101.0}
# material = init_simple({"E": params, "density": 1, "id": -1})
material = init_linear_elastic(
{"youngs_modulus": params, "density": 1, "poisson_ratio": 0, "id": -1}
{
"youngs_modulus": params["ym"],
"density": 1,
"poisson_ratio": 0,
"id": -1,
}
)
particles = [
init_particle_state(
Expand All @@ -142,11 +153,11 @@ def optax_adam(params, niter, mpm, target_vel, config):
}
)
param_list, loss_list = optax_adam(
params, 100, solver, true_vel, config
params, 200, solver, true_vel, config
) # ADAM optimizer

fig, ax = plt.subplots(1, 2, figsize=(16, 6))
ax[0].plot(param_list, "ko", markersize=2, label="E")
ax[0].plot(param_list["ym"], "ko", markersize=2, label="E")
ax[0].grid()
ax[0].legend()
ax[1].plot(loss_list, "ko", markersize=2, label="Loss")
Expand Down

0 comments on commit a0ce1ce

Please sign in to comment.