Skip to content

Commit

Permalink
Replaced decorator skipUnlessModule (#510)
Browse files Browse the repository at this point in the history
* Replaced decorator `skipUnlessModule`
  • Loading branch information
david-zwicker authored Jan 3, 2024
1 parent 52e6337 commit 4a10dd6
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pde/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
plotting
spectral
typing
.. codeauthor:: David Zwicker <[email protected]>
"""
2 changes: 1 addition & 1 deletion pde/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
check_package_version
packages_from_requirements
environment
.. codeauthor:: David Zwicker <[email protected]>
"""

Expand Down
9 changes: 8 additions & 1 deletion pde/tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
ensure_directory_exists
preserve_scalars
decorator_arguments
skipUnlessModule
import_class
classproperty
hybridmethod
Expand All @@ -29,6 +28,7 @@
import json
import os
import unittest
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar

Expand Down Expand Up @@ -135,6 +135,13 @@ def skipUnlessModule(module_names: str | Sequence[str]) -> Callable[[TFunc], TFu
Returns:
A function, so this can be used as a decorator
"""
# deprecated since 2024-01-03
warnings.warn(
"`skipUnlessModule` is deprecated. Use "
'`@pytest.mark.skipif(not module_available("module"))` instead.',
DeprecationWarning,
)

if isinstance(module_names, str):
module_names = [module_names]

Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_field_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fixtures.fields import iter_grids
from pde import FieldCollection, ScalarField, Tensor2Field, UnitGrid, VectorField
from pde.fields.base import FieldBase
from pde.tools.misc import skipUnlessModule
from pde.tools.misc import module_available


@pytest.mark.parametrize("grid", iter_grids())
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_from_scalar_expressions():
np.testing.assert_allclose(fc[1].data, 1)


@skipUnlessModule("napari")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.interactive
def test_interactive_collection_plotting(rng):
"""test the interactive plotting"""
Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_generic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SphericalSymGrid,
UnitGrid,
)
from pde.tools.misc import skipUnlessModule
from pde.tools.misc import module_available


@pytest.mark.parametrize("field_class", [ScalarField, VectorField, Tensor2Field])
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_complex_fields(field_class, rng):
assert field_copy.dtype == np.dtype("complex")


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("h5py"), reason="requires `h5py` module")
def test_hdf_input_output(tmp_path, rng):
"""test writing and reading files"""
grid = UnitGrid([4, 4])
Expand Down
5 changes: 2 additions & 3 deletions tests/fields/test_scalar_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pde.grids import CartesianGrid, PolarSymGrid, UnitGrid, boundaries
from pde.grids._mesh import GridMesh
from pde.tools import mpi
from pde.tools.misc import module_available, skipUnlessModule
from pde.tools.misc import module_available


def test_interpolation_singular():
Expand Down Expand Up @@ -240,7 +240,6 @@ def f(x, y):
np.testing.assert_allclose(sf.data, [[0.25, 0.75]])


@skipUnlessModule("matplotlib")
def test_from_image(tmp_path, rng):
from matplotlib.pyplot import imsave

Expand Down Expand Up @@ -410,7 +409,7 @@ def test_plotting_2d(rng):
field._update_plot(ref)


@skipUnlessModule("napari")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.interactive
def test_interactive_plotting(rng):
"""test the interactive plotting"""
Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_vectorial_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pde import CartesianGrid, ScalarField, Tensor2Field, UnitGrid, VectorField
from pde.fields.base import FieldBase
from pde.tools.misc import module_available, skipUnlessModule
from pde.tools.misc import module_available


def test_vectors_basic():
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_vector_plotting_2d(transpose, rng):
field.get_vector_data(transpose=transpose, max_points=7)


@skipUnlessModule("napari")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.interactive
def test_interactive_vector_plotting(rng):
"""test the interactive plotting"""
Expand Down
4 changes: 2 additions & 2 deletions tests/grids/operators/test_cartesian_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from pde.grids.operators import cartesian as ops
from pde.grids.operators.common import make_laplace_from_matrix
from pde.tools.misc import skipUnlessModule
from pde.tools.misc import module_available

π = np.pi

Expand Down Expand Up @@ -85,7 +85,7 @@ def test_laplace_1d(periodic, rng):
np.testing.assert_allclose(l1.data, l2.data)


@skipUnlessModule("rocket_fft")
@pytest.mark.skipif(not module_available("rocket_fft"), reason="requires `rocket_fft`")
@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("dtype", [float, complex])
def test_laplace_spectral(ndim, dtype, rng):
Expand Down
2 changes: 0 additions & 2 deletions tests/grids/test_generic_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
discretize_interval,
registered_operators,
)
from pde.tools.misc import skipUnlessModule


def iter_grids():
Expand Down Expand Up @@ -118,7 +117,6 @@ def test_integration_serial(grid, rng):
assert res == pytest.approx(grid.integrate(arr, axes=range(grid.num_axes)))


@skipUnlessModule("matplotlib")
def test_grid_plotting():
"""test plotting of grids"""
grids.UnitGrid([4]).plot()
Expand Down
14 changes: 7 additions & 7 deletions tests/storage/test_file_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import pde
from pde import DiffusionPDE, FileStorage, ScalarField, UnitGrid
from pde.tools.misc import skipUnlessModule
from pde.tools.misc import module_available


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.parametrize("collection", [True, False])
def test_storage_persistence(collection, tmp_path):
"""test writing to persistent trackers"""
Expand Down Expand Up @@ -76,7 +76,7 @@ def assert_storage_content(storage, expect):
assert info.items() <= reader.info.items()


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.parametrize("compression", [True, False])
def test_simulation_persistence(compression, tmp_path, rng):
"""test whether a tracker can accurately store information about simulation"""
Expand All @@ -101,7 +101,7 @@ def test_simulation_persistence(compression, tmp_path, rng):
assert grid == grid_res


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.parametrize("compression", [True, False])
def test_storage_fixed_size(compression, tmp_path):
"""test setting fixed size of FileStorage objects"""
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_storage_fixed_size(compression, tmp_path):
np.testing.assert_allclose(storage.times, [0, 1])


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
def test_appending(tmp_path):
"""test the appending data"""
path = tmp_path / "test_appending.hdf5"
Expand All @@ -152,7 +152,7 @@ def test_appending(tmp_path):
assert len(storage2) == 2


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
def test_keep_opened(tmp_path):
"""test the keep opened option"""
path = tmp_path / "test_keep_opened.hdf5"
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_keep_opened(tmp_path):
np.testing.assert_allclose(storage2.times, np.arange(3))


@skipUnlessModule("h5py")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.parametrize("dtype", [bool, float, complex])
def test_write_types(dtype, tmp_path, rng):
"""test whether complex data can be written"""
Expand Down
7 changes: 5 additions & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pytest

from pde.tools.misc import module_available, skipUnlessModule
from pde.tools.misc import module_available
from pde.visualization.movies import Movie

PACKAGE_PATH = Path(__file__).resolve().parents[1]
Expand Down Expand Up @@ -67,7 +67,10 @@ def test_example_scripts(path):

@pytest.mark.slow
@pytest.mark.no_cover
@skipUnlessModule(["h5py", "jupyter", "notebook", "nbconvert"])
@pytest.mark.skipif(not module_available("h5py"), reason="requires `h5py`")
@pytest.mark.skipif(not module_available("jupyter"), reason="requires `jupyter`")
@pytest.mark.skipif(not module_available("notebook"), reason="requires `notebook`")
@pytest.mark.skipif(not module_available("nbconvert"), reason="requires `nbconvert`")
@pytest.mark.parametrize("path", NOTEBOOKS)
def test_jupyter_notebooks(path, tmp_path):
"""run the jupyter notebooks"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pde.tools import misc, mpi, numba


@misc.skipUnlessModule("h5py")
@pytest.mark.skipif(not misc.module_available("h5py"), reason="requires `h5py` module")
def test_writing_to_storage(tmp_path, rng):
"""test whether data is written to storage"""
state = ScalarField.random_uniform(UnitGrid([3]), rng=rng)
Expand Down
8 changes: 7 additions & 1 deletion tests/tools/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def value(cls): # @NoSelf
assert Test.value == 2


@misc.skipUnlessModule("h5py")
@pytest.mark.skipif(not misc.module_available("h5py"), reason="requires `h5py` module")
def test_hdf_write_attributes(tmp_path):
"""test hdf_write_attributes function"""
import h5py
Expand Down Expand Up @@ -110,3 +110,9 @@ def test_hdf_write_attributes(tmp_path):
misc.hdf_write_attributes(
hdf_file, {"a": object()}, raise_serialization_error=True
)


@misc.skipUnlessModule("undefined_module_name")
def test_skipUnlessModule():
"""test skipUnlessModule decorator"""
raise RuntimeError # test should never run
7 changes: 2 additions & 5 deletions tests/visualization/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from pde.fields import FieldCollection, ScalarField
from pde.grids import UnitGrid
from pde.storage import get_memory_storage
from pde.tools.misc import skipUnlessModule
from pde.tools.misc import module_available
from pde.visualization import plotting


@skipUnlessModule("matplotlib")
def test_scalar_field_plot(tmp_path, rng):
"""test ScalarFieldPlot class"""
path = tmp_path / "test_scalar_field_plot.png"
Expand All @@ -28,7 +27,6 @@ def test_scalar_field_plot(tmp_path, rng):
assert path.stat().st_size > 0


@skipUnlessModule("matplotlib")
def test_scalar_plot(tmp_path, rng):
"""test Simple simulation"""
path = tmp_path / "test_scalar_plot.png"
Expand All @@ -48,7 +46,6 @@ def test_scalar_plot(tmp_path, rng):
assert path.stat().st_size > 0


@skipUnlessModule("matplotlib")
def test_collection_plot(tmp_path):
"""test Simple simulation"""
# create some data
Expand Down Expand Up @@ -107,7 +104,7 @@ def test_kymograph_collection(tmp_path):
assert path.stat().st_size > 0


@skipUnlessModule("napari")
@pytest.mark.skipif(not module_available("napari"), reason="requires `napari` module")
@pytest.mark.interactive
def test_interactive_plotting(rng):
"""test plot_interactive"""
Expand Down

0 comments on commit 4a10dd6

Please sign in to comment.