Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch norm #73

Merged
merged 12 commits into from
Oct 6, 2023
Merged

Fix batch norm #73

merged 12 commits into from
Oct 6, 2023

Conversation

hlzl
Copy link

@hlzl hlzl commented Sep 20, 2023

Trying to revive PR #71 and solve issue #70. These changes should complete the refactor and use the new state mechanic.
However, currently this causes the following error in the Sequential module:

File "/home/user/eqxvision/eqx_train.py", line 296, in forward
    pred_ys, state = batch_model(images, state, key=keys)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/resnet.py", line 352, in __call__
    x, state = self.layer1(x, state, key=keys[1])
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
    x = layer(x, key=key)
        ^^^^^^^^^^^^^^^^^
TypeError: _ResNetBasicBlock.__call__() missing 1 required positional argument: 'state'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Can be reproduced with:

import jax
import equinox as eqx
from eqxvision.models import resnet18

@eqx.filter_jit
def forward(model, state, images):
    keys = jax.random.split(jax.random.PRNGKey(0), images.shape[0])
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(images, state, key=keys)
    return pred_ys


net = resnet18()
state = eqx.nn.State(net)

images = jax.random.uniform(jax.random.PRNGKey(0), shape=(1, 3, 64, 64))
output = forward(net, state, images)

@patrick-kidger Any idea how to solve this? Should we get rid of the if-else statement causing this error that checks for isinstance(layer, StatefulLayer) in _sequential?

@hlzl
Copy link
Author

hlzl commented Sep 20, 2023

As a side note, not all models have been fixed. I don't use the following models and thus don't know a suitable fix. Might be interesting for someone later:

  • ShuffleNetV2: Not sure how to deal with states from different branches
  • VGG, ViT, DeepLabV3: Haven't looked at them

@hlzl
Copy link
Author

hlzl commented Sep 20, 2023

Seems like we can fix the isinstance(layer, StatefulLayer) error in the Sequential module if we inherit from StatefulLayer in our layer-block classes (such as _ResNetBasicBlock(), _DenseLayer()). This way the sequential module actually registers the stateful layers and the necessity to pass in the state.

This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet implementation.

Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state? @paganpasta

@patrick-kidger
Copy link
Contributor

Seems like we can fix the isinstance(layer, StatefulLayer) error in the Sequential module if we inherit from StatefulLayer in our layer-block classes

Yup! This is the expected fix.

This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet implementation. Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state?

If you have a child stateful layer, then the parent itself is necessarily also stateful as well -- they should subclass StatefulLayer, accept a state argument, return a state argument -- and pipe it to their child layer in between.

@hlzl
Copy link
Author

hlzl commented Sep 21, 2023

Maybe I'm missing something here, but if I have a nested sequential layer, how should the top sequential layer know that inside its child sequential layer there is a stateful layer?
The if isinstance(layer, StatefulLayer) condition is not recursive and will just see the child layer as being of type Sequential, right?

At least that seems to be what happened when I tried to change the EfficientNet implementation by converting _MBConv and _FusedMBConv to inherit from StatefulLayer, causing the following error:

