Skip to content

How to compute a partial jacobian matrix. #5904

Answered by jakevdp
neale asked this question in Q&A
Discussion options

You must be logged in to vote

Have you tried using standard slicing to return the portion of the Jacobian matrix you're interested in? Under JIT, the intermediate objects are not actually computed, and the XLA compiler is able to trim parts of the computation that are unnecessary. For example:

import jax
import jax.numpy as jnp
import numpy as np

np.random.seed(1701)
N = 1000
f_mat = np.array(np.random.rand(N, N))
def f(x):
  return jnp.sqrt(f_mat @ x / N)
x = np.array(np.random.rand(N))

#-----------
# Full Jacobian
f1 = jax.jit(lambda x: jax.jacfwd(f)(x))
J_f1 = f1(x)

print(J_f1.shape)
# (1000, 1000)

%timeit f1(x)
# 100 loops, best of 5: 6.96 ms per loop

#----------
# Partial Jacobian
f2 = jax.jit(lambda x: jax.j…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@neale
Comment options

@tetterl
Comment options

@jakevdp
Comment options

Answer selected by neale
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants