Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU support #14

Open
null-a opened this issue Sep 30, 2019 · 2 comments
Open

Add GPU support #14

null-a opened this issue Sep 30, 2019 · 2 comments
Milestone

Comments

@null-a
Copy link
Collaborator

null-a commented Sep 30, 2019

No description provided.

@null-a
Copy link
Collaborator Author

null-a commented Oct 11, 2019

For reference, the numpyro backend is currently hardcoded to use the cpu here.

@neerajprad
Copy link
Member

neerajprad commented Dec 3, 2019

@null-a - I was looking at this. Currently there is one issue - we need to set the device before jax is loaded so doing this on a per-model basis doesn't seem possible atm. Otherwise, this is as simple as using numpyro.set_platform('gpu') before running any model code. Please use that and see if that helps make any of your model computations faster. For us, we observed a 10X increase on the GPU for a slow sparse regression problem.

I think the right interface will involve using the currently experimental jax feature (jax-ml/jax#1598) in numpyro to be able to jit by device type. We can prioritize this for the next minor release soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants