-
Hi there! I'm hoping JAX can do what other frameworks cannot. I have a need to compute just a subset of the input-output Jacobian (J) of some function f. Since f is large (many parameters), I would like to consider J as a block matrix:
I want to only extract the "A" block of the jacobian. Is this possible with a custom VJP rule? In this case, if I use a JVP then I can get "A" by masking out the other blocks, but I still need to represent V as a large matrix. In addition, for JVP/VJP solutions the output is still the full size of J and this could be difficult to store in memory. Thanks for any help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Have you tried using standard slicing to return the portion of the Jacobian matrix you're interested in? Under JIT, the intermediate objects are not actually computed, and the XLA compiler is able to trim parts of the computation that are unnecessary. For example: import jax
import jax.numpy as jnp
import numpy as np
np.random.seed(1701)
N = 1000
f_mat = np.array(np.random.rand(N, N))
def f(x):
return jnp.sqrt(f_mat @ x / N)
x = np.array(np.random.rand(N))
#-----------
# Full Jacobian
f1 = jax.jit(lambda x: jax.jacfwd(f)(x))
J_f1 = f1(x)
print(J_f1.shape)
# (1000, 1000)
%timeit f1(x)
# 100 loops, best of 5: 6.96 ms per loop
#----------
# Partial Jacobian
f2 = jax.jit(lambda x: jax.jacfwd(f)(x)[:5, :5])
J_f2 = f2(x)
print(J_f2.shape)
# (5, 5)
%timeit f2(x)
# 1000 loops, best of 5: 214 µs per loop |
Beta Was this translation helpful? Give feedback.
Have you tried using standard slicing to return the portion of the Jacobian matrix you're interested in? Under JIT, the intermediate objects are not actually computed, and the XLA compiler is able to trim parts of the computation that are unnecessary. For example: