diff --git a/Cell_BLAST/directi.py b/Cell_BLAST/directi.py index a9a0aa8..b069982 100644 --- a/Cell_BLAST/directi.py +++ b/Cell_BLAST/directi.py @@ -87,10 +87,6 @@ def __init__( if path is None: path = tempfile.mkdtemp() - else: - os.makedirs(path, exist_ok=True) - utils.logger.info("Using model path: %s", path) - random_seed = ( config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed ) @@ -162,6 +158,9 @@ def fit( tolerance: float = 0.0, progress_bar: bool = False, ): + os.makedirs(self.path, exist_ok=True) + utils.logger.info("Using model path: %s", self.path) + val_size = int(len(dataset) * val_split) train_size = len(dataset) - val_size train_dataset, val_dataset = torch.utils.data.random_split( @@ -423,6 +422,7 @@ def save( Name of the weights file """ if path is None: + os.makedirs(self.path, exist_ok=True) torch.save(self.get_config(), os.path.join(self.path, config)) torch.save(self.state_dict(), os.path.join(self.path, weights)) else: @@ -1049,11 +1049,6 @@ def align_DIRECTi( ) DIRECTi.ensure_reproducibility(random_seed) - if path is None: - path = tempfile.mkdtemp() - else: - os.makedirs(path, exist_ok=True) - if rmbatch_module_kwargs is None: rmbatch_module_kwargs = {} if rmbatch_module_kwargs is None: @@ -1097,6 +1092,7 @@ def align_DIRECTi( _config["prob_module"]["fine_tune"] = True _config["prob_module"]["deviation_reg"] = deviation_reg _config["learning_rate"] = learning_rate + _config["path"] = path aligned_model = DIRECTi.load_config(_config) if reuse_weights: