Pulling an intermediate variable outside a function #16572
Unanswered
astanziola
asked this question in
Ideas
Replies: 1 comment 1 reply
-
I believe Oryx implements this idea (see here), but when I tried to install it with pip it looks like it may not be compatible with the most recent JAX version. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
One feature that would be quite nice to have is the ability to retrieve an intermediate variable from a jitted function, along with its results, without needing to modify the function signature. This could be especially useful for nested functions.
My thoughts are along these lines: take, for instance, the functions
I would like to have some method, perhaps called
pull_variable
, that could extract anamed_variable
like so:Now, I would be able to execute:
Where
variables = {"qux": y}
.While I don't see a reason why this shouldn't be feasible, I'm curious if there's a way to achieve this in Jax currently?
If not, this could indeed be a great feature, and I can foresee quite a few potential use cases for it, such as:
On a tangential note, there may be some complimentary arguments for having a
push_variable
decorator that modifies the input signature of a function, and allows injecting a value in place of a variable in a computational graph, but that's partially beyond the scope of this question 😄Beta Was this translation helpful? Give feedback.
All reactions