Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training loss does not converge after reloading checkpoint to continue training #2

Open
YitianShi opened this issue Mar 20, 2024 · 1 comment

Comments

@YitianShi
Copy link

Hi, I'm trying to reload my pre-trained network to continue training in Cifar-10 experiments cifar.py, while the loss does not converge after reloading (the loss still converges if the model is initialized without reloading). I guess maybe it's the issue of setting __base_optimizer as part of the optimizer state so when I run optimizer.load_state_dict(ckpt['optimizer_state_dict']) the state of base optimizer was directly replaced by the state in the optimizer_state_dict.

Now I solve the problem by making base_optimizer as class member of each optimizer of BNN algorithms, such as self.__base_optimizer = base_optimizer instead of self.state["__base_optimizer"] = base_optimizer. To this end, I will load the state_dict of the optimizer and its base optimizer separately and finally, the training loss converges after reloading.

Following is my code snippet for reloading:

def load_model(model_idx, model, scaler, optimizer, out_path, config, log):
    ckpt = None
    start_epoch = 0

    # Load checkpoint and scaler if available
    if config.get("use_checkpoint", None):
        try:
            ckpt_paths = glob.glob(out_path + f"{config['model']}_chkpt_{model_idx}_*.pth")
            ckpt_paths.sort(key=os.path.getmtime)
            ckpt = torch.load(ckpt_paths[-1]) 
            model.load_state_dict(ckpt['model_state_dict'])
            start_epoch = ckpt["epoch"] + 1
            scaler.load_state_dict(ckpt["scaler_state_dict"])
            log.info(f"Loaded checkpoint for model {model_idx} at epoch {start_epoch}")
        except:
            log.info(f"Failed to load checkpoint for model {model_idx}")

    optimizer.init_grad_scaler(scaler)

    # Load optimizer state if available
    # Base optimizer state is loaded separately if available
    if ckpt is not None: 
        try:
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            if ckpt.get("base_optimizer") is not None:
                optimizer.get_base_optimizer().load_state_dict(ckpt["base_optimizer"])
            log.info(f"Loaded base optimizer state for model {model_idx}")
        except:
            log.info(f"Failed to load optimizer state for model {model_idx}")

    # Load scheduler state if available
    if config["lr_schedule"]:
        scheduler = wilson_scheduler(optimizer.get_base_optimizer(), config["epochs"], config["lr"], None)
        if ckpt is not None:
            scheduler.load_state_dict(ckpt["scheduler_state_dict"])
            log.info(f"Loaded scheduler state for model {model_idx}")
    else:
        scheduler = None
    
    return start_epoch, model, optimizer, scaler, scheduler

and how I save the model during training:

def save_model(model, optimizer, scheduler, scaler, out_path, config, model_idx, epoch):
    state_dict = {
                    'epoch': epoch,
                    'model_idx': model_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else 'None'
                }
    if hasattr(optimizer, "get_base_optimizer"):
        state_dict['base_optimizer'] = optimizer.get_base_optimizer().state_dict()
    torch.save(state_dict, out_path + f"{config['model']}_chkpt_{model_idx}_{epoch}.pth")

an example for the optimizer change is:

class MAPOptimizer(BayesianOptimizer):
    '''
        Maximum A Posteriori

        This simply optimizes a point estimate of the parameters with the given base_optimizer.
    '''

    def __init__(self, params, base_optimizer):
        super().__init__(params, {})
        # self.state["__base_optimizer"] = base_optimizer
        self.__base_optimizer = base_optimizer

Since I'm still looking into other optimizers, it could be a great help if you can inform me of any potential problems of doing so. Thank you very much!

@Feuermagier
Copy link
Owner

Hi,

thanks for raising that issue! I think your change is good and makes sense, and shouldn't introduce any problems down the line. If I find time, I will fix it across the entire repository. If you want you can also open a pull request with the change.

Thanks for your interest in the repository!
Florian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants