Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

daydreamt
Copy link
Contributor

@daydreamt daydreamt commented Apr 6, 2020

Hi all, since it's been a while I thought I should maybe give a sign of life and continue from here. This tries to implement #559.

There are still some things I haven't figured out myself yet, so I was planning to only request the review when I'm more ready, but of course feel free to take a look if you want already.

  • write tests
    • sampling GumbelSoftmax with low temperatures
    • sampling GumbelSoftmax with high temperatures
    • more tests
  • make working GumbelSoftmaxProbs
    • pass test_log_prob_gradient
    • log_prob in general
    • log_prob with prepended shapes (i.e. the failing test_distribution_constraints test
    • sampling in general
    • discretize at the forward pass, not the backward pass.
  • mean, variance?
  • documentation
    • distributions.rst
    • every test
    • every docstring
  • figure out consistent and proper interface
  • big cleanup before review

@daydreamt daydreamt changed the title WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #559 WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough https://github.com/pyro-ppl/numpyro/issues/559 Apr 6, 2020
@daydreamt daydreamt changed the title WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough https://github.com/pyro-ppl/numpyro/issues/559 WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough Apr 6, 2020
@fehiepsi
Copy link
Member

fehiepsi commented Apr 7, 2020

Hi @daydreamt , thanks for the PR! I think the main blocker of your work would be to define custom derivative rules for some of your operators. I'll update the repo to the latest JAX version today to unblock your work.

@fehiepsi fehiepsi mentioned this pull request Apr 7, 2020
2 tasks
@fehiepsi
Copy link
Member

fehiepsi commented Apr 8, 2020

@daydreamt FYI, I think @tbsexton only needs RelaxedOneHotCategorical (or GumbelSoftmax) in his feature request because he wanted to use MCMC (instead of SVI) to draw samples from the relaxed distribution. @tbsexton could you confirm that StraightThrough is not required?

@rtbs-dev
Copy link

rtbs-dev commented Apr 8, 2020

@fehiepsi I was originally only using HMC, though planning to test out SVI as well. As long as not having access to a backward pass doesn't preclude using NUTS for inference of latent variables, should work!

This is in practice a work-around for not having discrete latent variables; see my original example problem here.

@fehiepsi
Copy link
Member

Thanks, @tbsexton! In your model, you want to infer each ϕ for each cascade, so I guess you can replace

ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
x0 = ny.sample("x0", dist.Categorical(ϕ))
infectious, hist = spread_jax(s_ij, x0, 5)

by

ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
infectious, hist = spread_jax(s_ij, ϕ, 5)

Or if you want the prior for ϕ to be more like discrete, you can choose (or define a prior) a suitable temperature variable and use RelaxedOneHotCategorical

ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
infectious, hist = spread_jax(s_ij, ϕ, 5)

The reason is with RelaxedOneHotCategorical, the support is "simplex", and there is a transform which transforms a simplex to an "unconstrained" value, which is required for HMC/NUTS. The support of RelaxedOneHotCategoricalStraightThrough is discrete, hence there is no such transform.

If you want something like straight through, you can simply use

ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
ϕ_quantize = quantize(ϕ)

by defining "straight-through" quantize operator as in Pyro

def quantize(x):
    return x + jax.lax.stop_gradient((x == np.max(x, -1, keepdims=True)) - x)

. You can use numpyro.deterministic(...) to record those quantized values. I am happy to add new helpers to NumPyro for your convenience when you start using SVI.

@rtbs-dev
Copy link

@fehiepsi much appreciated! I think I should update the model there to reflect som local changes, but primarily I think it makes more sense to pull the dirichlet out of the plates:

def diff_kg(infections):
    n_cascades, n_nodes  = infections.shape
    n_edges = n_nodes*(n_nodes-1)//2 # complete graph
        
    # node initial infection, relative probability
    ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes))) 
    
    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(np.zeros(n_edges), 
                                         np.ones(n_edges)))
    v = ny.sample("v", dist.Gamma(np.ones(n_edges),
                                       20*np.ones(n_edges)))
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
    s_ij = jax_squareform(Λ)  # adjacency matrix to recover via inference
    
    with ny.plate("n_cascades", n_cascades):
        # infer infection source node
        x0 = ny.sample("x0", dist.Categorical(ϕ))
        # simulate ode and realize
        infectious, hist = spread_jax(s_ij, x0, 5)
        numpyro.sample("obs", dist.Bernoulli(probs=infectious), 
                       obs=infections)

The main idea being that certain nodes in general have a tendency to be "sources", represented by the dirichlet prior, and those manifest as conditional probabilities that each node was the source (given any individual observed infection cascade). That should be realized as one node for the spread_jax sim, or at least, very close to one node (therefore the [relaxed]categorical).

Maybe that dirichlet prior is unnecessary partial pooling? I will definitely give the new relaxed categorical a try. @daydreamt would it be helpful if I tested things out before the PR gets merged?

@fehiepsi
Copy link
Member

pull the dirichlet out of the plates

Agree that this makes more sense. With this model, you can define RelaxedOneHotCategorical for x0. (FYI, in PyTorch, Categorical samples are 0, 1, 2, 3. If you want OneHot version, you can use RelaxedOneHotCat... or OneHotCat...)

@fehiepsi fehiepsi added the WIP label Jul 17, 2020
@dirmeier
Copy link
Contributor

dirmeier commented Jun 15, 2022

Hey @daydreamt , any progress on this?

@daydreamt
Copy link
Contributor Author

Hi @dirmeier, not really, please feel free to take over or supersede with another MR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants