You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, great initial work wrapping numpyro and pyro into a more user friendly interface! I'm having an issue with a few simple models where the numpyro backend gives me the following error: Cannot find valid initial parameters. Please check your model again.
It seems to occur when there are many predictors in the formula.
Pyro sampling and SVI both work fine for this model with the default Cauchy beta priors. Any thoughts on better initializing numpyro NUTS with SVI or perhaps using maximum a posteriori estimates?
It's tough to figure out what exactly is causing MCMC to immediately fail but I'm assuming it's the initial starting values. Full traceback:
RuntimeError Traceback (most recent call last)
<ipython-input-185-9da937b2eeba> in <module>
----> 1 fit = model.fit(backend=numpyro, seed=8877, iter=1000, warmup=500)
/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in fit(self, algo, **kwargs)
173 """
174 assert algo in ['prior', 'nuts', 'svi']
--> 175 return getattr(self, algo)(**kwargs)
176
177 def nuts(self, iter=10, warmup=None, num_chains=1, seed=None, backend=numpyro_backend):
/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in nuts(self, iter, warmup, num_chains, seed, backend)
200 """
201 warmup = iter // 2 if warmup is None else warmup
--> 202 return self.run_algo('nuts', backend, iter, warmup, num_chains, seed)
203
204 def svi(self, iter=10, num_samples=10, seed=None, backend=pyro_backend, **kwargs):
/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, backend, df, *args, **kwargs)
154 data = self.model.encode(df) if df is not None else self.data
155 assets_wrapper = self.model.gen(backend)
--> 156 return assets_wrapper.run_algo(name, data_from_numpy(backend, data), *args, **kwargs)
157
158 def fit(self, algo='nuts', **kwargs):
/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, data, *args, **kwargs)
75
76 def run_algo(self, name, data, *args, **kwargs):
---> 77 samples = getattr(self.backend, name)(data, self.assets, *args, **kwargs)
78 return Fit(self.model.formula, self.model.metadata,
79 self.model.contrasts, data,
/opt/conda/lib/python3.6/site-packages/brmp/numpyro_backend.py in nuts(data, assets, iter, warmup, num_chains, seed)
86 # `num_chains` > 1 to achieve parallel chains.
87 mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains)
---> 88 mcmc.run(rng, **data)
89 samples = mcmc.get_samples(group_by_chain=True)
90
/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, collect_warmup, init_params, *args, **kwargs)
639 if self.num_chains == 1:
640 states_flat = self._single_chain_mcmc((rng_key, init_params), collect_fields, collect_warmup,
--> 641 args, kwargs)
642 states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
643 else:
/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, collect_fields, collect_warmup, args, kwargs)
582 rng_key, init_params = init
583 init_state, constrain_fn = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 584 model_args=args, model_kwargs=kwargs)
585 if self.constrain_fn is None:
586 constrain_fn = identity if constrain_fn is None else constrain_fn
/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
409 rng_key, rng_key_init_model = np.swapaxes(vmap(random.split)(rng_key), 0, 1)
410 init_params_, self.potential_fn, constrain_fn = initialize_model(
--> 411 rng_key_init_model, self.model, *model_args, init_strategy=self.init_strategy, **model_kwargs)
412 if init_params is None:
413 init_params = init_params_
/opt/conda/lib/python3.6/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, *model_args, **model_kwargs)
413 if not_jax_tracer(is_valid):
414 if device_get(~np.all(is_valid)):
--> 415 raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
416 return init_params, potential_fn, constrain_fun
417
RuntimeError: Cannot find valid initial parameters. Please check your model again.
The text was updated successfully, but these errors were encountered:
Thanks for reporting this. I guess you're correct, and that we maybe we need to think more about initialization strategies at some point. If possible, could you share a simple example that reproduces the problem, so we might check this isn't some other bug. Thanks.
I realized part of the problem here is that brms automatically centers data whereas here we need to either specify more accurate priors or standardize the data before fitting with numpyro. Nevertheless, the pyro backends do work with non-centered columns in the example below. Agreed that it will be good to think through some initialization strategies.
import brmp
from brmp import brm
from brmp.numpyro_backend import backend as numpyro
from brmp.pyro_backend import backend as pyro
import pandas as pd
df = pd.read_csv('https://stats.idre.ucla.edu/stat/data/hdp.csv')
df = df.apply(lambda x: pd.factorize(x)[0] if np.issubdtype(x.dtype, np.number) is False else x) # factorize some columns
df['remission'] = df['remission'].astype(np.int)
df['DID'] = df['DID'].astype('category')
model = brm('remission ~ IL6 + CRP + CancerStage + LengthofStay + Experience + FamilyHx + SmokingHx + Sex + CancerStage + LengthofStay + WBC + BMI + (1 | DID)', df = df, family = brmp.family.Binomial(num_trials=1))
fit = model.fit(backend=numpyro, seed=8877, iter=1000, warmup=500) # fails
fit = model.fit(backend=pyro, seed=8877, iter=1000, warmup=500) # works
fit = model.fit(algo='svi', seed=8877, iter=10000, num_samples=1000) # works
Hi there, great initial work wrapping numpyro and pyro into a more user friendly interface! I'm having an issue with a few simple models where the numpyro backend gives me the following error:
Cannot find valid initial parameters. Please check your model again.
It seems to occur when there are many predictors in the formula.
Pyro sampling and SVI both work fine for this model with the default Cauchy beta priors. Any thoughts on better initializing numpyro NUTS with SVI or perhaps using maximum a posteriori estimates?
It's tough to figure out what exactly is causing MCMC to immediately fail but I'm assuming it's the initial starting values. Full traceback:
The text was updated successfully, but these errors were encountered: