Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multitransform does not work with flax nnx #1148

Open
MiladInk opened this issue Dec 3, 2024 · 1 comment
Open

multitransform does not work with flax nnx #1148

MiladInk opened this issue Dec 3, 2024 · 1 comment

Comments

@MiladInk
Copy link

MiladInk commented Dec 3, 2024

The multitransform optimizers do not work with the flax nnx way of handling optimizers. Here is a minimal code showcasing the problem:

import optax
import jax
import jax.numpy as jnp

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss


if __name__ == '__main__':
    # An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
    model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
    # optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # this works
    optimizers = {
        'linear1': optax.adam(1e-3),
        'linear2': optax.adam(2e-3),
        'bn': optax.adam(3e-3),
    }

    name_map = {
        'linear1': 'linear1',
        'linear2': 'linear2',
        'bn': 'bn',
    }
    
    tx = optax.multi_transform(optimizers, name_map)
    optimizer = nnx.Optimizer(model, tx)

    x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
    loss = train_step(model, optimizer, x, y)


    print(f'{loss = }')
    print(f'{optimizer.step.value = }')
    print(model)

This code produces this error:

Traceback (most recent call last):
  File "/Users/miladaghajohari/Applications/PyCharm Professional.app/Contents/plugins/python-ce/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
  File "/Users/miladaghajohari/Applications/PyCharm Professional.app/Contents/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Users/miladaghajohari/Applications/PyCharm Professional.app/Contents/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/src/outputs/multi_transform_bug.py", line 46, in <module>
    optimizer = nnx.Optimizer(model, tx)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/flax/nnx/object.py", line 79, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/flax/nnx/object.py", line 88, in _graph_node_meta_call
    cls._object_meta_construct(node, *args, **kwargs)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/flax/nnx/object.py", line 82, in _object_meta_construct
    self.__init__(*args, **kwargs)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/flax/nnx/training/optimizer.py", line 193, in __init__
    self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/optax/transforms/_combining.py", line 243, in init_fn
    inner_states = {
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/optax/transforms/_combining.py", line 248, in <dictcomp>
    ).init(params)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/optax/transforms/_masking.py", line 128, in init_fn
    masked_params = mask_pytree(params, mask_tree)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/optax/transforms/_masking.py", line 91, in mask_pytree
    return jax.tree.map(
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 343, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/Users/miladaghajohari/PycharmProjects/meltingpot-moonshot/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 343, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Expected dict, got State({
  'bn': {
    'bias': VariableState(
      type=Param,
      value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
    ),
    'scale': VariableState(
      type=Param,
      value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)
    )
  },
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[ 0.72217995,  0.3850245 , -0.07669721, -0.00685275,  0.5362586 ,
              -0.38415453,  0.41223425, -1.1759619 , -0.13179159, -1.2721819 ,
               0.9267716 ,  0.94057745, -0.45282063, -1.5622699 ,  0.9774794 ,
              -1.0354867 ],
             [-0.13776538, -0.24806416, -0.63165766,  0.4321943 ,  0.5578346 ,
              -0.0922389 , -0.6438374 ,  1.0601913 , -0.19511075,  0.35218215,
              -1.0710531 , -0.785865  , -0.688127  ,  0.1734108 ,  0.6501992 ,
              -0.23442917]], dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.02864332, -0.4359587 , -0.3435134 , -0.191117  , -0.2358688 ,
               0.32216027, -0.10704373,  0.17006359,  0.02259357, -0.2991183 ],
             [-0.17418525, -0.3475435 ,  0.14287646,  0.15190478,  0.04712677,
               0.22665717, -0.12826608, -0.10540091, -0.3031107 ,  0.13475087],
             [-0.05568802,  0.06233315,  0.26388815,  0.14147666, -0.4953701 ,
               0.48909268,  0.17546323,  0.32143998, -0.21842146, -0.04013054],
             [ 0.06601496, -0.19542927,  0.06449655, -0.29464233,  0.088284  ,
               0.42847374,  0.0058595 ,  0.15576684,  0.0455053 , -0.00616307],
             [-0.01459054,  0.06947349, -0.17197794, -0.03559725, -0.10898678,
              -0.1344615 ,  0.43058196,  0.3249984 , -0.02944539,  0.17368062],
             [-0.35230708, -0.3854319 , -0.3465927 ,  0.11094601,  0.15047713,
               0.30810842, -0.26587242, -0.42608866,  0.39287725,  0.01910183],
             [-0.08720911, -0.24712713,  0.19909504, -0.12760757,  0.09218101,
               0.29698348,  0.1696361 ,  0.28325173,  0.05396844,  0.26612124],
             [-0.2581897 ,  0.3014514 , -0.2643698 ,  0.16643463, -0.17577483,
              -0.26752108, -0.00084232,  0.25919583, -0.53392863, -0.36726063],
             [-0.14639883,  0.14041387, -0.26908892,  0.5115992 , -0.3799397 ,
               0.1986312 , -0.16673851, -0.02079232, -0.15477388, -0.163966  ],
             [-0.17034872, -0.07494676,  0.10665798,  0.19908334,  0.33362478,
              -0.16310433,  0.16489765,  0.13541739,  0.15069027,  0.06568305],
             [ 0.09099597,  0.05862452, -0.04828012,  0.01189236,  0.01344032,
               0.20035312,  0.22156866, -0.11468177, -0.54855186,  0.07578899],
             [-0.33152425, -0.29136464,  0.10423858, -0.12863392, -0.28326088,
               0.33248442, -0.11456756, -0.05140362,  0.15094203, -0.1147663 ],
             [-0.28723633,  0.27467206,  0.38115257, -0.05806991,  0.18336216,
               0.27663368, -0.09170966,  0.00422179,  0.00296012,  0.19378984],
             [ 0.48130342,  0.13028473, -0.08717275, -0.2596676 , -0.00527761,
               0.22398742,  0.49893376, -0.25161025,  0.28161395,  0.19897749],
             [-0.16675583, -0.15079561,  0.02887268,  0.00925025, -0.15791203,
               0.19928937, -0.2496463 , -0.28188944,  0.04672124, -0.19195534],
             [ 0.34721088,  0.1112446 ,  0.44359654,  0.12015927,  0.32238033,
              -0.1884635 ,  0.01119952, -0.2487719 , -0.44758505, -0.39491665]],      dtype=float32)
    )
  }
}).

where the relevant part is:

packages/jax/_src/tree_util.py", line 343, in <listcomp>
  all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Expected dict, got State({
'bn': {
@rdyro
Copy link
Collaborator

rdyro commented Dec 3, 2024

Hey, good question!

name_map must correspond to the pytree structure of the model. Your model has 3 layers, but 6 parameters, so the name_map must have 6 leaves too.

One way to automate constructing the name_map is to make use of jax.tree_util.tree_flatten_with_path like so:

model_state = nnx.state(model, nnx.Param) # this is the actual pytree whose structure you have to match
is_param = lambda x: isinstance(x, nnx.Param)
# extract the name of the parent layer
name_map_values = [k[0].key for k, _ in jax.tree_util.tree_flatten_with_path(model_state, is_leaf=is_param)[0]]
name_map = jax.tree.unflatten(jax.tree.structure(model_state, is_leaf=is_param), name_map_values)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants