You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@null-a - I was looking at this. Currently there is one issue - we need to set the device before jax is loaded so doing this on a per-model basis doesn't seem possible atm. Otherwise, this is as simple as using numpyro.set_platform('gpu') before running any model code. Please use that and see if that helps make any of your model computations faster. For us, we observed a 10X increase on the GPU for a slow sparse regression problem.
I think the right interface will involve using the currently experimental jax feature (jax-ml/jax#1598) in numpyro to be able to jit by device type. We can prioritize this for the next minor release soon.
No description provided.
The text was updated successfully, but these errors were encountered: