-
Hello, I have a question about custom rules like In this case (if I need for example a gradient, a transpose rule and vmap rule) from jax import custom_jvp @custom_jvp
@custom_transpose
def solve(A, b):
"""find x such that A @ x = b"""
return long_iterative_solve(A, b)
@solve.def_transpose
def solve_transpose(A, tb):
return solve(A.T, tb)
@solve.defjvp
def solve_jvp(primals, tangents):
A, b = primals
tA, tb = tangents
x = solve(A, b)
tx = solve(A, tb - tA @ x) # automatically transposed for VJP
return x, tx I saw this example that has multiple decorators stacked but how does this fit if I want to use a vmap aswell. Does the vmap rule go after or before the vjp? Thank you |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Good question! The dream is (was?) to generally support mixing and matching these decorators, but (as you've probably worked out!) there are some rough edges with the current APIs, and such customization hasn't been widely exercised. For now, you're probably on the right track, but hopefully this will be easier in the future! As for your specific question about |
Beta Was this translation helpful? Give feedback.
-
@dfm I am closing this disscusion because I think that the answer is not yet implemented in JAX But in my opinion it is a mix of #22457 and #24726 is exactly what I need in addition to supporting custom_partitionning for example with a |
Beta Was this translation helpful? Give feedback.
Good question! The dream is (was?) to generally support mixing and matching these decorators, but (as you've probably worked out!) there are some rough edges with the current APIs, and such customization hasn't been widely exercised. For now, you're probably on the right track, but hopefully this will be easier in the future!
As for your specific question about
custom_vmap
. You typically wantcustom_vmap
on the inside when it is combined withcustom_vjp
(see the newcustom_vmap
docstring for an example), which is the most common use case. I expect that you'll want something similar in your case.