Memory issues with jit? #5862
Unanswered
tomasgeffner
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all! I am posting this in the discussion section because I'm not really sure if this is an issue or if I am doing some simple mistake (also, I think I found a solution for this), but I'd like to understand what is going on if possible. In short, I'm optimizing some function, and jitting the function that computes the gradient. As optimization proceeds, memory usage grows a lot and consistently. If I don't use jit at all, then the memory is just fine. The code I'm running is quite big, but here's a "small" example:
The possible solutions I found:
1. As I said before, if I don't use jit when I define gf, then things are fine. However, this is not ideal since things become slow.
2. If I use jit, but replace the line
losses.append(loss)
bylosses.append(loss.item())
things are also apparently fine (at least in the examples I tried things are okay). This makes me think that I may be missing something simple...I am running things on a cluster. To my surprise, I am not able to reproduce this strange memory behavior when running things locally on my computer. That is, when running things locally memory always stays low (regardless of using jit or not, and adding .item() or not).
I am using Jax version 0.2.9, and jaxlib version 0.1.61 (both in the cluster and locally). Slightly different versions of python: 3.9.1 in the cluster, and 3.9.2 locally. Also, different OS (CentOS Linux 7 in cluster and MacOS 10.14.5 locally).
Any suggestions for potential causes for this would be highly appreciated!
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions