diff --git a/CHANGELOG.md b/CHANGELOG.md index c0a7c23e59..339787911c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Graph Transformer processor for GraphCast/GenCast. - Utility to generate STL from Signed Distance Field. +- Improved Pangu training code ### Changed - Refactored CorrDiff training recipe for improved usability +- Refactored Pangu model for better extensibility and gradient checkpointing support. + Some of these changes are not backward compatible. ### Deprecated diff --git a/examples/weather/pangu_weather/conf/config.yaml b/examples/weather/pangu_weather/conf/config.yaml index f2b08d4818..5254d25fbc 100644 --- a/examples/weather/pangu_weather/conf/config.yaml +++ b/examples/weather/pangu_weather/conf/config.yaml @@ -24,30 +24,91 @@ hydra: run: dir: ./outputs/ -start_epoch: 0 -max_epoch: 100 +max_epoch: 101 + +pangu: + img_size: [721, 1440] + patch_size: [2, 4, 4] + embed_dim: 192 + num_heads: [6, 12, 12, 6] + window_size: [2, 6, 12] + number_constant_variables: 3 + number_surface_variables: 4 + number_atmosphere_variables: 5 + number_atmosphere_levels: 13 + number_up_sampled_blocks: 2 + number_down_sampled_blocks: 6 + checkpoint_flag: True train: data_dir: "/data/train/" stats_dir: "/data/stats/" + checkpoint_dir: "/data/checkpoints/" + mask_dir: "/data/constant_mask/" + num_samples_per_year: 1456 - batch_size: 1 - patch_size: [1, 1] num_workers: 8 + lr: 1e-3 + weight_decay: 0.05 + use_cosine_zenith: True + mask_dtype: "float32" + enable_amp: False + enable_graphs: False + + stages: + - name: "Learning Rate Warmup" + max_iterations: .inf + num_epochs: 1 + batch_size: 1 + num_rollout_steps: 1 + lr_scheduler_name: LinearLR + args: + start_factor: 0.001 + end_factor: 1.0 + total_iters: 1 + + - name: "Cosine Annealing LR" + max_iterations: .inf + num_epochs: 100 + batch_size: 1 + num_rollout_steps: 1 + lr_scheduler_name: CosineAnnealingLR + args: + T_max: 100 + eta_min: 0.0 + + - name: "LambdaLR Rollout: 2 Steps" + max_iterations: .inf + num_epochs: 20 + batch_size: 2 + num_rollout_steps: 2 + lr_scheduler_name: LambdaLR + args: + lr_lambda: ${lambda_lr:3e-7,${train.lr}} + + - name: "LambdaLR Rollout: 3 Steps" + max_iterations: .inf + num_epochs: 20 + batch_size: 2 + num_rollout_steps: 3 + lr_scheduler_name: LambdaLR + args: + lr_lambda: ${lambda_lr:3e-7,${train.lr}} + + - name: "LambdaLR Rollout: 4 Steps" + max_iterations: .inf + num_epochs: 20 + batch_size: 1 + num_rollout_steps: 4 + lr_scheduler_name: LambdaLR + args: + lr_lambda: ${lambda_lr:3e-7,${train.lr}} + val: data_dir: "/data/test/" stats_dir: "/data/stats/" - num_samples_per_year: 4 + num_samples_per_year: 32 batch_size: 1 - patch_size: [1, 1] num_workers: 8 - -pangu: - img_size: [721, 1440] - patch_size: [2, 4, 4] - embed_dim: 192 - num_heads: [6, 12, 12, 6] - window_size: [2, 6, 12] - -mask_dir: "/data/constant_mask" -mask_dtype: "float32" + num_rollout_steps: 6 + channels: [0, 1, 2, 3, 4, 7, 43] diff --git a/examples/weather/pangu_weather/conf/config_lite.yaml b/examples/weather/pangu_weather/conf/config_lite.yaml index 548fc239b0..2e864486d7 100644 --- a/examples/weather/pangu_weather/conf/config_lite.yaml +++ b/examples/weather/pangu_weather/conf/config_lite.yaml @@ -16,7 +16,7 @@ experiment_name: "Modulus-Launch-Dev" experiment_desc: "Modulus launch development" -run_desc: "Pangu lite ERA5 Training" +run_desc: "Pangu ERA5 Lite-Training" hydra: job: @@ -24,30 +24,64 @@ hydra: run: dir: ./outputs/ -start_epoch: 0 -max_epoch: 100 +max_epoch: 11 + +pangu: + img_size: [721, 1440] + patch_size: [2, 4, 4] + embed_dim: 96 + num_heads: [6, 12, 12, 6] + window_size: [2, 6, 12] + number_constant_variables: 3 + number_surface_variables: 4 + number_atmosphere_variables: 5 + number_atmosphere_levels: 13 + number_up_sampled_blocks: 2 + number_down_sampled_blocks: 6 + checkpoint_flag: True train: data_dir: "/data/train/" stats_dir: "/data/stats/" - num_samples_per_year: 1456 - batch_size: 1 - patch_size: [1, 1] + checkpoint_dir: "/data/checkpoints/" + mask_dir: "/data/constant_mask/" + + num_samples_per_year: 600 num_workers: 8 + lr: 1e-3 + weight_decay: 0.05 + use_cosine_zenith: True + mask_dtype: "float32" + enable_amp: False + enable_graphs: False + + stages: + - name: "Learning Rate Warmup" + max_iterations: .inf + num_epochs: 1 + batch_size: 1 + num_rollout_steps: 1 + lr_scheduler_name: LinearLR + args: + start_factor: 0.001 + end_factor: 1.0 + total_iters: 1 + + - name: "Cosine Annealing LR" + max_iterations: .inf + num_epochs: 10 + batch_size: 1 + num_rollout_steps: 1 + lr_scheduler_name: CosineAnnealingLR + args: + T_max: 10 + eta_min: 0.0 + val: data_dir: "/data/test/" stats_dir: "/data/stats/" - num_samples_per_year: 4 + num_samples_per_year: 1 batch_size: 1 - patch_size: [1, 1] num_workers: 8 - -pangu: - img_size: [721, 1440] - patch_size: [2, 8, 8] - embed_dim: 192 - num_heads: [6, 12, 12, 6] - window_size: [2, 6, 12] - -mask_dir: "/data/constant_mask" -mask_dtype: "float32" + num_rollout_steps: 1 + channels: [0, 1, 2, 3, 4, 7, 43] \ No newline at end of file diff --git a/examples/weather/pangu_weather/requirements.txt b/examples/weather/pangu_weather/requirements.txt deleted file mode 100644 index c51f257854..0000000000 --- a/examples/weather/pangu_weather/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -mlflow>=2.1.1 \ No newline at end of file diff --git a/examples/weather/pangu_weather/train_pangu_era5.py b/examples/weather/pangu_weather/train_pangu_era5.py index 119ba454d3..f3e6e2cae9 100644 --- a/examples/weather/pangu_weather/train_pangu_era5.py +++ b/examples/weather/pangu_weather/train_pangu_era5.py @@ -19,17 +19,25 @@ import hydra import numpy as np import matplotlib.pyplot as plt +import pandas as pd from torch.nn.parallel import DistributedDataParallel +from torch.cuda import amp +from torch.cuda.amp import GradScaler + from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + from modulus.models.pangu import Pangu from modulus.datapipes.climate import ERA5HDF5Datapipe from modulus.distributed import DistributedManager from modulus.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad -from modulus.launch.logging import LaunchLogger, PythonLogger -from modulus.launch.logging.mlflow import initialize_mlflow +from modulus.launch.logging import ( + RankZeroLoggingWrapper, + PythonLogger, +) from modulus.launch.utils import load_checkpoint, save_checkpoint try: @@ -40,118 +48,166 @@ + "See https://github.com/nvidia/apex for install details." ) +OmegaConf.register_new_resolver("lambda_lr", lambda x, y: (lambda epoch: x / y)) +torch._dynamo.config.optimize_ddp = False + -def loss_func(x, y): - return torch.nn.functional.l1_loss(x, y) +@torch.jit.script +def loss_func(x: torch.Tensor, y: torch.Tensor, weights: torch.Tensor): + return torch.mean( + weights * torch.nn.functional.smooth_l1_loss(x, y, reduction="none", beta=0.5) + ) / torch.mean(weights) @torch.no_grad() def validation_step( - eval_step, pangu_model, datapipe, surface_mask, channels=[0, 1], epoch=0 + eval_step, + pangu_model, + datapipe, + surface_mask, + weights, + channels, + epoch, ): - loss_epoch = 0 + + num_channels = len(channels) + num_steps = datapipe.num_steps + loss_epoch = torch.zeros((num_channels, num_steps), device="cpu") num_examples = 0 # Number of validation examples + # Dealing with DDP wrapper if hasattr(pangu_model, "module"): pangu_model = pangu_model.module + pangu_model.eval() - for i, data in enumerate(datapipe): - invar_surface = data[0]["invar"].detach()[:, :4, :, :] - invar_upper_air = ( - data[0]["invar"] - .detach()[:, 4:, :, :] - .reshape( - ( - data[0]["invar"].shape[0], - 5, - -1, - data[0]["invar"].shape[2], - data[0]["invar"].shape[3], - ) + + # Loop over datapipe + for di, data in enumerate(datapipe): + # Get input data + invar = data[0]["invar"].detach().to("cuda:0") + cos_zenith = data[0]["cos_zenith"].detach().to("cuda:0").squeeze(dim=2) + cos_zenith = torch.clamp(cos_zenith, min=0.0) - 1.0 / torch.pi + outvar = data[0]["outvar"].detach() + sm = surface_mask.repeat(invar.shape[0], 1, 1, 1) + + num_examples += outvar.shape[0] + + # If first batch then create buffer for outputs + if di == 0: + outpred = torch.zeros_like(outvar, device="cpu").pin_memory() + for t in range(outvar.shape[1]): + out, loss = eval_step( + pangu_model, + invar, + cos_zenith[:, t : t + 1], + sm, + outvar[:, t].to("cuda:0"), + weights, ) - ) - outvar_surface = data[0]["outvar"].cpu().detach()[:, :, :4, :, :] - outvar_upper_air = ( - data[0]["outvar"] - .cpu() - .detach()[:, :, 4:, :, :] - .reshape( - ( - data[0]["outvar"].shape[0], - data[0]["outvar"].shape[1], - 5, - -1, - data[0]["outvar"].shape[3], - data[0]["outvar"].shape[4], + invar = out.clone() + out = out.detach().cpu() + loss = loss.detach().cpu() + + # Normalize + out = out * datapipe.sd + datapipe.mu + loss = loss * datapipe.sd[0, :, 0, 0] + loss_epoch[:, t] += loss[channels] + + # If first batch then save out to buffer + if di == 0: + outpred[:, t].copy_(out, non_blocking=True) + + # If first batch plot images + if (di == 0) and (epoch % 10 == 0): + os.makedirs("./images", exist_ok=True) + num_plots = max(4, num_steps) + + outvar = outvar.cpu() * datapipe.sd + datapipe.mu + for i, ch in enumerate(channels): + fig, ax = plt.subplots( + nrows=3, ncols=num_plots, figsize=(4 * num_plots, 2 * 3 + 1) ) - ) - ) - predvar_surface = torch.zeros_like(outvar_surface) - predvar_upper_air = torch.zeros_like(outvar_upper_air) - for t in range(outvar_surface.shape[1]): - output_surface, output_upper_air = eval_step( - pangu_model, invar_surface, surface_mask, invar_upper_air - ) - invar_surface.copy_(output_surface) - invar_upper_air.copy_(output_upper_air) - predvar_surface[:, t] = output_surface.detach().cpu() - predvar_upper_air[:, t] = output_upper_air.detach().cpu() - - num_elements_surface = torch.prod(torch.Tensor(list(predvar_surface.shape[1:]))) - num_elements_upper_air = torch.prod( - torch.Tensor(list(predvar_upper_air.shape[1:])) - ) - loss_epoch += ( - torch.sum(torch.pow(predvar_surface - outvar_surface, 2)) - + torch.sum(torch.pow(predvar_upper_air - outvar_upper_air, 2)) - ) / (num_elements_surface + num_elements_upper_air) + for j in range(num_plots): + op = outpred[0, j, ch] + ov = outvar[0, j, ch] + vmin = ov.min().item() + vmax = ov.max().item() + pred = ax[0, j].imshow(op, vmin=vmin, vmax=vmax) + ax[0, j].set_title(f"Channel {ch} Step {j} Prediction") + plt.colorbar( + pred, ax=ax[0, j], shrink=0.75, orientation="horizontal" + ) + + truth = ax[1, j].imshow(ov, vmin=vmin, vmax=vmax) + ax[1, j].set_title(f"Channel {ch} Step {j} Truth") + plt.colorbar( + truth, ax=ax[1, j], shrink=0.75, orientation="horizontal" + ) + + diff = ax[2, j].imshow((op - ov) / ov.abs().mean()) + ax[2, j].set_title(f"Channel {ch} Step {j} Relative Error") + plt.colorbar( + diff, ax=ax[2, j], shrink=0.75, orientation="horizontal" + ) + + plt.tight_layout() + plt.savefig( + f"./images/diff_channel_{ch}_epoch_{epoch}.png", + dpi=600, + bbox_inches="tight", + ) + plt.clf() + + loss_epoch = torch.sqrt(loss_epoch / num_examples).numpy() + + # Save losses + csv_file_name = "validation_rmse_loss.csv" + try: + # See if there is an existing file. + df = pd.read_csv(csv_file_name, index_col=0) + except FileNotFoundError: + # Create a new dataframe otherwise. + df = pd.DataFrame(columns=["epoch", "channel_id", "step", "loss"]) - num_examples += predvar_surface.shape[0] + dd = [] + for i, ch in enumerate(channels): + for j in range(datapipe.num_steps): + dd.append([epoch, ch, j, loss_epoch[i, j]]) + + df = pd.concat([df, pd.DataFrame(dd, columns=df.columns)], ignore_index=True) + df.to_csv(csv_file_name) pangu_model.train() - return loss_epoch / num_examples -@hydra.main(version_base="1.2", config_path="conf", config_name="config") +@hydra.main(version_base="1.2", config_path="conf", config_name="config_internal") def main(cfg: DictConfig) -> None: DistributedManager.initialize() dist = DistributedManager() # Initialize loggers - initialize_mlflow( - experiment_name=cfg.experiment_name, - experiment_desc=cfg.experiment_desc, - run_name="Pangu-trainng", - run_desc=cfg.experiment_desc, - user_name="Modulus User", - mode="offline", - ) - LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger logger = PythonLogger("main") # General python logger + rank_zero_logger = RankZeroLoggingWrapper(logger, dist) + rank_zero_logger.file_logging() - number_channels_pangu = 4 + 5 * 13 - datapipe = ERA5HDF5Datapipe( - data_dir=cfg.train.data_dir, - stats_dir=cfg.train.stats_dir, - channels=[i for i in range(number_channels_pangu)], - num_samples_per_year=cfg.train.num_samples_per_year, - batch_size=cfg.train.batch_size, - patch_size=OmegaConf.to_object(cfg.train.patch_size), - num_workers=cfg.train.num_workers, - device=dist.device, - process_rank=dist.rank, - world_size=dist.world_size, + # print ranks and devices + logger.info(f"Rank: {dist.rank}, Device: {dist.device}") + + number_channels_pangu = ( + cfg.pangu.number_surface_variables + + cfg.pangu.number_atmosphere_variables * cfg.pangu.number_atmosphere_levels ) - logger.success(f"Loaded datapipe of size {len(datapipe)}") + img_size = OmegaConf.to_object(cfg.pangu.img_size) - mask_dir = cfg.mask_dir - if cfg.get("mask_dtype", "float32") == "float32": + mask_dir = cfg.train.mask_dir + if cfg.train.get("mask_dtype", "float32") == "float32": mask_dtype = np.float32 - elif cfg.get("mask_dtype", "float32") == "float16": + elif cfg.train.get("mask_dtype", "float32") == "float16": mask_dtype = np.float16 else: mask_dtype = np.float32 + land_mask = torch.from_numpy( np.load(os.path.join(mask_dir, "land_mask.npy")).astype(mask_dtype) ) @@ -161,35 +217,38 @@ def main(cfg: DictConfig) -> None: topography = torch.from_numpy( np.load(os.path.join(mask_dir, "topography.npy")).astype(mask_dtype) ) - surface_mask = torch.stack([land_mask, soil_type, topography], dim=0).to( - dist.device + topography = (topography - topography.mean()) / topography.std() + surface_mask = ( + torch.stack([land_mask, soil_type, topography], dim=0) + .to(dist.device) + .unsqueeze(0) ) - logger.success(f"Loaded suface constant mask from {mask_dir}") - - if dist.rank == 0: - logger.file_logging() - validation_datapipe = ERA5HDF5Datapipe( - data_dir=cfg.val.data_dir, - stats_dir=cfg.val.stats_dir, - channels=[i for i in range(number_channels_pangu)], - num_steps=1, - num_samples_per_year=cfg.val.num_samples_per_year, - batch_size=cfg.val.batch_size, - patch_size=OmegaConf.to_object(cfg.val.patch_size), - device=dist.device, - num_workers=cfg.val.num_workers, - shuffle=False, - ) - logger.success(f"Loaded validaton datapipe of size {len(validation_datapipe)}") + logger.success(f"Rank {dist.rank}: Loaded suface constant mask from {mask_dir}") pangu_model = Pangu( - img_size=OmegaConf.to_object(cfg.pangu.img_size), + img_size=img_size, patch_size=OmegaConf.to_object(cfg.pangu.patch_size), embed_dim=cfg.pangu.embed_dim, num_heads=OmegaConf.to_object(cfg.pangu.num_heads), window_size=OmegaConf.to_object(cfg.pangu.window_size), + number_constant_variables=cfg.pangu.number_constant_variables + + int(cfg.train.use_cosine_zenith), + number_surface_variables=cfg.pangu.number_surface_variables, + number_atmosphere_levels=cfg.pangu.number_atmosphere_levels, + number_atmosphere_variables=cfg.pangu.number_atmosphere_variables, + number_up_sampled_blocks=cfg.pangu.number_up_sampled_blocks, + number_down_sampled_blocks=cfg.pangu.number_down_sampled_blocks, + checkpoint_flag=cfg.pangu.checkpoint_flag, ).to(dist.device) + # pangu_model.compile() + + weights = ( + torch.abs(torch.cos(torch.linspace(90, -90, img_size[0]) * torch.pi / 180.0)) + .unsqueeze(1) + .repeat(1, img_size[1]) + .to(dist.device) + ) # Distributed learning if dist.world_size > 1: ddps = torch.cuda.Stream() @@ -203,122 +262,197 @@ def main(cfg: DictConfig) -> None: ) torch.cuda.current_stream().wait_stream(ddps) + # pangu_model = torch.compile(pangu_model, mode = "max-autotune") # Initialize optimizer and scheduler optimizer = optimizers.FusedAdam( - pangu_model.parameters(), betas=(0.9, 0.999), lr=0.0005, weight_decay=0.000003 + pangu_model.parameters(), + betas=(0.9, 0.999), + lr=cfg.train.lr, + weight_decay=cfg.train.weight_decay, ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) - # Attempt to load latest checkpoint if one exists - loaded_epoch = load_checkpoint( - "./checkpoints", - models=pangu_model, - optimizer=optimizer, - scheduler=scheduler, - device=dist.device, - ) + # Load Validation Datapipe + if dist.rank == 0: + validation_datapipe = ERA5HDF5Datapipe( + data_dir=cfg.val.data_dir, + stats_dir=cfg.val.stats_dir, + channels=[i for i in range(number_channels_pangu)], + num_steps=cfg.val.num_rollout_steps, + num_samples_per_year=cfg.val.num_samples_per_year, + batch_size=cfg.val.batch_size, + device="cpu", + num_workers=cfg.val.num_workers, + shuffle=False, + use_cos_zenith=cfg.train.use_cosine_zenith, + cos_zenith_args={ + "dt": 6.0, + "start_year": 1980, + "latlon_bounds": ((90, -90), (0, 360)), + }, + latlon_resolution=img_size, + ) + logger.success( + f"Rank {dist.rank}: Loaded validaton datapipe of size {len(validation_datapipe)}" + ) @StaticCaptureEvaluateNoGrad(model=pangu_model, logger=logger, use_graphs=False) - def eval_step_forward(my_model, invar_surface, surface_mask, invar_upper_air): - invar = my_model.prepare_input(invar_surface, surface_mask, invar_upper_air) - return my_model(invar) + def eval_step_forward(my_model, invar, cos_zenith, surface_mask, outvar, weights): + # Multi-step prediction + invar = torch.concat([surface_mask, cos_zenith, invar], dim=1) + outpred = my_model(invar) + loss = torch.sum( + weights * (outpred - outvar) ** 2, dim=(0, -2, -1) + ) / torch.sum(weights) + return outpred, loss @StaticCaptureTraining( - model=pangu_model, optim=optimizer, logger=logger, use_graphs=False + model=pangu_model, + optim=optimizer, + logger=logger, + use_graphs=cfg.train.enable_graphs, + use_amp=cfg.train.enable_amp, + gradient_clip_norm=cfg.train.get("gradient_clip_norm", None), ) - def train_step_forward( - my_model, - invar_surface, - surface_mask, - invar_upper_air, - outvar_surface, - outvar_upper_air, - ): + def train_step_forward(my_model, invar, cos_zenith, surface_mask, outvar, weights): # Multi-step prediction loss = 0 - # Multi-step not supported - for t in range(outvar_surface.shape[1]): - invar = my_model.prepare_input(invar_surface, surface_mask, invar_upper_air) - outpred_surface, outpred_upper_air = my_model(invar) - invar_surface = outpred_surface - invar_upper_air = outpred_upper_air - loss += loss_func(outpred_surface, outvar_surface[:, t]) * 0.25 + loss_func( - outpred_upper_air, outvar_upper_air[:, t] - ) + batch_size = outvar.shape[0] + for b in range(batch_size): + invar_ = invar[b : b + 1] + cos_zenith_ = cos_zenith[b : b + 1] + for t in range(outvar.shape[1]): + invar_ = torch.concat( + [surface_mask, cos_zenith_[:, t : t + 1], invar_], dim=1 + ) + outpred = my_model(invar_) + loss += loss_func(outpred, outvar[b : b + 1, t], weights) / batch_size + invar_ = outpred + return loss # Main training loop - max_epoch = cfg.max_epoch - for epoch in range(max(1, loaded_epoch + 1), max_epoch + 1): - # Wrap epoch in launch logger for console / WandB logs - with LaunchLogger( - "train", epoch=epoch, num_mini_batch=len(datapipe), epoch_alert_freq=10 - ) as log: - # === Training step === - for j, data in enumerate(datapipe): - invar_surface = data[0]["invar"][:, :4, :, :] - invar_upper_air = data[0]["invar"][:, 4:, :, :].reshape( - ( - data[0]["invar"].shape[0], - 5, - -1, - data[0]["invar"].shape[2], - data[0]["invar"].shape[3], - ) + + # Attempt to load latest checkpoint if one exists + loaded_epoch = load_checkpoint( + cfg.train.checkpoint_dir, + models=pangu_model, + optimizer=optimizer, + scheduler=None, + scaler=None, + device=dist.device, + ) + + rank_zero_logger.info("Rank {dist.rank}: Training started.") + global_epoch = 0 + for stage in cfg.train.stages: + if loaded_epoch > global_epoch: + if loaded_epoch >= global_epoch + stage.num_epochs: + # Skip stage + global_epoch += stage.num_epochs + continue + else: + num_epochs = stage.num_epochs - (loaded_epoch - global_epoch) + global_epoch = loaded_epoch + else: + num_epochs = stage.num_epochs + + rank_zero_logger.info( + f"Rank {dist.rank}: Starting stage {stage.name} at epoch # {loaded_epoch}." + ) + + # Load datapipe for this stage + train_datapipe = ERA5HDF5Datapipe( + data_dir=cfg.train.data_dir, + stats_dir=cfg.train.stats_dir, + channels=[i for i in range(number_channels_pangu)], + num_samples_per_year=cfg.train.num_samples_per_year, + use_cos_zenith=cfg.train.use_cosine_zenith, + cos_zenith_args={ + "dt": 6.0, + "start_year": 1980, + "latlon_bounds": ((90, -90), (0, 360)), + }, + num_steps=stage.num_rollout_steps, + latlon_resolution=img_size, + batch_size=stage.batch_size, + num_workers=cfg.train.num_workers, + device=dist.device, + process_rank=dist.rank, + world_size=dist.world_size, + ) + logger.success( + f"Rank {dist.rank}: Loaded datapipe of size {len(train_datapipe)}" + ) + + # Initialize scheduler + SchedulerClass = getattr(torch.optim.lr_scheduler, stage.lr_scheduler_name) + scheduler = SchedulerClass(optimizer, **stage.args) + + # Set scheduler to current step + scheduler.step(stage.num_epochs - num_epochs) + + # Get current step for checking if max iterations is reached + current_step = len(train_datapipe) * (stage.num_epochs - num_epochs) + + for epoch in range(num_epochs): + logger.info(f"Rank {dist.rank}: Starting Epoch {global_epoch}.") + loss_agg = 0.0 + for j, data in tqdm(enumerate(train_datapipe), disable=(dist.rank != 0)): + if current_step > stage.max_iterations: + break + + invar = data[0]["invar"] + outvar = data[0]["outvar"] + cos_zenith = data[0]["cos_zenith"].squeeze(dim=2) + cos_zenith = torch.clamp(cos_zenith, min=0.0) - 1.0 / torch.pi + loss_agg += train_step_forward( + pangu_model, invar, cos_zenith, surface_mask, outvar, weights ) - outvar_surface = data[0]["outvar"][:, :, :4, :, :] - outvar_upper_air = data[0]["outvar"][:, :, 4:, :, :].reshape( - ( - data[0]["outvar"].shape[0], - data[0]["outvar"].shape[1], - 5, - -1, - data[0]["outvar"].shape[3], - data[0]["outvar"].shape[4], + + current_step += 1 + + if ( + current_step % int(len(train_datapipe) // 5) == 0 + ) and dist.rank == 0: + tqdm.write( + f"Epoch: {global_epoch} \t iteration: {current_step} " + + f"\t loss: {loss_agg / int(len(train_datapipe) // 5)}" ) - ) - loss = train_step_forward( - pangu_model, - invar_surface, - surface_mask, - invar_upper_air, - outvar_surface, - outvar_upper_air, + loss_agg = 0.0 + + # Step scheduler + scheduler.step() + + # Perform validation + if dist.rank == 0: + del invar, cos_zenith, outvar + torch.cuda.empty_cache() + + # Use Modulus Launch checkpoint + save_checkpoint( + cfg.train.checkpoint_dir, + models=pangu_model, + optimizer=optimizer, + scheduler=scheduler, + scaler=None, + epoch=global_epoch + 1, ) - log.log_minibatch({"loss": loss.detach()}) - log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) - - if dist.rank == 0: - # Wrap validation in launch logger for console / WandB logs - with LaunchLogger("valid", epoch=epoch) as log: - # === Validation step === - error = validation_step( + validation_step( eval_step_forward, pangu_model, validation_datapipe, surface_mask, - epoch=epoch, + weights, + cfg.val.channels, + global_epoch + 1, ) - log.log_epoch({"Validation error": error}) - - if dist.world_size > 1: + global_epoch += 1 + torch.cuda.empty_cache() torch.distributed.barrier() - scheduler.step() - - if (epoch % 5 == 0 or epoch == 1) and dist.rank == 0: - # Use Modulus Launch checkpoint - save_checkpoint( - "./checkpoints", - models=pangu_model, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - ) - if dist.rank == 0: - logger.info("Finished training!") + logger.info("Rank {dist.rank}: Finished training!") if __name__ == "__main__": diff --git a/modulus/models/layers/transformer_layers.py b/modulus/models/layers/transformer_layers.py index ab94b02c68..19f8aa5bc6 100644 --- a/modulus/models/layers/transformer_layers.py +++ b/modulus/models/layers/transformer_layers.py @@ -20,8 +20,9 @@ from timm.layers import to_2tuple from timm.models.swin_transformer import SwinTransformerStage from torch import nn +from torch.utils.checkpoint import checkpoint -from ..utils import ( +from modulus.models.utils import ( PatchEmbed2D, PatchRecovery2D, crop2d, @@ -32,6 +33,7 @@ window_partition, window_reverse, ) + from .attention_layers import EarthAttention2D, EarthAttention3D from .drop import DropPath from .mlp_layers import Mlp @@ -369,6 +371,7 @@ def __init__( attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, + checkpoint_flag: bool = False, ): super().__init__() self.dim = dim @@ -397,9 +400,14 @@ def __init__( ] ) + self.checkpoint_flag = checkpoint_flag + def forward(self, x): for blk in self.blocks: - x = blk(x) + if self.checkpoint_flag: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) return x @@ -445,6 +453,7 @@ def __init__( attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, + checkpoint_flag: bool = False, ): super().__init__() self.in_chans = in_chans @@ -519,16 +528,24 @@ def __init__( ] ) + self.checkpoint_flag = checkpoint_flag + def forward(self, x): x = self.patchembed2d(x) B, C, Lat, Lon = x.shape x = x.reshape(B, C, -1).transpose(1, 2) for blk in self.blocks: - x = blk(x) + if self.checkpoint_flag: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) skip = x.reshape(B, Lat, Lon, C) x = self.downsample(x) for blk in self.blocks_middle: - x = blk(x) + if self.checkpoint_flag: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) return x, skip @@ -574,6 +591,7 @@ def __init__( attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, + checkpoint_flag: bool = False, ): super().__init__() self.out_chans = out_chans @@ -645,13 +663,21 @@ def __init__( self.patchrecovery2d = PatchRecovery2D(img_size, patch_size, 2 * dim, out_chans) + self.checkpoint_flag = checkpoint_flag + def forward(self, x, skip): B, Lat, Lon, C = skip.shape for blk in self.blocks_middle: - x = blk(x) + if self.checkpoint_flag: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) x = self.upsample(x) for blk in self.blocks: - x = blk(x) + if self.checkpoint_flag: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) output = torch.concat([x, skip.reshape(B, -1, C)], dim=-1) output = output.transpose(1, 2).reshape(B, -1, Lat, Lon) output = self.patchrecovery2d(output) diff --git a/modulus/models/pangu/__init__.py b/modulus/models/pangu/__init__.py index 503f0c258a..151c30dd26 100644 --- a/modulus/models/pangu/__init__.py +++ b/modulus/models/pangu/__init__.py @@ -15,3 +15,4 @@ # limitations under the License. from .pangu import Pangu +from .pangu_processor import PanguProcessor diff --git a/modulus/models/pangu/pangu.py b/modulus/models/pangu/pangu.py index a14d4f018f..b15d734128 100644 --- a/modulus/models/pangu/pangu.py +++ b/modulus/models/pangu/pangu.py @@ -20,10 +20,10 @@ import numpy as np import torch -from ..layers import DownSample3D, FuserLayer, UpSample3D -from ..meta import ModelMetaData -from ..module import Module -from ..utils import ( +from modulus.models.meta import ModelMetaData +from modulus.models.module import Module +from modulus.models.pangu.pangu_processor import PanguProcessor +from modulus.models.utils import ( PatchEmbed2D, PatchEmbed3D, PatchRecovery2D, @@ -53,12 +53,34 @@ class Pangu(Module): Pangu A PyTorch impl of: `Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast` - https://arxiv.org/abs/2211.02556 - Args: - img_size (tuple[int]): Image size [Lat, Lon]. - patch_size (tuple[int]): Patch token size [Lat, Lon]. - embed_dim (int): Patch embedding dimension. Default: 192 - num_heads (tuple[int]): Number of attention heads in different layers. - window_size (tuple[int]): Window size. + Parameters + img_size: tuple[int] + Image size [lat, lon] + patch_size: tuple[int] + Patch-embedding shape + embed_dim: int + Embedding dimension size, be default 192. + num_heads: tuple[int] + Number of attention heads to use for each Fuser Layer. + window_size: tuple[int] + Window size in 3D attention window mechanism. + number_constant_variables: int + The number of constant variables (do not change in time). + number_surface_variables: int + The number of surface variables (not including constant variables). + By default 4 + number_atmosphere_variables: int + The number of atmosphere variables per atmosphere level. + By default 5 + number_atmosphere_levels: int + The number of pressure levels in the atmosphere. + By default 13. + number_up_sampled_blocks: int + The number of upsampled blocks in the Earth-specific Transformer blocks. + number_down_sampled_blocks: int + The number of downsampled blocks in the Earth-specific Transformer blocks. + checkpoint_flag: int + Whether to use gradient checkpointing in training. """ def __init__( @@ -68,20 +90,34 @@ def __init__( embed_dim=192, num_heads=(6, 12, 12, 6), window_size=(2, 6, 12), + number_constant_variables=3, + number_surface_variables=4, + number_atmosphere_variables=5, + number_atmosphere_levels=13, + number_up_sampled_blocks=2, + number_down_sampled_blocks=6, + checkpoint_flag: bool = False, ): super().__init__(meta=MetaData()) - drop_path = np.linspace(0, 0.2, 8).tolist() + drop_path = np.linspace( + 0, 0.2, number_up_sampled_blocks + number_down_sampled_blocks + ).tolist() # In addition, three constant masks(the topography mask, land-sea mask and soil type mask) + self.number_constant_variables = number_constant_variables + self.number_surface_variables = number_surface_variables + self.number_air_variables = number_atmosphere_variables + self.number_air_levels = number_atmosphere_levels self.patchembed2d = PatchEmbed2D( img_size=img_size, patch_size=patch_size[1:], - in_chans=4 + 3, # add + in_chans=self.number_surface_variables + + self.number_constant_variables, # add embed_dim=embed_dim, ) self.patchembed3d = PatchEmbed3D( - img_size=(13, img_size[0], img_size[1]), + img_size=(number_atmosphere_levels, img_size[0], img_size[1]), patch_size=patch_size, - in_chans=5, + in_chans=number_atmosphere_variables, embed_dim=embed_dim, ) patched_inp_shape = ( @@ -90,80 +126,44 @@ def __init__( math.ceil(img_size[1] / patch_size[2]), ) - self.layer1 = FuserLayer( - dim=embed_dim, - input_resolution=patched_inp_shape, - depth=2, - num_heads=num_heads[0], - window_size=window_size, - drop_path=drop_path[:2], + self.processor = PanguProcessor( + embed_dim, + patched_inp_shape, + num_heads, + window_size, + drop_path, + number_up_sampled_blocks, + checkpoint_flag, ) - patched_inp_shape_downsample = ( - 8, - math.ceil(patched_inp_shape[1] / 2), - math.ceil(patched_inp_shape[2] / 2), - ) - self.downsample = DownSample3D( - in_dim=embed_dim, - input_resolution=patched_inp_shape, - output_resolution=patched_inp_shape_downsample, - ) - self.layer2 = FuserLayer( - dim=embed_dim * 2, - input_resolution=patched_inp_shape_downsample, - depth=6, - num_heads=num_heads[1], - window_size=window_size, - drop_path=drop_path[2:], - ) - self.layer3 = FuserLayer( - dim=embed_dim * 2, - input_resolution=patched_inp_shape_downsample, - depth=6, - num_heads=num_heads[2], - window_size=window_size, - drop_path=drop_path[2:], - ) - self.upsample = UpSample3D( - embed_dim * 2, embed_dim, patched_inp_shape_downsample, patched_inp_shape - ) - self.layer4 = FuserLayer( - dim=embed_dim, - input_resolution=patched_inp_shape, - depth=2, - num_heads=num_heads[3], - window_size=window_size, - drop_path=drop_path[:2], - ) # The outputs of the 2nd encoder layer and the 7th decoder layer are concatenated along the channel dimension. self.patchrecovery2d = PatchRecovery2D( - img_size, patch_size[1:], 2 * embed_dim, 4 + img_size, patch_size[1:], 2 * embed_dim, self.number_surface_variables ) self.patchrecovery3d = PatchRecovery3D( - (13, img_size[0], img_size[1]), patch_size, 2 * embed_dim, 5 + (number_atmosphere_levels, img_size[0], img_size[1]), + patch_size, + 2 * embed_dim, + number_atmosphere_variables, ) - def prepare_input(self, surface, surface_mask, upper_air): - """Prepares the input to the model in the required shape. - Args: - surface (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=4. - surface_mask (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=3. - upper_air (torch.Tensor): 3D n_pl=13, n_lat=721, n_lon=1440, chans=5. - """ - upper_air = upper_air.reshape( - upper_air.shape[0], -1, upper_air.shape[3], upper_air.shape[4] - ) - surface_mask = surface_mask.unsqueeze(0).repeat(surface.shape[0], 1, 1, 1) - return torch.concat([surface, surface_mask, upper_air], dim=1) - def forward(self, x): """ Args: x (torch.Tensor): [batch, 4+3+5*13, lat, lon] """ - surface = x[:, :7, :, :] - upper_air = x[:, 7:, :, :].reshape(x.shape[0], 5, 13, x.shape[2], x.shape[3]) + surface = x[ + :, : self.number_constant_variables + self.number_surface_variables, :, : + ] + upper_air = x[ + :, self.number_constant_variables + self.number_surface_variables :, :, : + ].reshape( + x.shape[0], + self.number_air_variables, + self.number_air_levels, + x.shape[2], + x.shape[3], + ) surface = self.patchembed2d(surface) upper_air = self.patchembed3d(upper_air) @@ -171,21 +171,14 @@ def forward(self, x): B, C, Pl, Lat, Lon = x.shape x = x.reshape(B, C, -1).transpose(1, 2) - x = self.layer1(x) - - skip = x - - x = self.downsample(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.upsample(x) - x = self.layer4(x) + output = self.processor(x) - output = torch.concat([x, skip], dim=-1) output = output.transpose(1, 2).reshape(B, -1, Pl, Lat, Lon) output_surface = output[:, :, 0, :, :] output_upper_air = output[:, :, 1:, :, :] output_surface = self.patchrecovery2d(output_surface) output_upper_air = self.patchrecovery3d(output_upper_air) - return output_surface, output_upper_air + s = output_upper_air.shape + output_upper_air = output_upper_air.reshape(s[0], s[1] * s[2], *s[3:]) + return torch.concat([output_surface, output_upper_air], dim=1) diff --git a/modulus/models/pangu/pangu_processor.py b/modulus/models/pangu/pangu_processor.py new file mode 100644 index 0000000000..e641a5e56c --- /dev/null +++ b/modulus/models/pangu/pangu_processor.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch + +from ..layers import DownSample3D, FuserLayer, UpSample3D +from ..module import Module + + +class PanguProcessor(Module): + """ + Processor sub-component for the Pangu DLNWP model. This model contains the + layers corresponding to both the encoder and decoder portions of the 3D + Earth-Specific Transformer from the Pangu paper (see link below). + + Parameters + ---------- + embed_dim: int + Embedded dimension of the transformer layers. + patched_inp_shape: tuple[int] + Tuple containing the shape of the patched embedding inputs. + num_heads: tuple[int] + The number of attention heads for the contained transformers. + Expected to have 4 entries, corresponding to the 4 Fuser Layers. + window_size: tuple[int] + Window size in the Earth-Specific transformer. + drop_path: list + Stochastic depth rate + number_upsampled_blocks: int + The number of upsampling (and downsampling) blocks to use. + checkpoint_flag: bool + Whether to use gradient checkpointing during training. + """ + + def __init__( + self, + embed_dim: int, + patched_inp_shape: tuple[int], + num_heads: tuple[int], + window_size: tuple[int], + drop_path: list, + number_upsampled_blocks: int, + checkpoint_flag: bool, + ): + super().__init__() + + self.layer1 = FuserLayer( + dim=embed_dim, + input_resolution=patched_inp_shape, + depth=number_upsampled_blocks, + num_heads=num_heads[0], + window_size=window_size, + drop_path=drop_path[:number_upsampled_blocks], + checkpoint_flag=checkpoint_flag, + ) + + patched_inp_shape_downsample = ( + 8, + math.ceil(patched_inp_shape[1] / 2), + math.ceil(patched_inp_shape[2] / 2), + ) + + self.layers = torch.nn.Sequential( + DownSample3D( + in_dim=embed_dim, + input_resolution=patched_inp_shape, + output_resolution=patched_inp_shape_downsample, + ), + FuserLayer( + dim=embed_dim * 2, + input_resolution=patched_inp_shape_downsample, + depth=6, + num_heads=num_heads[1], + window_size=window_size, + drop_path=drop_path[number_upsampled_blocks:], + checkpoint_flag=checkpoint_flag, + ), + FuserLayer( + dim=embed_dim * 2, + input_resolution=patched_inp_shape_downsample, + depth=6, + num_heads=num_heads[2], + window_size=window_size, + drop_path=drop_path[number_upsampled_blocks:], + checkpoint_flag=checkpoint_flag, + ), + UpSample3D( + embed_dim * 2, + embed_dim, + patched_inp_shape_downsample, + patched_inp_shape, + ), + FuserLayer( + dim=embed_dim, + input_resolution=patched_inp_shape, + depth=2, + num_heads=num_heads[3], + window_size=window_size, + drop_path=drop_path[:number_upsampled_blocks], + checkpoint_flag=checkpoint_flag, + ), + ) + + self.checkpoint_flag = checkpoint_flag + + def forward(self, x: torch.Tensor) -> torch.Tensor: + "Forward model pass." + x = self.layer1(x) + + skip = x + + for layer in self.layers: + x = layer(x) + + return torch.concat([x, skip], dim=-1) diff --git a/modulus/models/utils/patch_embed.py b/modulus/models/utils/patch_embed.py index b73002c899..6c92c04ae1 100644 --- a/modulus/models/utils/patch_embed.py +++ b/modulus/models/utils/patch_embed.py @@ -132,7 +132,7 @@ def forward(self, x: torch.Tensor): B, C, L, H, W = x.shape x = self.pad(x) x = self.proj(x) - if self.norm: + if self.norm is not None: x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) return x diff --git a/modulus/models/utils/shift_window_mask.py b/modulus/models/utils/shift_window_mask.py index c67c19cc34..fedade0d6f 100644 --- a/modulus/models/utils/shift_window_mask.py +++ b/modulus/models/utils/shift_window_mask.py @@ -17,7 +17,7 @@ import torch -def window_partition(x: torch.Tensor, window_size, ndim=3): +def window_partition(x: torch.Tensor, window_size: tuple[int], ndim: int = 3): """ Args: x: (B, Pl, Lat, Lon, C) or (B, Lat, Lon, C) diff --git a/test/models/data/pangu_output.pth b/test/models/data/pangu_output.pth index f8d1e58dd4..8059afc8ef 100644 Binary files a/test/models/data/pangu_output.pth and b/test/models/data/pangu_output.pth differ diff --git a/test/models/test_pangu.py b/test/models/test_pangu.py index 3443faf6aa..9cdb1f294c 100644 --- a/test/models/test_pangu.py +++ b/test/models/test_pangu.py @@ -39,9 +39,9 @@ def test_pangu_forward(device): bsize = 2 invar_surface = torch.randn(bsize, 4, 32, 32).to(device) - invar_surface_mask = torch.randn(3, 32, 32).to(device) - invar_upper_air = torch.randn(bsize, 5, 13, 32, 32).to(device) - invar = model.prepare_input(invar_surface, invar_surface_mask, invar_upper_air) + invar_surface_mask = torch.randn(bsize, 3, 32, 32).to(device) + invar_upper_air = torch.randn(bsize, 5 * 13, 32, 32).to(device) + invar = torch.concat((invar_surface, invar_surface_mask, invar_upper_air), dim=1) # Check output size with torch.no_grad(): assert common.validate_forward_accuracy(model, (invar,), atol=5e-3) @@ -77,13 +77,17 @@ def test_pangu_constructor(device): bsize, 4, kw_args["img_size"][0], kw_args["img_size"][1] ).to(device) invar_surface_mask = torch.randn( - 3, kw_args["img_size"][0], kw_args["img_size"][1] + bsize, 3, kw_args["img_size"][0], kw_args["img_size"][1] ).to(device) invar_upper_air = torch.randn( - bsize, 5, 13, kw_args["img_size"][0], kw_args["img_size"][1] + bsize, 5 * 13, kw_args["img_size"][0], kw_args["img_size"][1] ).to(device) - invar = model.prepare_input(invar_surface, invar_surface_mask, invar_upper_air) - outvar_surface, outvar_upper_air = model(invar) + invar = torch.concat( + (invar_surface, invar_surface_mask, invar_upper_air), dim=1 + ) + outvar = model(invar) + outvar_surface = outvar[:, :4] + outvar_upper_air = outvar[:, 4:] assert outvar_surface.shape == ( bsize, 4, @@ -92,8 +96,7 @@ def test_pangu_constructor(device): ) assert outvar_upper_air.shape == ( bsize, - 5, - 13, + 5 * 13, kw_args["img_size"][0], kw_args["img_size"][1], ) @@ -116,23 +119,25 @@ def setup_model(): bsize = random.randint(1, 5) invar_surface = torch.randn(bsize, 4, 32, 32).to(device) - invar_surface_mask = torch.randn(3, 32, 32).to(device) - invar_upper_air = torch.randn(bsize, 5, 13, 32, 32).to(device) - invar = model.prepare_input(invar_surface, invar_surface_mask, invar_upper_air) + invar_surface_mask = torch.randn(bsize, 3, 32, 32).to(device) + invar_upper_air = torch.randn(bsize, 5 * 13, 32, 32).to(device) + invar = torch.concat( + (invar_surface, invar_surface_mask, invar_upper_air), dim=1 + ) return model, invar # Ideally always check graphs first model, invar = setup_model() assert common.validate_cuda_graphs(model, (invar,)) # Check JIT - # model, invar = setup_model() - # assert common.validate_jit(model, (invar,)) + model, invar = setup_model() + assert common.validate_jit(model, (invar,)) # Check AMP - # model, invar = setup_model() - # assert common.validate_amp(model, (invar,)) + model, invar = setup_model() + assert common.validate_amp(model, (invar,)) # Check Combo - # model, invar = setup_model() - # assert common.validate_combo_optims(model, (invar,)) + model, invar = setup_model() + assert common.validate_combo_optims(model, (invar,)) @common.check_ort_version() @@ -150,8 +155,8 @@ def test_pangu_deploy(device): bsize = random.randint(1, 5) invar_surface = torch.randn(bsize, 4, 32, 32).to(device) - invar_surface_mask = torch.randn(3, 32, 32).to(device) - invar_upper_air = torch.randn(bsize, 5, 13, 32, 32).to(device) - invar = model.prepare_input(invar_surface, invar_surface_mask, invar_upper_air) + invar_surface_mask = torch.randn(bsize, 3, 32, 32).to(device) + invar_upper_air = torch.randn(bsize, 5 * 13, 32, 32).to(device) + invar = torch.concat((invar_surface, invar_surface_mask, invar_upper_air), dim=1) assert common.validate_onnx_export(model, (invar,)) assert common.validate_onnx_runtime(model, (invar,))