Replies: 2 comments
-
Here is gist code showing this issue: https://gist.github.com/selym3/fa2f551c03cc5abd60c138af55d19c3b The data itself is only 1.6GB, and the "False" brach that is coded in a different style runs very fast without eating many GB of ram |
Beta Was this translation helpful? Give feedback.
0 replies
-
Hey ! Did you find way around or are still relying on callbacks for those large arrays ? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
We are building a differentiable simulator and as part of our state update, we need to access an array with ~2GB of data. Normally, if we use
pure_callback()
, the array can be accessed with the code running at a slow pace; however, if wejit()
this process, the RAM can roar drastically causing the process to be killed.The data is preloaded from files before the part of the code that is in
jit
, and it is only killed when the data is accessed injit
. We have also tried usingnp.memmap()
to access the data, but it has the same result of working inpure_callback
, causing excessive RAM use injit
.Is there support for operations like this in JAX? The data should not cause enough RAM usage for the process to be killed when running on a device with ~16GB. We were wondering why running the code in
jit
could cause this. Right now, we are only looking to run on CPU and want the autogradient features of JAX. Thanks!Beta Was this translation helpful? Give feedback.
All reactions