Adversarial training with lightning #14782
-
Hello! I'm attempting to do some simple adversarial training with lightning but I'm running in some issues for the testing part. with torch.enable_grad():
adv_img = self.atk(imgs, labels) This works fine during training (when doing trainer.fit(model), but fails during testing (trainer.test(model)), with RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn I checked similar problems and the solutions was to enable gradients (which I did) or remove automatic optimization but I would like to retain the possibility of accumulating gradients, and since trainer.fit works fine I don't see why I would need to do manual optimization. I did some digging and the enabling of gradients seems to work: Same thing happens if I do trainer.validate(model). Also when using other attacks than PGD. Any idea why is that and how I can fix it? Reproducible in colab Or full script to reproduce: import pytorch_lightning as pl
from torch import nn
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchattacks
class adv_model(pl.LightningModule):
def __init__(self,
model,
attack=None,
loaders=None,
loss_fn=nn.CrossEntropyLoss(),
optim="AdamW",
clean=False,
lr=0.01
):
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.loaders = loaders
self.atk = attack
self.clean = clean
self.lr = lr
if optim is None:
self.optim = torch.optim.AdamW
elif optim == "AdamW":
self.optim = torch.optim.AdamW
elif optim == "Adam":
self.optim = torch.optim.Adam
elif optim == "SGD":
self.optim = torch.optim.SGD
else:
raise ValueError(f"Optim should be in '[AdamW, Adam, SGD]', not {optim}")
def forward(self, x, clean=None):
return self.model(x)
def training_step(self, batch, batch_nb):
imgs, labels = batch
if not self.clean:
imgs = self.atk(imgs, labels)
logits = self.model(imgs)
loss = self.loss_fn(logits, labels)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
acc = (logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
self.log("train_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
return {"loss": loss, "acc": acc}
def validation_step(self, batch, batch_idx):
imgs, labels = batch
clean_logits = self.model(imgs)
clean_loss = self.loss_fn(clean_logits, labels)
clean_acc = (clean_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
self.log("clean_val_loss", clean_loss, prog_bar=True)
self.log('clean_val_acc', clean_acc, prog_bar=True)
if self.clean:
return clean_loss, clean_acc
# computing adversarial accuracy and loss
with torch.enable_grad():
adv_img = self.atk(imgs, labels)
adv_logits = self.model(adv_img)
adv_loss = self.loss_fn(adv_logits, labels)
self.log("adv_val_loss", adv_loss, prog_bar=True)
adv_acc = (adv_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
self.log('adv_val_acc', adv_acc, prog_bar=True)
return clean_loss, clean_acc, adv_loss, adv_acc
def test_step(self, batch, batch_idx):
imgs, labels = batch
clean_logits = self.model(imgs)
clean_loss = self.loss_fn(clean_logits, labels)
clean_acc = (clean_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
self.log("clean_test_loss", clean_loss, prog_bar=True)
self.log('clean_test_acc', clean_acc, prog_bar=True)
if self.clean:
return clean_loss, clean_acc
# computing adversarial accuracy and loss
with torch.enable_grad():
adv_img = self.atk(imgs, labels)
adv_logits = self.model(adv_img)
adv_loss = self.loss_fn(adv_logits, labels)
self.log("adv_test_loss", adv_loss, prog_bar=True)
adv_acc = (adv_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
self.log('adv_test_acc', adv_acc, prog_bar=True)
return clean_loss, clean_acc, adv_loss, adv_acc
def configure_optimizers(self):
optim = self.optim
if issubclass(optim, torch.optim.SGD):
if self.lr is not None:
return optim(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
else:
return optim(self.model.parameters(), momentum=0.9, weight_decay=1e-4)
elif issubclass(optim, (torch.optim.Adam, torch.optim.AdamW)):
if self.lr is not None:
return optim(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
else:
return optim(self.model.parameters(), weight_decay=1e-4)
else:
return self.optim
def train_dataloader(self):
return self.loaders[0]
def val_dataloader(self):
return self.loaders[1]
def test_dataloader(self):
return self.loaders[2]
trainer = pl.Trainer(accelerator="gpu",
max_epochs=3,
val_check_interval=1.0,
)
base_model = torch.nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, 10))
train_set = MNIST(root="./",
transform=T.ToTensor(),
download=True,
train=True
)
test_set = MNIST(root="./",
transform=T.ToTensor(),
download=True,
train=False
)
train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)
val_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)
loaders = (train_loader, val_loader, test_loader)
atk = torchattacks.PGD(model=base_model.cuda(), steps=10)
model = adv_model(base_model,
loaders=loaders,
attack=atk,
clean=False,
optim="Adam")
trainer.fit(model)
trainer.test(model) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
Answering to myself: For now the solution I have is to change the function @contextmanager
def _evaluation_context(accelerator: Accelerator) -> Generator:
# inference mode is not supported with gloo backend (#9431),
# and HPU & TPU accelerators.
context_manager_class = (
torch.inference_mode
if not (dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(accelerator, HPUAccelerator)
and not isinstance(accelerator, TPUAccelerator)
else torch.no_grad
)
with context_manager_class():
yield to always use torch.no_grad (l2794 in trainer.py). If there is a simple alternative to use in the test_step or some parameters to force the use of no_grad instead of inference_mode I'm all ears. |
Beta Was this translation helpful? Give feedback.
-
Could you please let us know if there have been any updates on this issue? @sergedurand |
Beta Was this translation helpful? Give feedback.
Answering to myself:
After more digging, it seems that it is the use of torch.inference_mode that is the cause of the issue.
Using torch.no_grad is not enough to get out of inference_mode.
In fact getting out of inference_mode with e.g with torch.inference_mode(mode=False) or a decorator is not enough, I then have a problem
with Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
For now the solution I have is to change the function