Calling a mixture of GPU / CPU ops within a jit'ed set of functions #5866
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I recently came across this excellent article, which outlines how one can make their own custom ops.
https://github.com/dfm/extending-jax
I am coming from a tensorflow (TF) background, where one can write custom ops in a similar manner (https://www.tensorflow.org/guide/create_op). If you write CPU only ops, when you run in graph mode, although most of the data / calculations are being run on the GPU, TF automatically handles the transfer of data from GPU to CPU and back when a custom op is called, without having to specify this is the case.
After building from the sample code, quickly playing about, it seems naively if I want to use the CPU version of the Kepler op, I have to either invoke CPU mode i.e. jax.config.update('jax_platform_name', 'cpu') or call the jit'ed op with the backend settings of cpu e.g. jax.jit(func_name, backend='cpu')
My current use case is such that the vast majority of my code is pure Jax using GPU computations (and without it my code will run far too slowly to be useful), but I would like to port over some specific CPU ops that I previously coded for TF (I can't make GPU versions, as they depend upon 3rd party C++ libraries which are CPU only).
Is it possible to have something along these lines, where inside a jit'ed set of functions that run on the GPU, a custom CPU only op is called i.e. via jax.jit(cpuOpName, backend='cpu') ? And or do you have to do some "jax.device_put" from GPU to CPU, call the op, and force back again? Or is this just impossible to do via Jax?
Beta Was this translation helpful? Give feedback.
All reactions