Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff1995 committed Feb 7, 2023
1 parent 6d4948b commit 680c803
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions Cell_BLAST/directi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ def __init__(

if path is None:
path = tempfile.mkdtemp()
else:
os.makedirs(path, exist_ok=True)
utils.logger.info("Using model path: %s", path)

random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
Expand Down Expand Up @@ -162,6 +158,9 @@ def fit(
tolerance: float = 0.0,
progress_bar: bool = False,
):
os.makedirs(self.path, exist_ok=True)
utils.logger.info("Using model path: %s", self.path)

val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
Expand Down Expand Up @@ -423,6 +422,7 @@ def save(
Name of the weights file
"""
if path is None:
os.makedirs(self.path, exist_ok=True)
torch.save(self.get_config(), os.path.join(self.path, config))
torch.save(self.state_dict(), os.path.join(self.path, weights))
else:
Expand Down Expand Up @@ -1049,11 +1049,6 @@ def align_DIRECTi(
)
DIRECTi.ensure_reproducibility(random_seed)

if path is None:
path = tempfile.mkdtemp()
else:
os.makedirs(path, exist_ok=True)

if rmbatch_module_kwargs is None:
rmbatch_module_kwargs = {}
if rmbatch_module_kwargs is None:
Expand Down Expand Up @@ -1097,6 +1092,7 @@ def align_DIRECTi(
_config["prob_module"]["fine_tune"] = True
_config["prob_module"]["deviation_reg"] = deviation_reg
_config["learning_rate"] = learning_rate
_config["path"] = path

aligned_model = DIRECTi.load_config(_config)
if reuse_weights:
Expand Down

0 comments on commit 680c803

Please sign in to comment.