Replies: 4 comments
-
Thanks for the question!
My first thought is to wonder whether you could check if it's a type you do expect, rather than checking that it's not a JAX Tracer. That way you might avoid depending on JAX internals, which could change underneath you without warning. Could that make sense? In general, detecting leaks is tricky, as a Tracer could be hidden in some other container data structure into which you don't recurse, or could've been stashed away by a side-effect. We just merged a leak checker that you might want to look at, but keep in mind it's currently more of a "debug mode" than a "always leave it enabled" kind of thing.
I don't think so, not without breaking JAX semantics. That is, as soon as you have a Tracer that isn't processed by JAX core, and is instead unpacked by your own code, it likely means breaking the semantics of whatever JAX transformation is being applied. What do you think? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the info!
I am doing some library and the user can pass any pytree, I was hoping to do something like: any(jax.is_traced(x) for x in jax.tree_leaves(value)) This structure has to be a pytree or else
The leak checker seems nice! I need something to check leakage during the |
Beta Was this translation helpful? Give feedback.
-
The general problem I am trying to solve is creating a Module system that is Transfer (Learning) friendly, the approach is having parameters live inside the Module so if you create a new modules from exiting modules they naturally compose without needing to do complex manipulation of the parameter pytrees to satisfy the new structure. The challenge is leakage during
|
Beta Was this translation helpful? Give feedback.
-
Thanks for explaining. Also I realized it's not super easy to check for non-Tracer types you expect, since Maybe the best option is to check using |
Beta Was this translation helpful? Give feedback.
-
Hey! I want to check is an array is one of these traced objects used it
jit
,grad
and friends to make sure I don't leak out one of these structures. What is the correct way to do this?Bonus: is there a way to extract a value from them?
Beta Was this translation helpful? Give feedback.
All reactions