Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smarter Overwrite #1191

Merged
merged 16 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import aspire.volume
from aspire.nufft import anufft, nufft
from aspire.numeric import fft, xp
from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d
from aspire.utils import (
FourierRingCorrelation,
anorm,
crop_pad_2d,
grid_2d,
rename_with_timestamp,
)
from aspire.volume import SymmetryGroup

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -484,10 +490,25 @@ def filter(self, filter):
def rotate(self):
raise NotImplementedError

def save(self, mrcs_filepath, overwrite=False):
def save(self, mrcs_filepath, overwrite=None):
"""
Save Image to disk as mrcs file

:param filename: Filepath where Image will be saved.
:param overwrite: Options to control overwrite behavior (default is None):
- True: Overwrites the existing file if it exists.
- False: Raises an error if the file exists.
- None: Renames the old file by appending a time/date stamp.
"""
if self.stack_ndim > 1:
raise NotImplementedError("`save` is currently limited to 1D image stacks.")

if overwrite is None and os.path.exists(mrcs_filepath):
# If the file exists, append a timestamp to the old file and rename it
_ = rename_with_timestamp(mrcs_filepath)
elif overwrite is None:
overwrite = False

with mrcfile.new(mrcs_filepath, overwrite=overwrite) as mrc:
# original input format (the image index first)
mrc.set_data(self._data.astype(np.float32))
Expand Down
25 changes: 22 additions & 3 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
PowerFilter,
)
from aspire.storage import MrcStats, StarFile
from aspire.utils import Rotation, grid_2d, support_mask, trange
from aspire.utils import Rotation, grid_2d, rename_with_timestamp, support_mask, trange
from aspire.volume import IdentitySymmetryGroup, SymmetryGroup

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -976,7 +976,7 @@ def save(
starfile_filepath,
batch_size=512,
save_mode=None,
overwrite=False,
overwrite=None,
):
"""
Save the output metadata to STAR file and/or images to MRCS file.
Expand All @@ -988,10 +988,29 @@ def save(
while `batch_size>=1` implies stack MRC extension `.mrcs`.
:param save_mode: Whether to save all images in a `single` or multiple files in batch size.
Default is multiple, supply `'single'` for single mode.
:param overwrite: Option to overwrite the output MRC files.
:param overwrite: Options to control overwrite behavior (default is None):
- True: Overwrites the existing file if it exists.
- False: Raises an error if the file exists.
- None: Renames the old file by appending a time/date stamp.
:return: A dictionary containing "starfile"--the path to the saved starfile-- and "mrcs", a
list of the saved particle stack MRC filenames.
"""
if overwrite is None and os.path.exists(starfile_filepath):
# If the file exists, append the timestamp to the old file and rename it
renamed_filepath = rename_with_timestamp(starfile_filepath, move=False)

# Retrieve original ImageSource and save with new starfile name.
from aspire.source import RelionSource

src = RelionSource(starfile_filepath)
src.save(renamed_filepath, overwrite=False)

# Allow overwriting old files.
overwrite = True

elif overwrite is None:
overwrite = False

logger.info("save metadata into STAR file")
filename_indices = self.save_metadata(
starfile_filepath,
Expand Down
15 changes: 11 additions & 4 deletions src/aspire/source/micrograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aspire.source import Simulation
from aspire.source.image import _ImageAccessor
from aspire.storage import StarFile
from aspire.utils import Random, grid_2d
from aspire.utils import Random, grid_2d, rename_with_timestamp
from aspire.volume import Volume

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,7 +44,7 @@ def __len__(self):
"""
return self.micrograph_count

def save(self, path, name_prefix="micrograph", overwrite=True):
def save(self, path, name_prefix="micrograph", overwrite=None):
"""
Save micrographs to `path`.

Expand All @@ -54,11 +54,18 @@ def save(self, path, name_prefix="micrograph", overwrite=True):

:param path: Directory to save data.
:param name_prefix: Optional, name prefix string for micrograph files.
:param overwrite: Optional, bool. Allow writing to existing directory,
and overwriting existing files.
:param overwrite: Options to control overwrite behavior (default is None):
- True: Overwrites the existing path if it exists.
- False: Raises an error if the path exists.
- None: Renames the old path by appending a time/date stamp.
:return: List of saved `.mrc` files.
"""

if overwrite is None and os.path.exists(path):
# If the directory exists, append a timestamp to existing directory.
_ = rename_with_timestamp(path)
overwrite = True

# Make dir if does not exist.
Path(path).mkdir(parents=True, exist_ok=overwrite)

Expand Down
1 change: 1 addition & 0 deletions src/aspire/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
inverse_r,
J_conjugate,
powerset,
rename_with_timestamp,
sha256sum,
support_mask,
fuzzy_mask,
Expand Down
27 changes: 27 additions & 0 deletions src/aspire/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import hashlib
import importlib.resources
import logging
import os
import sys
from datetime import datetime
from itertools import chain, combinations

import numpy as np
Expand Down Expand Up @@ -48,6 +50,31 @@ def importlib_path(package, resource):
return p


def rename_with_timestamp(filepath, move=True):
"""
Rename a file by appending a timestamp to the end of the filename.

:param filepath: Filepath to rename.
:param move: Option to rename the file on disk.

:return: filepath with timestamp appended.
"""
base, ext = os.path.splitext(filepath)
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
renamed_filepath = f"{base}_{timestamp}{ext}"
logger.info(f"Renaming {filepath} as {renamed_filepath}.")

# Rename the existing file by appending the timestamp.
if move:
garrettwrong marked this conversation as resolved.
Show resolved Hide resolved
try:
os.rename(filepath, renamed_filepath)
except FileNotFoundError:
logger.warning(f"File '{filepath}' not found, could not rename.")
return None

return renamed_filepath


def abs2(x):
"""
Compute complex modulus squared.
Expand Down
19 changes: 14 additions & 5 deletions src/aspire/volume/volume.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import warnings

import mrcfile
Expand All @@ -16,6 +17,7 @@
grid_2d,
grid_3d,
mat_to_vec,
rename_with_timestamp,
vec_to_mat,
)
from aspire.volume import IdentitySymmetryGroup, SymmetryGroup
Expand Down Expand Up @@ -635,20 +637,27 @@ def rotate(self, rot_matrices, zero_nyquist=True):
def denoise(self):
raise NotImplementedError

def save(self, filename, overwrite=False):
def save(self, filename, overwrite=None):
"""
Save volume to disk as mrc file

:param filename: Filepath where volume will be saved

:param overwrite: Option to overwrite file when set to True.
Defaults to overwrite=False.
:param filename: Filepath where volume will be saved.
:param overwrite: Options to control overwrite behavior (default is None):
- True: Overwrites the existing file if it exists.
- False: Raises an error if the file exists.
- None: Renames the old file by appending a time/date stamp.
"""
if self.stack_ndim > 1:
raise NotImplementedError(
"`save` is currently limited to 1D Volume stacks."
)

if overwrite is None and os.path.exists(filename):
# If the file exists, append a timestamp to the old file and rename it
_ = rename_with_timestamp(filename)
elif overwrite is None:
overwrite = False

with mrcfile.new(filename, overwrite=overwrite) as mrc:
mrc.set_data(self._data.astype(np.float32))
# Note assigning voxel_size must come after `set_data`
Expand Down
65 changes: 65 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os.path
import tempfile
from datetime import datetime
from unittest import mock

import mrcfile
import numpy as np
Expand Down Expand Up @@ -353,6 +355,69 @@ def test_asnumpy_readonly():
vw[0, 0, 0] = 123


def test_save_overwrite(caplog):
"""
Test that the overwrite flag behaves as expected.
- overwrite=True: Overwrites the existing file.
- overwrite=False: Raises an error if the file exists.
- overwrite=None: Renames the existing file and saves the new one.
"""
im1 = Image(np.ones((1, 8, 8), dtype=np.float32))
im2 = Image(2 * np.ones((1, 8, 8), dtype=np.float32))
im3 = Image(3 * np.ones((1, 8, 8), dtype=np.float32))

# Create a tmp dir for this test output
with tempfile.TemporaryDirectory() as tmpdir_name:
# tmp filename
mrc_path = os.path.join(tmpdir_name, "og.mrc")
base, ext = os.path.splitext(mrc_path)

# Create and save the first image
im1.save(mrc_path, overwrite=True)

# Case 1: overwrite=True (should overwrite the existing file)
im2.save(mrc_path, overwrite=True)

# Load and check if im2 has overwritten im1
im2_loaded = Image.load(mrc_path)
np.testing.assert_allclose(im2.asnumpy(), im2_loaded.asnumpy())

# Case 2: overwrite=False (should raise an overwrite error)
with pytest.raises(
ValueError,
match="File '.*' already exists; set overwrite=True to overwrite it",
):
im3.save(mrc_path, overwrite=False)

# Case 3: overwrite=None (should rename the existing file and save im3 with original filename)
# Mock datetime to return a fixed timestamp.
mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0)
with mock.patch("aspire.utils.misc.datetime") as mock_datetime:
mock_datetime.now.return_value = mock_datetime_value
mock_datetime.strftime = datetime.strftime

with caplog.at_level(logging.INFO):
im3.save(mrc_path, overwrite=None)

# Check that the existing file was renamed and logged
assert f"Renaming {mrc_path}" in caplog.text

# Construct the expected renamed filename using the mock timestamp
mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S")
renamed_file = f"{base}_{mock_timestamp}{ext}"

# Assert that the renamed file exists
assert os.path.exists(renamed_file), "Renamed file not found"

# Load and check that im3 was saved to the original path
im3_loaded = Image.load(mrc_path)
np.testing.assert_allclose(im3.asnumpy(), im3_loaded.asnumpy())

# Also check that the renamed file still contains im2's data
im2_loaded_renamed = Image.load(renamed_file)
np.testing.assert_allclose(im2.asnumpy(), im2_loaded_renamed.asnumpy())


def test_corrupt_mrc_load(caplog):
"""
Test that corrupt mrc files are logged as expected.
Expand Down
77 changes: 77 additions & 0 deletions tests/test_micrograph_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os
import tempfile
from datetime import datetime
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -303,6 +305,81 @@ def test_sim_save():
)


def test_save_overwrite(caplog):
"""
Tests MicrographSimulation.save functionality.

Specifically tests interoperability with CentersCoordinateSource
"""

v = AsymmetricVolume(L=16, C=1, dtype=np.float64).generate()
ctfs = [
RadialCTFFilter(
pixel_size=4, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0
)
]

mg_sim = MicrographSimulation(
volume=v,
particles_per_micrograph=3,
interparticle_distance=v.resolution,
micrograph_count=2,
micrograph_size=512,
ctf_filters=ctfs,
)

mg_sim_new = MicrographSimulation(
volume=v,
particles_per_micrograph=4,
interparticle_distance=v.resolution,
micrograph_count=3,
micrograph_size=512,
ctf_filters=ctfs,
)

with tempfile.TemporaryDirectory() as tmp_output_dir:
path = os.path.join(tmp_output_dir, "test")

# Write MRC and STAR files
save_paths_1 = mg_sim.save(path, overwrite=True)

# Case 1: overwrite=True (should overwrite the existing file)
save_paths_2 = mg_sim.save(path, overwrite=True)
np.testing.assert_array_equal(save_paths_1, save_paths_2)

# Case2: overwrite=False (should raise error)
with pytest.raises(FileExistsError):
_ = mg_sim.save(path, overwrite=False)

# Case 3: overwrite=None (should rename the existing directory)
mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0)
with mock.patch("aspire.utils.misc.datetime") as mock_datetime:
mock_datetime.now.return_value = mock_datetime_value
mock_datetime.strftime = datetime.strftime

with caplog.at_level(logging.INFO):
_ = mg_sim_new.save(path, overwrite=None)

# Check that the existing directory was renamed and logged
assert f"Renaming {path}" in caplog.text
assert os.path.exists(path), "Directory not found"

# Construct the expected renamed directory using the mock timestamp
mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S")
renamed_dir = f"{path}_{mock_timestamp}"

# Assert that the renamed file exists
assert os.path.exists(renamed_dir), "Renamed directory not found"

# Load renamed directory and check images against orignal sim.
mg_src = DiskMicrographSource(renamed_dir)
np.testing.assert_allclose(mg_src.asnumpy(), mg_sim.asnumpy())

# Load new directory and check images against orignal sim.
mg_src_new = DiskMicrographSource(path)
np.testing.assert_allclose(mg_src_new.asnumpy(), mg_sim_new.asnumpy())


def test_bad_amplitudes(vol_fixture):
"""
Test incorrect `particle_amplitudes` argument raises.
Expand Down
Loading
Loading