jittable pytree flattening/unflattening #13473
-
I have a problem-specific optimizer that solves a least squares problem at each step using
Here matrix(params) is a user provided positive semi-definite matrix-valued function of the network parameters of the same form as below.
where here As you can see, I rely on |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I have exactly the same use-case as you have in a package I develop. You can quickly check to see if it works for you by simply using |
Beta Was this translation helpful? Give feedback.
I have exactly the same use-case as you have in a package I develop.
I ended up using a slightly different version of the ravel function, see it here.
You can quickly check to see if it works for you by simply using
nk.jax.tree_ravel
instead of the jax-provided function.If it works for you, you can extract it from there (note: if you are not using complex dtypes, then you can safely replace
nkjax.vjp
withjax.vjp
)