Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for issue # 813 modified the way kwargs are handled #863

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

tanishy7777
Copy link
Contributor

Solves issue #813

  • Make sure all test pass.
  • Make sure your code passes black.
  • Make sure your code passes pylint.

@tomicapretto
Copy link
Collaborator

Hi @tanishy7777 thanks for working on this, and sorry for the delay in the response. I think we should do one of the following

  • Omit a warning saying that some kwargs are not used by the JAX based sampler (as @GStechschulte mentions in the issue)
  • Do nothing, which is the current behavior.

I'm not sure if the proposed changes are what we want. If I understand correctly, it will set up kwargs["num_steps"] to draws when it's not None, but would leave it unset when it's None (using whatever default value applies for that sampler in the library that implements it).

What if we have the following example:

kwargs = {
        "num_draws": 500,
}

blackjax_nuts_idata = model.fit(draws=250, inference_method="blackjax_nuts", **kwargs)

The user is passing through kwargs["num_draws"] they want 500 draws but we then replace that with 250, which is the one passed to draws=250.

This example makes me think we should just omit a warning in that case, and respect the name required by the underlying sampler.


A separate comment, it's always good to add a test when one implements a change in behavior like this one. Just let me/us know if you would need help with that.

@GStechschulte
Copy link
Collaborator

GStechschulte commented Dec 16, 2024

Thanks @tomicapretto for catching this.

bayeux uses the argument names (often num_chains for example) from the underlying libraries (blackjax, numpyro, tfp, etc.). However, if a library does not use that argument name, bayeux will silently ignore unused kwargs. Therefore, we may need to add logic according to the library to ensure the correct names are passed down.

Below is a comment and code snippet from Colin

The only samplers are from blackjax, numpyro, nutpie, tfp, and flowMC. bayeux uses the argument names from underlying libraries

for method in model.mcmc.methods:
  sampler = getattr(model.mcmc, method)
  print(method)
  for v in ('chains', 'num_chains', 'num_draws', 'num_samples', 'num_results'):
    if any(k == v for step in sampler.get_kwargs().values() for k in step):
      print('\t', v)
tfp_hmc
	 num_chains
	 num_draws
tfp_nuts
	 num_chains
	 num_draws
tfp_snaper_hmc
	 num_chains
	 num_results
blackjax_hmc
	 num_chains
	 num_draws
blackjax_chees_hmc
	 num_chains
	 num_draws
blackjax_meads_hmc
	 num_chains
	 num_draws
blackjax_nuts
	 num_chains
	 num_draws
blackjax_hmc_pathfinder
	 num_chains
	 num_draws
blackjax_nuts_pathfinder
	 num_chains
	 num_draws
numpyro_hmc
	 num_chains
	 num_samples
numpyro_nuts
	 num_chains
	 num_samples

As an aside, we should make sure changes here are compatible with #855. I do not think there is any major conflict, but good to double check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants