-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix batch norm #73
Fix batch norm #73
Conversation
As a side note, not all models have been fixed. I don't use the following models and thus don't know a suitable fix. Might be interesting for someone later:
|
Seems like we can fix the This, however, doesn't seem to work straight forwardly for more nested models such as the Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state? @paganpasta |
Yup! This is the expected fix.
If you have a child stateful layer, then the parent itself is necessarily also stateful as well -- they should subclass |
Maybe I'm missing something here, but if I have a nested sequential layer, how should the top sequential layer know that inside its child sequential layer there is a stateful layer? At least that seems to be what happened when I tried to change the EfficientNet implementation by converting File "/home/user/eqxvision/eqx_train.py", line 12, in forward
pred_ys, state = batch_model(images, state, key=keys)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/efficientnet.py", line 405, in __call__
x, state = self.features(x, state, key=keys[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
x = layer(x, key=key)
^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 66, in __call__
x, state = layer(x, state=state, key=key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_batch_norm.py", line 155, in __call__
first_time = state.get(self.first_time_index)
^^^^^^^^^
AttributeError: 'object' object has no attribute 'get'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Ah, I take your point -- you'd like a way to programatically propagate statefulness through, since the containing classes (whether they are a I've just written patrick-kidger/equinox#505, which adds a def is_stateful(self):
return any(isinstance(x, StatefulLayer) and x.is_stateful() for x in self.layers) Thus nested sequentials should now automatically work. In addition, you should be able to have your own classes inherit from Does that seem like it would work for you? If you can install Equinox from that branch and check that it meets your needs, then I'll include it in the upcoming release. |
That's exactly what I meant and your proposed solutions sounds great! I haven't been able to complete my testing as I ran into the As it seems that it's common in a lot of models in We don't want to create the layer blocks anew every time we call them, so I thought about using |
Great, I'm glad it works. I've just merged that chagne into As for that error -- that's a bug in something I just wrote, whoops. I think patrick-kidger/equinox#508 should fix. |
All models using |
@hlzl Hi, thanks for the PR and sorry for the late response. I think the tests are currently failing cause the equinox updates are not yet packaged into a new release. I'll test it locally, update and merge accordingly, soon. Currently, tied down this week with few deadlines. |
I'm planning on doing the next Equinox release in the next week, by the way. |
No worries, thank you @paganpasta. The models generally run, but it seems like they do not reproduce the same results as their PyTorch equivalent. E.g., trying to reproduce results on CIFAR10 with a PyTorch ResNet18 with an equivalent The gradients behave super strangely starting in the third epoch, but I'm not able to tell what causes the issue and leads to an exploding loss. In particular, because this setup should be identical to the PyTorch one, where this problem does not arise. Examples to reproduce with current import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm
def accuracy(outputs, labels):
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
total = labels.size(0)
return correct / total
######################################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
batch_size = 256
num_epochs = 100
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
model = torchvision.models.resnet18(pretrained=False)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(num_epochs):
running_loss = 0.0
running_accuracy = 0.0
for images, labels in tqdm(
trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
running_accuracy += accuracy(outputs, labels) * 100
epoch_loss = running_loss / len(trainloader)
epoch_accuracy = running_accuracy / len(trainloader)
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
) vs. import equinox as eqx
import eqxvision
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from tqdm import tqdm
def xavier_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
nin = weight.shape[1]
return jax.random.normal(key, weight.shape) * jnp.sqrt(2.0 / nin)
def init_weights(model, init_fn, key):
is_weight = lambda x: isinstance(x, (eqx.nn.Linear, eqx.nn.Conv))
get_weights = lambda m: [
x.weight
for x in jax.tree_util.tree_leaves(m, is_leaf=is_weight)
if is_weight(x)
]
weights = get_weights(model)
new_weights = [
init_fn(weight, subkey)
for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
]
new_model = eqx.tree_at(get_weights, model, new_weights)
return new_model
@eqx.filter_jit
def loss(model, state, x, label):
keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
batch_model = jax.vmap(
model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)
pred_ys, state = batch_model(x, state, key=keys)
return optax.softmax_cross_entropy_with_integer_labels(pred_ys, label).mean(), state
@eqx.filter_jit
def make_step(model, state, opt_state, x, label):
(val, state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
model, state, x, label
)
updates, opt_state = opt.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, state, opt_state, val
@eqx.filter_jit
def inference(model, state, x):
keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
inference_model = eqx.Partial(eqx.tree_inference(model, value=True), state=state)
return jax.vmap(inference_model)(x, key=keys)
@eqx.filter_jit
def accuracy(outputs, labels):
predicted = jnp.argmax(outputs, 1)
correct = (predicted == labels).sum()
total = labels.size
return correct / total
######################################################
batch_size = 256
num_epochs = 100
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
model = eqxvision.models.resnet18(num_classes=10)
model = init_weights(model, xavier_init, key=jax.random.PRNGKey(5678))
opt = optax.sgd(learning_rate=0.01, momentum=0.9)
state = eqx.nn.State(model)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
for epoch in range(0, num_epochs):
running_loss, running_accuracy = 0.0, 0.0
for images, labels in tqdm(
trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
):
model, state, opt_state, loss_val = make_step(
model,
state,
opt_state,
images.numpy(),
labels.numpy(),
)
out = inference(model, state, images.numpy())
running_loss += loss_val
running_accuracy += accuracy(out[0], labels.numpy()) * 100
epoch_loss = running_loss / len(trainloader)
epoch_accuracy = running_accuracy / len(trainloader)
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
) Would be great if one of you could have a look here. |
Hmm two quick thoughts:
FWIW PyTorch and Equinox use slightly different batch norm implementations (there are many variants). But once you've removed batch norm, you could try initialising them with the same weights, and use the same training batches, and see if you can get close to bit-for-bit reproducibility. |
Btw, heads-up that Equinox v0.11.0 is now released! That shouldn't be a blocker any more for this PR. |
@patrick-kidger Sorry for the late reply. Regrading your two thoughts:
BTW, my simple test for comparing the two implementations was to overfit on CIFAR10 during training. |
Trying to revive PR #71 and solve issue #70. These changes should complete the refactor and use the new
state
mechanic.However, currently this causes the following error in the
Sequential
module:Can be reproduced with:
@patrick-kidger Any idea how to solve this? Should we get rid of the if-else statement causing this error that checks for
isinstance(layer, StatefulLayer)
in_sequential
?