File "/home/user/eqxvision/eqx_train.py", line 12, in forward
    pred_ys, state = batch_model(images, state, key=keys)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/efficientnet.py", line 405, in __call__
  x, state = self.features(x, state, key=keys[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
         ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
  x = layer(x, key=key)
        ^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
         ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 66, in __call__
  x, state = layer(x, state=state, key=key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_batch_norm.py", line 155, in __call__
  first_time = state.get(self.first_time_index)
            ^^^^^^^^^
AttributeError: 'object' object has no attribute 'get'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@patrick-kidger
Copy link
Contributor

patrick-kidger commented Sep 21, 2023

Ah, I take your point -- you'd like a way to programatically propagate statefulness through, since the containing classes (whether they are a Sequential or a custom class) may-or-may-not be stateful, depending on their choice of sublayers.

I've just written patrick-kidger/equinox#505, which adds a StatefulLayer.is_stateful method, which is called to check whether or not a layer is stateful. And indeed Sequential now inherits from StatefulLayer, with an implementation of

def is_stateful(self):
    return any(isinstance(x, StatefulLayer) and x.is_stateful() for x in self.layers)

Thus nested sequentials should now automatically work. In addition, you should be able to have your own classes inherit from StatefulLayer, and implement is_stateful, so that if you wish your custom layers can also be handled statefully when placed inside a Sequential, if required.

Does that seem like it would work for you? If you can install Equinox from that branch and check that it meets your needs, then I'll include it in the upcoming release.

@hlzl
Copy link
Author

hlzl commented Sep 22, 2023

That's exactly what I meant and your proposed solutions sounds great!

I haven't been able to complete my testing as I ran into the Cannot assign methods in __init__ error introduced in v0.11.0 (gets thrown e.g. here for the ResNet).

As it seems that it's common in a lot of models in eqxvision to create layer blocks in __init__ with class functions (probably taken from the equivalent pytorch implementations), I was wondering what would be the best approach to refactor this before I do.

We don't want to create the layer blocks anew every time we call them, so I thought about using functools.chached_property.
This still seems like a lot of additional code compared to the previous implementations, so I thought you might had a better idea in mind when implementing the error that I haven't thought of yet.

@patrick-kidger
Copy link
Contributor

Great, I'm glad it works. I've just merged that chagne into dev.

As for that error -- that's a bug in something I just wrote, whoops. I think patrick-kidger/equinox#508 should fix.

@hlzl
Copy link
Author

hlzl commented Sep 24, 2023

All models using BatchNorm should work now. Tested with equinox/dev commit d9b018a.

@paganpasta
Copy link
Owner

@hlzl Hi, thanks for the PR and sorry for the late response.

I think the tests are currently failing cause the equinox updates are not yet packaged into a new release.

I'll test it locally, update and merge accordingly, soon. Currently, tied down this week with few deadlines.

@patrick-kidger
Copy link
Contributor

I'm planning on doing the next Equinox release in the next week, by the way.

@hlzl
Copy link
Author

hlzl commented Sep 26, 2023

No worries, thank you @paganpasta.

The models generally run, but it seems like they do not reproduce the same results as their PyTorch equivalent.

E.g., trying to reproduce results on CIFAR10 with a PyTorch ResNet18 with an equivalent equinox implementation:
Somehow, the equinox implementation stops learning anything useful after matching the torch implementation during the first (and second) epoch, and then eventually converges back to random guessing.

The gradients behave super strangely starting in the third epoch, but I'm not able to tell what causes the issue and leads to an exploding loss. In particular, because this setup should be identical to the PyTorch one, where this problem does not arise.

Examples to reproduce with current dev branch of equinox and torch==2.0:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm


def accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    correct = (predicted == labels).sum().item()
    total = labels.size(0)
    return correct / total


######################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

batch_size = 256
num_epochs = 100

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

model = torchvision.models.resnet18(pretrained=False)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(num_epochs):
    running_loss = 0.0
    running_accuracy = 0.0
    for images, labels in tqdm(
        trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
    ):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += accuracy(outputs, labels) * 100

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = running_accuracy / len(trainloader)
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
    )

vs.

import equinox as eqx
import eqxvision

import jax
import jax.numpy as jnp
import optax

import torch
import torchvision
from tqdm import tqdm


def xavier_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
    nin = weight.shape[1]
    return jax.random.normal(key, weight.shape) * jnp.sqrt(2.0 / nin)


def init_weights(model, init_fn, key):
    is_weight = lambda x: isinstance(x, (eqx.nn.Linear, eqx.nn.Conv))
    get_weights = lambda m: [
        x.weight
        for x in jax.tree_util.tree_leaves(m, is_leaf=is_weight)
        if is_weight(x)
    ]
    weights = get_weights(model)
    new_weights = [
        init_fn(weight, subkey)
        for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    ]
    new_model = eqx.tree_at(get_weights, model, new_weights)
    return new_model


@eqx.filter_jit
def loss(model, state, x, label):
    keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(x, state, key=keys)
    return optax.softmax_cross_entropy_with_integer_labels(pred_ys, label).mean(), state


@eqx.filter_jit
def make_step(model, state, opt_state, x, label):
    (val, state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
        model, state, x, label
    )
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state, val


@eqx.filter_jit
def inference(model, state, x):
    keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
    inference_model = eqx.Partial(eqx.tree_inference(model, value=True), state=state)
    return jax.vmap(inference_model)(x, key=keys)


@eqx.filter_jit
def accuracy(outputs, labels):
    predicted = jnp.argmax(outputs, 1)
    correct = (predicted == labels).sum()
    total = labels.size
    return correct / total


######################################################

batch_size = 256
num_epochs = 100

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)


model = eqxvision.models.resnet18(num_classes=10)
model = init_weights(model, xavier_init, key=jax.random.PRNGKey(5678))

opt = optax.sgd(learning_rate=0.01, momentum=0.9)

state = eqx.nn.State(model)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
for epoch in range(0, num_epochs):
    running_loss, running_accuracy = 0.0, 0.0
    for images, labels in tqdm(
        trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
    ):
        model, state, opt_state, loss_val = make_step(
            model,
            state,
            opt_state,
            images.numpy(),
            labels.numpy(),
        )

        out = inference(model, state, images.numpy())

        running_loss += loss_val
        running_accuracy += accuracy(out[0], labels.numpy()) * 100

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = running_accuracy / len(trainloader)
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
    )

Would be great if one of you could have a look here.

@patrick-kidger
Copy link
Contributor

Hmm two quick thoughts:

  • a key shouldn't be necessary at inference time as things should be deterministic then. Possibly indicative of some larger error?
  • what does training look like without batchnorm? I.e. just replace it with a dummy identity layer in both frameworks.

FWIW PyTorch and Equinox use slightly different batch norm implementations (there are many variants). But once you've removed batch norm, you could try initialising them with the same weights, and use the same training batches, and see if you can get close to bit-for-bit reproducibility.

@paganpasta paganpasta changed the base branch from main to dev October 2, 2023 22:19
@patrick-kidger
Copy link
Contributor

Btw, heads-up that Equinox v0.11.0 is now released! That shouldn't be a blocker any more for this PR.

@hlzl
Copy link
Author

hlzl commented Oct 4, 2023

@patrick-kidger Sorry for the late reply. Regrading your two thoughts:

  • The key is actually not used at all in the forward pass, even during training, and could be set to optional in the ResNet. Will change that. Difference in training behaviour remains though.
  • Yes, seems like the batch norm layer is definitely the cause for this behaviour. Switching to group norm solved the issue. Interestingly, the validation accuracy is still quite different to the PyTorch implementation, but training itself behaves as expected when using group norm.

BTW, my simple test for comparing the two implementations was to overfit on CIFAR10 during training.
This generally works out of the box on most ResNet implementations, but somehow I was not able to achieve this for the equinox implementation with batch norm.
Do you have any idea why this could be (or if this is due to the different bn variant)?

@paganpasta paganpasta merged commit 0bdd2f2 into paganpasta:dev Oct 6, 2023
0 of 2 checks passed
@hlzl hlzl deleted the fix_batch_norm branch October 6, 2023 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants