Skip to content

Troubleshooting jit-ing with methods. #25591

Answered by jakevdp
joseph-jnl asked this question in Q&A
Discussion options

You must be logged in to vote

The issue is that your tree_flatten method includes the array special_states_rewards in aux_data, and arrays are not allowed in aux_data because they are not hashable and don't have simple equality semantics. From https://jax.readthedocs.io/en/latest/pytrees.html:

When defining unflattening functions, in general children should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while aux_data should contain all the static elements that will be rolled into the treedef structure. JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@joseph-jnl
Comment options

Answer selected by joseph-jnl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants