-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562
base: master
Are you sure you want to change the base?
WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562
Conversation
…more reading on relaxed_categorical and transformations of distributions instead
…pansion where needed
Hi @daydreamt , thanks for the PR! I think the main blocker of your work would be to define custom derivative rules for some of your operators. I'll update the repo to the latest JAX version today to unblock your work. |
@daydreamt FYI, I think @tbsexton only needs RelaxedOneHotCategorical (or GumbelSoftmax) in his feature request because he wanted to use MCMC (instead of SVI) to draw samples from the relaxed distribution. @tbsexton could you confirm that StraightThrough is not required? |
@fehiepsi I was originally only using HMC, though planning to test out SVI as well. As long as not having access to a backward pass doesn't preclude using NUTS for inference of latent variables, should work! This is in practice a work-around for not having discrete latent variables; see my original example problem here. |
Thanks, @tbsexton! In your model, you want to infer each ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
x0 = ny.sample("x0", dist.Categorical(ϕ))
infectious, hist = spread_jax(s_ij, x0, 5) by ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
infectious, hist = spread_jax(s_ij, ϕ, 5) Or if you want the prior for ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
infectious, hist = spread_jax(s_ij, ϕ, 5) The reason is with If you want something like straight through, you can simply use
by defining "straight-through"
. You can use |
@fehiepsi much appreciated! I think I should update the model there to reflect som local changes, but primarily I think it makes more sense to pull the dirichlet out of the plates: def diff_kg(infections):
n_cascades, n_nodes = infections.shape
n_edges = n_nodes*(n_nodes-1)//2 # complete graph
# node initial infection, relative probability
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
# beta hyperpriors
u = ny.sample("u", dist.Uniform(np.zeros(n_edges),
np.ones(n_edges)))
v = ny.sample("v", dist.Gamma(np.ones(n_edges),
20*np.ones(n_edges)))
Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
s_ij = jax_squareform(Λ) # adjacency matrix to recover via inference
with ny.plate("n_cascades", n_cascades):
# infer infection source node
x0 = ny.sample("x0", dist.Categorical(ϕ))
# simulate ode and realize
infectious, hist = spread_jax(s_ij, x0, 5)
numpyro.sample("obs", dist.Bernoulli(probs=infectious),
obs=infections) The main idea being that certain nodes in general have a tendency to be "sources", represented by the dirichlet prior, and those manifest as conditional probabilities that each node was the source (given any individual observed infection cascade). That should be realized as one node for the Maybe that dirichlet prior is unnecessary partial pooling? I will definitely give the new relaxed categorical a try. @daydreamt would it be helpful if I tested things out before the PR gets merged? |
Agree that this makes more sense. With this model, you can define RelaxedOneHotCategorical for |
Hey @daydreamt , any progress on this? |
Hi @dirmeier, not really, please feel free to take over or supersede with another MR. |
Hi all, since it's been a while I thought I should maybe give a sign of life and continue from here. This tries to implement #559.
There are still some things I haven't figured out myself yet, so I was planning to only request the review when I'm more ready, but of course feel free to take a look if you want already.