Troubleshooting jit-ing with methods. #25591
-
Hello, I'm going through the Sutton & Barto rlbook while also learning jax (probably too many degrees of freedom to learn efficiently!) and ran into an issue with troubleshooting jit-ing methods. Specifically, the method works for:
But fails for a second object with the same set of initialized parameters. I'm unsure if the root cause is incorrectly:
def _tree_flatten(self):
children = (
self.R,
self.P,
self.actions,
self.v_init,
) # arrays and dynamic values
# static values (non-arrays)
aux_data = {
"special_states": self.special_states,
"special_states_prime": self.special_states_prime,
"special_states_rewards": self.special_states_rewards,
}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
grid = cls(
aux_data["special_states"],
aux_data["special_states_prime"],
aux_data["special_states_rewards"],
R=children[0],
P=children[1],
)
grid.v_init = children[2]
return grid
def state_value(
self,
v,
R,
P,
special_states,
special_states_prime,
special_states_rewards,
discount: float = 0.9,
):
""""""
# Update interior grid
vp = (
R
+ convolve2d(
jnp.pad(v, pad_width=(1, 1), constant_values=0),
P,
mode="valid",
)
* discount
)
# Update edges except for corners
vp = vp.at[1:-1, 0].add(v[1:-1, 0] * discount * 0.25)
vp = vp.at[1:-1, -1].add(v[1:-1, -1] * discount * 0.25)
vp = vp.at[0, 1:-1].add(v[0, 1:-1] * discount * 0.25)
vp = vp.at[-1, 1:-1].add(v[-1, 1:-1] * discount * 0.25)
# Update corners
vp = vp.at[0, 0].add(v[0, 0] * 2 * discount * 0.25)
vp = vp.at[0, -1].add(v[0, -1] * 2 * discount * 0.25)
vp = vp.at[-1, -1].add(v[-1, -1] * 2 * discount * 0.25)
vp = vp.at[-1, 0].add(v[-1, 0] * 2 * discount * 0.25)
# Update special states
vp = vp.at[special_states[0], special_states[1]].set(
v[special_states_prime[0], special_states_prime[1]] * discount
+ special_states_rewards
)
return vp
def estimate_state_value(self, iter=1000):
""""""
v = self.v_init
for _ in range(iter):
v = self.state_value(
v,
self.R,
self.P,
self.special_states,
self.special_states_prime,
self.special_states_rewards,
)
return v
Link to full code. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The issue is that your
The fix would be to include |
Beta Was this translation helpful? Give feedback.
The issue is that your
tree_flatten
method includes the arrayspecial_states_rewards
inaux_data
, and arrays are not allowed inaux_data
because they are not hashable and don't have simple equality semantics. From https://jax.readthedocs.io/en/latest/pytrees.html: