Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly pass the temperature parameter from (adaptive) tempered SMC to the MCMC kernel #495

Open
maxhinne opened this issue Feb 22, 2023 · 2 comments

Comments

@maxhinne
Copy link

Current behavior

Currently, the (adaptive) tempered SMC kernel samples (as desired) from lmbda * loglikelihood + logprior, where loglikelihood and logprior are densities provided by the user. This temperature is then scaled from 0 to 1 to eventually sample from the target posterior. This works as intended when combined with other Blackjax kernels that take a (posterior) logdensity as argument, such as RMH, HMC and NUTS.

However, other kernels, such as elliptical_slice or mgrad_gaussian take only the loglikelihood as an argument, as here the prior is built into the model. The same typically applies for custom Gibbs kernels, especially if they describe a hierarchical model; we'd need to construct several target densities within the Gibbs kernel, some that require the temperature, others that do not. This makes it so SMC cannot be combined with the last-mentioned kernels (or it can, but this would not effectively temper the target distribution: we would always sample from the posterior, and use the tempering only in reweighing the SMC particles. This hampers exploration, the key benefit of SMC).

The relevant code is on lines 128 and 132 in https://github.com/blackjax-devs/blackjax/blob/main/blackjax/smc/tempered.py:

state = mcmc_init_fn(position, tempered_logposterior_fn)

and

new_state, info = mcmc_step_fn(
    rng_key, state, tempered_logposterior_fn, **mcmc_parameters
)

In both cases, the MCMC-within-SMC kernel is called with the tempered_logposterior_fn, in which lmbda has already been incorporated.

Desired behavior

Since states are NamedTuples which are immutable, but users have to define their own mcmc_init_fn anyway, lmbda could be passed as an argument to mcmc_init_fn, which leaves it up to the user to decide whether it is relevant to store this temperature in the MCMC state object. In our own use-case, we'd take this temperature from the MCMC state and use it to determine something like:

tempered_loglikelihood = lambda state: state.lmbda*loglikelihood_fn(state)

I hope this enhancement is possible, it would greatly ease our way of working with Blackjax! :-)

@AdrienCorenflos
Copy link
Contributor

I agree with the above. We have had several chats with @rlouf over the past on this specific point.

Food for thought:

The "only" way I see to do this is to abstract away the concept of the prior/loglikelihood into a form of mcmc_factory which would be user defined. It is however not clear how do this in a clean way given that the log-likelihood is needed to compute the importance weights at each tempering step.

There is also the fact that the log-likelihood is not allowed to change in the definition of the elliptical slice sampler or the Gaussian samplers. Being able to combine both would likely require some refactoring on this end too.

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Feb 23, 2023

An additional point to note: at the moment, at least the Gaussian sampler is parametrised in a way that reduces the need for computation:

Changing the tempering parameter from under it without updating the U_grad_x of the state would result in an invalid algorithm. In practice, it is easy to do, but we can't expect the user (even an expert one!) to know low-level implementation details to this extent. I'm sure there are other examples of this in the library, where changing the target is very non-obvious.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants