Replies: 2 comments 5 replies
-
I did play around with this a bit at one point, but I'm not aware of any worked examples. I'd recommend taking a look at the new foreign function interface. I expect the easiest approach would be to write a C++ shim that uses the FFI to dispatch to Rust, rather than trying to reimplement the FFI in Rust directly. Using variadic arguments and results you could probably write a pretty general purpose interface. Hope this helps! |
Beta Was this translation helpful? Give feedback.
-
This is something I'm also very interested in. Mostly the other way round however. I would like to be able to call a jax compiled function from rust. The way I do it right now is to just call a python function which calls jax, but this adds a lot of overhead (including running into the GIL). |
Beta Was this translation helpful? Give feedback.
-
Hi everyone!
I really enjoy JAX, and I also really appreciate programming in Rust, and I would love to be able to use Rust code to implement some performance-critical functions and export them to Python.
Currently, I do it using PyO3 and
rust-numpy
, but I lose all the benefits from XLA (mainly automatic differentiation for me), so I eventually have to first convert NumPy arrays to JAX arrays in Python.I am aware of Extending JAX with custom C++ and CUDA code and Custom operations for GPUs with C++ and CUDA, which are two great tutorials, but I feel there should be something easier when it comes to writing a JAX extension with Rust, especially when tools and libraries like PyO3, maturin,
rust-numpy
, andxla-rs
exist.Have some of you already tried implementing a JAX (or XLA) extension with Rust?
I was considering giving it a try myself, but I am of course interested to see if other people have tried before me.
Beta Was this translation helpful? Give feedback.
All reactions