Skip to content

Commit

Permalink
add gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
mjun0812 committed Nov 15, 2024
1 parent d8fba11 commit e1b554f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions config/__base__/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ compile_backend: "inductor"

use_clip_grad: True
clip_grad_norm: 10
gradient_accumulation_steps: 1

adjust_lr: False

Expand Down
1 change: 1 addition & 0 deletions src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ExperimentConfig(BaseConfig):

use_clip_grad: bool = True
clip_grad_norm: float = 10
gradient_accumulation_steps: int = 1

adjust_lr: bool = False

Expand Down
43 changes: 25 additions & 18 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
use_amp: bool = False,
amp_init_scale: int = 2**16,
amp_dtype: str = "fp16",
gradient_accumulation_steps: int = 1,
) -> None:
self.epochs = epochs
self.device = device
Expand Down Expand Up @@ -77,6 +78,10 @@ def __init__(
else:
self.scaler = GradScaler(init_scale=amp_init_scale, enabled=self.use_amp)

self.gradient_accumulation_steps = gradient_accumulation_steps
if self.gradient_accumulation_steps > 1:
logger.info(f"Gradient Accumulation Steps: {self.gradient_accumulation_steps}")

for phase in ["train", "val"]:
total_iters = len(self.dataloaders[phase]) * self.epochs
logger.info(f"Total {phase} Iterations per GPU: {total_iters}")
Expand All @@ -85,12 +90,12 @@ def __init__(
def do_one_epoch(
self,
phase: PhaseStr,
epoch: int,
current_epoch: int,
model: BaseModel,
) -> EpochResult:
hist_epoch_loss = HistoryEpochLoss()
pbar = self._setup_progress_bar(phase)
self._set_model_phase(model, phase, epoch)
self._set_model_phase(model, phase, current_epoch)
self.optimizer.zero_grad(set_to_none=True)

for i, data in pbar:
Expand All @@ -104,13 +109,13 @@ def do_one_epoch(
):
output: ModelOutput = model(data)
if phase == "train":
self.backward(output, model, i, epoch)
self.backward(output, model, i, current_epoch)

self.update_metrics_and_losses(phase, data, output, hist_epoch_loss)
if is_main_process():
self._update_pbar(pbar, epoch, output["losses"])
self._update_pbar(pbar, current_epoch, output["losses"])

return self.after_epoch(phase, epoch, hist_epoch_loss)
return self.after_epoch(phase, current_epoch, hist_epoch_loss)

def _setup_progress_bar(self, phase: PhaseStr):
progress_bar = enumerate(self.dataloaders[phase])
Expand All @@ -136,19 +141,21 @@ def prepare_input(self, data: dict, phase: PhaseStr) -> dict:
data = self.batched_transforms[phase](data)
return data

def backward(self, output: ModelOutput, model: BaseModel, i: int, epoch: int):
self.scaler.scale(output["losses"]["total_loss"]).backward()
if self.use_clip_grad:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad)
self.scaler.step(self.optimizer)
self.scaler.update()
if self.iter_lr_scheduler:
self.iter_lr_scheduler.step(
epoch=i + epoch * len(self.dataloaders["train"]),
metric=output["losses"]["total_loss"].item(),
)
self.optimizer.zero_grad(set_to_none=True)
def backward(self, output: ModelOutput, model: BaseModel, iter_idx: int, epoch: int):
total_loss = output["losses"]["total_loss"] / self.gradient_accumulation_steps
self.scaler.scale(total_loss).backward()
if (iter_idx + 1) % self.gradient_accumulation_steps == 0:
if self.use_clip_grad:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad)
self.scaler.step(self.optimizer)
self.scaler.update()
if self.iter_lr_scheduler:
self.iter_lr_scheduler.step(
epoch=iter_idx + epoch * len(self.dataloaders["train"]),
metric=output["losses"]["total_loss"].item(),
)
self.optimizer.zero_grad(set_to_none=True)

def update_metrics_and_losses(
self, phase: PhaseStr, data: dict, output: ModelOutput, hist_epoch_loss: HistoryEpochLoss
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def do_train(cfg: ExperimentConfig, device: torch.device, output_dir: Path, logg
use_amp=cfg.use_amp,
amp_init_scale=cfg.amp_init_scale,
amp_dtype=cfg.amp_dtype,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
)

logger.info("Start Training")
for epoch in range(start_epoch, cfg.epoch):
logger.phase = "train"
logger.log_metric("Epoch", epoch + 1, epoch + 1)

result = trainer.do_one_epoch(phase="train", epoch=epoch, model=model)
result = trainer.do_one_epoch(phase="train", current_epoch=epoch, model=model)
logger.log_metrics(result.epoch_losses, epoch + 1, "train")
logger.log_metric("Learning Rate", result.lr, epoch + 1, "train")
logger.log_artifact(f"{output_dir}/train.log")
Expand Down Expand Up @@ -184,7 +185,7 @@ def do_train(cfg: ExperimentConfig, device: torch.device, output_dir: Path, logg
ConfigManager.dump(cfg, output_dir / "config.yaml")

if (epoch + 1) % cfg.val_interval == 0:
result = trainer.do_one_epoch(phase="val", epoch=epoch, model=model)
result = trainer.do_one_epoch(phase="val", current_epoch=epoch, model=model)
logger.log_metrics(result.epoch_losses, epoch + 1, "val")
logger.log_metrics(result.metrics, epoch + 1, "val")
logger.log_metric("Learning Rate", result.lr, epoch + 1, "val")
Expand Down

0 comments on commit e1b554f

Please sign in to comment.