You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
packages/jax/_src/tree_util.py", line 343, in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Expected dict, got State({
'bn': {
The text was updated successfully, but these errors were encountered:
model_state=nnx.state(model, nnx.Param) # this is the actual pytree whose structure you have to matchis_param=lambdax: isinstance(x, nnx.Param)
# extract the name of the parent layername_map_values= [k[0].keyfork, _injax.tree_util.tree_flatten_with_path(model_state, is_leaf=is_param)[0]]
name_map=jax.tree.unflatten(jax.tree.structure(model_state, is_leaf=is_param), name_map_values)
The multitransform optimizers do not work with the flax nnx way of handling optimizers. Here is a minimal code showcasing the problem:
This code produces this error:
where the relevant part is:
The text was updated successfully, but these errors were encountered: