-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicitly pass the temperature parameter from (adaptive) tempered SMC to the MCMC kernel #495
Comments
I agree with the above. We have had several chats with @rlouf over the past on this specific point. Food for thought: The "only" way I see to do this is to abstract away the concept of the prior/loglikelihood into a form of There is also the fact that the log-likelihood is not allowed to change in the definition of the elliptical slice sampler or the Gaussian samplers. Being able to combine both would likely require some refactoring on this end too. |
An additional point to note: at the moment, at least the Gaussian sampler is parametrised in a way that reduces the need for computation:
Changing the tempering parameter from under it without updating the |
Current behavior
Currently, the (adaptive) tempered SMC kernel samples (as desired) from
lmbda * loglikelihood + logprior
, whereloglikelihood
andlogprior
are densities provided by the user. This temperature is then scaled from 0 to 1 to eventually sample from the target posterior. This works as intended when combined with other Blackjax kernels that take a (posterior) logdensity as argument, such as RMH, HMC and NUTS.However, other kernels, such as
elliptical_slice
ormgrad_gaussian
take only the loglikelihood as an argument, as here the prior is built into the model. The same typically applies for custom Gibbs kernels, especially if they describe a hierarchical model; we'd need to construct several target densities within the Gibbs kernel, some that require the temperature, others that do not. This makes it so SMC cannot be combined with the last-mentioned kernels (or it can, but this would not effectively temper the target distribution: we would always sample from the posterior, and use the tempering only in reweighing the SMC particles. This hampers exploration, the key benefit of SMC).The relevant code is on lines 128 and 132 in https://github.com/blackjax-devs/blackjax/blob/main/blackjax/smc/tempered.py:
and
In both cases, the MCMC-within-SMC kernel is called with the
tempered_logposterior_fn
, in whichlmbda
has already been incorporated.Desired behavior
Since states are NamedTuples which are immutable, but users have to define their own
mcmc_init_fn
anyway,lmbda
could be passed as an argument tomcmc_init_fn
, which leaves it up to the user to decide whether it is relevant to store this temperature in the MCMC state object. In our own use-case, we'd take this temperature from the MCMC state and use it to determine something like:I hope this enhancement is possible, it would greatly ease our way of working with Blackjax! :-)
The text was updated successfully, but these errors were encountered: