diff --git a/docs/source/conf.py b/docs/source/conf.py
index a17e4567..cbe98263 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -15,9 +15,9 @@
import sys
sys.path.insert(0, ".")
-sys.path.insert(0, os.path.abspath("../.."))
-sys.path.insert(0, os.path.abspath("../../scripts"))
-sys.path.insert(0, os.path.abspath("../sphinx_ext/"))
+sys.path.insert(0, os.path.abspath("../..")) # noqa: PTH100
+sys.path.insert(0, os.path.abspath("../../scripts")) # noqa: PTH100
+sys.path.insert(0, os.path.abspath("../sphinx_ext/")) # noqa: PTH100
from datetime import date
@@ -26,7 +26,7 @@
project = "py-pde"
module_name = "pde"
author = "Zwicker Group"
-copyright = f"{date.today().year}, {author}" # @ReservedAssignment
+copyright = f"{date.today().year}, {author}" # @ReservedAssignment # noqa: A001
html_logo = "_images/logo_small.png"
# Determine the version from the actual package
diff --git a/docs/source/manual/contributing.rst b/docs/source/manual/contributing.rst
index 28a6ad73..c1f182cd 100644
--- a/docs/source/manual/contributing.rst
+++ b/docs/source/manual/contributing.rst
@@ -55,8 +55,9 @@ where ``vector_component`` is either 0 or 1.
Coding style
""""""""""""
-The coding style is enforced using `isort `_
-and `black `_. Moreover, we use `Google Style docstrings
+The coding style is enforced using `ruff `_, based on the
+styles suggest by `isort `_ and
+`black `_. Moreover, we use `Google Style docstrings
`_,
which might be best `learned by example
`_.
diff --git a/docs/source/run_autodoc.py b/docs/source/run_autodoc.py
index ae7fc760..c810af80 100755
--- a/docs/source/run_autodoc.py
+++ b/docs/source/run_autodoc.py
@@ -4,6 +4,7 @@
import logging
import os
import subprocess as sp
+from pathlib import Path
logging.basicConfig(level=logging.INFO)
@@ -31,21 +32,21 @@ def replace_in_file(infile, replacements, outfile=None):
if outfile is None:
outfile = infile
- with open(infile) as fp:
+ with Path(infile).open() as fp:
content = fp.read()
for key, value in replacements.items():
content = content.replace(key, value)
- with open(outfile, "w") as fp:
+ with Path(outfile).open("w") as fp:
fp.write(content)
def main():
# remove old files
- for path in glob.glob(f"{OUTPUT_PATH}/*.rst"):
+ for path in Path(OUTPUT_PATH).glob("*.rst"):
logging.info("Remove file `%s`", path)
- os.remove(path)
+ path.unlink()
# run sphinx-apidoc
sp.check_call(
@@ -65,7 +66,7 @@ def main():
)
# replace unwanted information
- for path in glob.glob(f"{OUTPUT_PATH}/*.rst"):
+ for path in Path(OUTPUT_PATH).glob("*.rst"):
logging.info("Patch file `%s`", path)
replace_in_file(path, REPLACEMENTS)
diff --git a/examples/pde_brusselator_class.py b/examples/pde_brusselator_class.py
index 47ffce9a..c0345500 100644
--- a/examples/pde_brusselator_class.py
+++ b/examples/pde_brusselator_class.py
@@ -27,11 +27,11 @@
class BrusselatorPDE(PDEBase):
"""Brusselator with diffusive mobility."""
- def __init__(self, a=1, b=3, diffusivity=[1, 0.1], bc="auto_periodic_neumann"):
+ def __init__(self, a=1, b=3, diffusivity=None, bc="auto_periodic_neumann"):
super().__init__()
self.a = a
self.b = b
- self.diffusivity = diffusivity # spatial mobility
+ self.diffusivity = [1, 0.1] if diffusivity is None else diffusivity
self.bc = bc # boundary condition
def get_initial_state(self, grid):
diff --git a/pde/__init__.py b/pde/__init__.py
index 1f0621bb..bfe5c750 100644
--- a/pde/__init__.py
+++ b/pde/__init__.py
@@ -4,7 +4,7 @@
# determine the package version
try:
# try reading version of the automatically generated module
- from ._version import __version__ # type: ignore
+ from ._version import __version__
except ImportError:
# determine version automatically from CVS information
from importlib.metadata import PackageNotFoundError, version
@@ -20,7 +20,8 @@
from .tools.config import Config, environment
config = Config() # initialize the default configuration
-del Config # clean name space
+
+import contextlib
# import all other modules that should occupy the main name space
from .fields import * # @UnusedWildImport
@@ -32,7 +33,7 @@
from .trackers import * # @UnusedWildImport
from .visualization import * # @UnusedWildImport
-try:
+with contextlib.suppress(ImportError):
from .tools.modelrunner import *
-except ImportError:
- pass # modelrunner extensions are simply not loaded
+
+del contextlib, Config # clean name space
diff --git a/pde/fields/datafield_base.py b/pde/fields/datafield_base.py
index e48c0778..961689df 100644
--- a/pde/fields/datafield_base.py
+++ b/pde/fields/datafield_base.py
@@ -1131,7 +1131,7 @@ def get_vector_data(self, transpose: bool = False, **kwargs) -> dict[str, Any]:
Returns:
dict: Information useful for plotting an vector field
"""
- raise NotImplementedError()
+ raise NotImplementedError
def _plot_line(
self,
diff --git a/pde/fields/scalar.py b/pde/fields/scalar.py
index 7f8708d6..215582b1 100644
--- a/pde/fields/scalar.py
+++ b/pde/fields/scalar.py
@@ -372,7 +372,7 @@ def slice(
raise ValueError(
f"The axes {ax} is not contained in "
f"{self.grid} with axes {self.grid.axes}"
- )
+ ) from None
ax_remove.append(i)
# check the position
diff --git a/pde/fields/tensorial.py b/pde/fields/tensorial.py
index 62e7aa79..9118c892 100644
--- a/pde/fields/tensorial.py
+++ b/pde/fields/tensorial.py
@@ -124,8 +124,8 @@ def _get_axes_index(self, key: tuple[int | str, int | str]) -> tuple[int, int]:
try:
if len(key) != 2:
raise IndexError("Index must be given as two integers")
- except TypeError:
- raise IndexError("Index must be given as two values")
+ except TypeError as err:
+ raise IndexError("Index must be given as two values") from err
return tuple(self.grid.get_axis_index(k) for k in key) # type: ignore
def __getitem__(self, key: tuple[int | str, int | str]) -> ScalarField:
diff --git a/pde/fields/vectorial.py b/pde/fields/vectorial.py
index 272ae2bf..e6c4c16f 100644
--- a/pde/fields/vectorial.py
+++ b/pde/fields/vectorial.py
@@ -317,8 +317,8 @@ def check_rank(arr: nb.types.Type | nb.types.Optional) -> None:
@register_jitable
def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> np.ndarray:
"""Calculate outer product between fields `a` and `b`"""
- for i in range(0, dim):
- for j in range(0, dim):
+ for i in range(dim):
+ for j in range(dim):
out[i, j, :] = a[i] * b[j]
return out
diff --git a/pde/grids/base.py b/pde/grids/base.py
index f366bef0..f17a33fd 100644
--- a/pde/grids/base.py
+++ b/pde/grids/base.py
@@ -1021,7 +1021,7 @@ def contains_point(
the grid
"""
cell_coords = self.transform(points, source=coords, target="cell", full=full)
- return np.all((0 <= cell_coords) & (cell_coords <= self.shape), axis=-1) # type: ignore
+ return np.all((cell_coords >= 0) & (cell_coords <= self.shape), axis=-1) # type: ignore
def iter_mirror_points(
self, point: np.ndarray, with_self: bool = False, only_periodic: bool = True
@@ -1222,6 +1222,7 @@ def register_operator(factor_func_arg: OperatorFactory):
else:
# method is used directly
register_operator(factory_func)
+ return None
@hybridmethod # type: ignore
@property
diff --git a/pde/grids/boundaries/axes.py b/pde/grids/boundaries/axes.py
index cd4e743d..a323d0e3 100644
--- a/pde/grids/boundaries/axes.py
+++ b/pde/grids/boundaries/axes.py
@@ -264,7 +264,7 @@ def __setitem__(self, index, data) -> None:
else:
# handle all other cases, in particular integer indices
- return super().__setitem__(index, data)
+ super().__setitem__(index, data)
def get_mathematical_representation(self, field_name: str = "C") -> str:
"""Return mathematical representation of the boundary condition."""
@@ -303,7 +303,7 @@ def set_ghost_cells(
for i, j in itertools.product([0, -1], [0, -1]):
d[..., i, j] = (d[..., nxt[i], j] + d[..., i, nxt[j]]) / 2
- elif self.grid.num_axes >= 3:
+ elif self.grid.num_axes == 3:
# iterate all edges
for i, j in itertools.product([0, -1], [0, -1]):
d[..., :, i, j] = (+d[..., :, nxt[i], j] + d[..., :, i, nxt[j]]) / 2
@@ -316,8 +316,9 @@ def set_ghost_cells(
+ d[..., i, nxt[j], k]
+ d[..., i, j, nxt[k]]
) / 3
- else:
- logging.getLogger(self.__class__.__name__).warning(
+
+ elif self.grid.num_axes > 3:
+ raise NotImplementedError(
f"Can't interpolate corners for grid with {self.grid.num_axes} axes"
)
diff --git a/pde/grids/boundaries/axis.py b/pde/grids/boundaries/axis.py
index 5d928209..61805c45 100644
--- a/pde/grids/boundaries/axis.py
+++ b/pde/grids/boundaries/axis.py
@@ -287,7 +287,7 @@ def from_data(cls, grid: GridBase, axis: int, data, rank: int = 0) -> BoundaryPa
# if len is not supported, the format must be wrong
raise BCDataError(
f"Unsupported boundary format: `{data}`. " + cls.get_help()
- )
+ ) from None
else:
if data_len == 2:
# assume that data is given for each boundary
diff --git a/pde/grids/boundaries/local.py b/pde/grids/boundaries/local.py
index 161d6f54..a96c7719 100644
--- a/pde/grids/boundaries/local.py
+++ b/pde/grids/boundaries/local.py
@@ -426,7 +426,7 @@ def from_str(
except KeyError:
raise BCDataError(
f"Boundary condition `{condition}` not defined. " + cls.get_help()
- )
+ ) from None
# create the actual class
return boundary_class(grid=grid, axis=axis, upper=upper, rank=rank, **kwargs)
@@ -459,7 +459,7 @@ def from_dict(
data = data.copy() # need to make a copy since we modify it below
# parse all possible variants that could be given
- if "type" in data.keys():
+ if "type" in data:
# type is given (optionally with a value)
b_type = data.pop("type")
return cls.from_str(grid, axis, upper, condition=b_type, rank=rank, **data)
@@ -1224,14 +1224,13 @@ def __init__(
except Exception as err:
if self._is_func:
raise BCDataError(
- f"Could not evaluate BC function. Expected signature "
- f"{signature}.\nEncountered error: {err}"
- )
+ f"Could not evaluate BC function. Expected signature {signature}."
+ ) from err
else:
raise BCDataError(
f"Could not evaluate BC expression `{expression}` with signature "
- f"{signature}.\nEncountered error: {err}"
- )
+ f"{signature}."
+ ) from err
@property
def _test_values(self) -> tuple[float, ...]:
@@ -1275,7 +1274,7 @@ def value_func(*args):
except nb.NumbaError:
# if compilation fails, we simply fall back to pure-python mode
- self._logger.warning(f"Cannot compile BC {self}")
+ self._logger.warning("Cannot compile BC %s", self)
@register_jitable
def value_func(*args):
@@ -1360,7 +1359,7 @@ def _get_function_from_expression(self, do_jit: bool) -> Callable:
except nb.NumbaError:
# if compilation fails, we simply fall back to pure-python mode
- self._logger.warning(f"Cannot compile BC {self._func_expression}")
+ self._logger.warning("Cannot compile BC %s", self._func_expression)
# calculate the expected value to test this later (and fail early)
expected = func(*self._test_values)
@@ -1391,12 +1390,14 @@ def value_func(grid_value, dx, x, y, z, t):
else:
# cheap way to signal a problem
- raise ValueError
+ raise ValueError from None
# compile the actual functio and check the result
result_compiled = value_func(*self._test_values)
if not np.allclose(result_compiled, expected):
- raise RuntimeError("Compiled function does not give same value")
+ raise RuntimeError(
+ "Compiled function does not give same value"
+ ) from None
return value_func # type: ignore
@@ -2974,4 +2975,4 @@ def registered_boundary_condition_names() -> dict[str, type[BCBase]]:
Returns:
dict: a dictionary with the names of the boundary conditions that can be used
"""
- return {cls_name: cls for cls_name, cls in BCBase._conditions.items()}
+ return dict(BCBase._conditions.items())
diff --git a/pde/grids/cartesian.py b/pde/grids/cartesian.py
index ca6cc600..187561de 100644
--- a/pde/grids/cartesian.py
+++ b/pde/grids/cartesian.py
@@ -319,7 +319,7 @@ def _get_axis(axis):
try:
axis = self.axes.index(axis)
except ValueError:
- raise ValueError(f"Axis `{axis}` not defined")
+ raise ValueError(f"Axis `{axis}` not defined") from None
return axis
if extract == "auto":
diff --git a/pde/grids/operators/cartesian.py b/pde/grids/operators/cartesian.py
index e5a4907f..7032f962 100644
--- a/pde/grids/operators/cartesian.py
+++ b/pde/grids/operators/cartesian.py
@@ -986,7 +986,8 @@ def _make_divergence_scipy_nd(
def divergence(arr: np.ndarray, out: np.ndarray) -> None:
"""Apply divergence operator to array `arr`"""
- assert arr.shape[0] == len(data_shape) and arr.shape[1:] == data_shape
+ assert arr.shape[0] == len(data_shape)
+ assert arr.shape[1:] == data_shape
# need to initialize with zeros since data is added later
if out is None:
diff --git a/pde/grids/operators/polar_sym.py b/pde/grids/operators/polar_sym.py
index 1a25ee15..078c2120 100644
--- a/pde/grids/operators/polar_sym.py
+++ b/pde/grids/operators/polar_sym.py
@@ -300,12 +300,11 @@ def _get_laplace_matrix(bcs: Boundaries) -> tuple[np.ndarray, np.ndarray]:
if r_min == 0:
matrix[i, i + 1] = 2 * scale
continue # the special case of the inner boundary is handled
- else:
- const, entries = bcs[0].get_sparse_matrix_data((-1,))
- factor = scale - scale_i
- vector[i] += const * factor
- for k, v in entries.items():
- matrix[i, k] += v * factor
+ const, entries = bcs[0].get_sparse_matrix_data((-1,))
+ factor = scale - scale_i
+ vector[i] += const * factor
+ for k, v in entries.items():
+ matrix[i, k] += v * factor
else:
matrix[i, i - 1] = scale - scale_i
diff --git a/pde/grids/operators/spherical_sym.py b/pde/grids/operators/spherical_sym.py
index fb0b9f80..f5972b47 100644
--- a/pde/grids/operators/spherical_sym.py
+++ b/pde/grids/operators/spherical_sym.py
@@ -57,7 +57,8 @@ def make_laplace(grid: SphericalSymGrid, *, conservative: bool = True) -> Operat
# create a conservative spherical laplace operator
rl = rs - dr / 2 # inner radii of spherical shells
rh = rs + dr / 2 # outer radii
- assert np.isclose(rl[0], r_min) and np.isclose(rh[-1], r_max)
+ assert np.isclose(rl[0], r_min)
+ assert np.isclose(rh[-1], r_max)
volumes = (rh**3 - rl**3) / 3 # volume of the spherical shells
factor_l = rl**2 / (dr * volumes)
factor_h = rh**2 / (dr * volumes)
@@ -496,7 +497,8 @@ def make_tensor_double_divergence(
rl = rs - dr / 2 # inner radii of spherical shells
rh = rs + dr / 2 # outer radii
r_min, r_max = grid.axes_bounds[0]
- assert np.isclose(rl[0], r_min) and np.isclose(rh[-1], r_max)
+ assert np.isclose(rl[0], r_min)
+ assert np.isclose(rh[-1], r_max)
volumes = (rh**3 - rl**3) / 3 # volume of the spherical shells
factor_l = rl / volumes
factor_h = rh / volumes
diff --git a/pde/pdes/base.py b/pde/pdes/base.py
index 03f20c91..4ae79ad2 100644
--- a/pde/pdes/base.py
+++ b/pde/pdes/base.py
@@ -447,7 +447,6 @@ def noise_realization(state_data: np.ndarray, t: float) -> np.ndarray:
@jit
def noise_realization(state_data: np.ndarray, t: float) -> None:
"""Helper function returning a noise realization."""
- return None
return noise_realization # type: ignore
diff --git a/pde/pdes/laplace.py b/pde/pdes/laplace.py
index e8d58837..8931f6f4 100644
--- a/pde/pdes/laplace.py
+++ b/pde/pdes/laplace.py
@@ -65,14 +65,14 @@ def solve_poisson_equation(
result = ScalarField(rhs.grid, label=label)
try:
solver(rhs.data, result.data)
- except RuntimeError:
+ except RuntimeError as err:
magnitude = rhs.magnitude
if magnitude > 1e-10:
raise RuntimeError(
"Could not solve the Poisson problem. One possible reason for this is "
"that only periodic or Neumann conditions are applied although the "
f"magnitude of the field is {magnitude} and thus non-zero."
- )
+ ) from err
else:
raise # another error occurred
diff --git a/pde/pdes/pde.py b/pde/pdes/pde.py
index 9345509e..19e00dbf 100644
--- a/pde/pdes/pde.py
+++ b/pde/pdes/pde.py
@@ -224,16 +224,16 @@ def __init__(
else:
raise ValueError(f'Cannot parse boundary condition "{key_str}"')
if key in self.bcs:
- self._logger.warning(f"Two boundary conditions for key {key}")
+ self._logger.warning("Two boundary conditions for key %s", key)
self.bcs[key] = value
# save information for easy inspection
self.diagnostics["pde"] = {
"variables": list(self.variables),
- "constants": list(sorted(self.consts)),
+ "constants": sorted(self.consts),
"explicit_time_dependence": explicit_time_dependence,
"complex_valued_rhs": complex_valued,
- "operators": list(sorted(set().union(*self._operators.values()))),
+ "operators": sorted(set().union(*self._operators.values())),
}
self._cache: dict[str, dict[str, Any]] = {}
@@ -309,7 +309,7 @@ def _compile_rhs_single(
expr._sympy_expr = expr._sympy_expr.replace(
# only modify the relevant operator
lambda expr: isinstance(expr.func, UndefinedFunction)
- and expr.name == func
+ and expr.name == func # noqa: B023
# and do not modify it when the bc_args have already been set
and not (
isinstance(expr.args[-1], Symbol)
@@ -335,7 +335,7 @@ def _compile_rhs_single(
else:
# expression only depends on the actual variables
- extra_args = tuple() # @UnusedVariable
+ extra_args = () # @UnusedVariable
# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
@@ -462,7 +462,7 @@ def _prepare_cache(
# check whether there are boundary conditions that have not been used
bcs_left = set(self.bcs.keys()) - self.diagnostics["pde"]["bcs_used"] - {"*:*"}
if bcs_left:
- self._logger.warning("Unused BCs: %s", list(sorted(bcs_left)))
+ self._logger.warning("Unused BCs: %s", sorted(bcs_left))
# add extra information for field collection
if isinstance(state, FieldCollection):
diff --git a/pde/solvers/base.py b/pde/solvers/base.py
index c8a67233..a74c9126 100644
--- a/pde/solvers/base.py
+++ b/pde/solvers/base.py
@@ -30,7 +30,7 @@ class ConvergenceError(RuntimeError):
"""Indicates that an implicit step did not converge."""
-class SolverBase(metaclass=ABCMeta):
+class SolverBase:
"""Base class for PDE solvers."""
dt_default: float = 1e-3
@@ -73,7 +73,7 @@ def __init_subclass__(cls, **kwargs): # @NoSelf
cls._subclasses[cls.__name__] = cls
if hasattr(cls, "name") and cls.name:
if cls.name in cls._subclasses:
- logging.warning(f"Solver with name {cls.name} is already registered")
+ logging.warning("Solver with name %s is already registered", cls.name)
cls._subclasses[cls.name] = cls
@classmethod
@@ -108,14 +108,14 @@ def from_name(cls, name: str, pde: PDEBase, **kwargs) -> SolverBase:
raise ValueError(
f"Unknown solver method '{name}'. Registered solvers are "
+ ", ".join(solvers)
- )
+ ) from None
return solver_class(pde, **kwargs)
@classproperty
def registered_solvers(cls) -> list[str]: # @NoSelf
"""list of str: the names of the registered solvers"""
- return list(sorted(cls._subclasses.keys()))
+ return sorted(cls._subclasses.keys())
@property
def _compiled(self) -> bool:
@@ -250,7 +250,7 @@ def _make_pde_rhs(
time. The function returns the deterministic evolution rate and (if
applicable) a realization of the associated noise.
"""
- if getattr(self.pde, "is_sde"):
+ if getattr(self.pde, "is_sde", False):
raise RuntimeError(
f"Cannot create a deterministic stepper for a stochastic equation"
)
@@ -379,8 +379,9 @@ def make_stepper(
dt = self.dt_default
self._logger.warning(
"Explicit stepper with a fixed time step did not receive any "
- f"initial value for `dt`. Using dt={dt}, but specifying a value or "
- "enabling adaptive stepping is advisable."
+ "initial value for `dt`. Using dt=%g, but specifying a value or "
+ "enabling adaptive stepping is advisable.",
+ dt,
)
dt_float = float(dt) # explicit casting to help type checking
@@ -526,7 +527,7 @@ def _make_single_step_error_estimate(
An example for the state from which the grid and other information can
be extracted
"""
- if getattr(self.pde, "is_sde"):
+ if getattr(self.pde, "is_sde", False):
raise RuntimeError("Cannot use adaptive stepper with stochastic equation")
single_step = self._make_single_step_variable_dt(state)
@@ -638,7 +639,7 @@ def adaptive_stepper(
)
adaptive_stepper = jit(sig_adaptive)(adaptive_stepper)
- self._logger.info(f"Initialized adaptive stepper")
+ self._logger.info("Initialized adaptive stepper")
return adaptive_stepper
def make_stepper(
diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py
index bf53281c..972f9127 100644
--- a/pde/solvers/controller.py
+++ b/pde/solvers/controller.py
@@ -103,13 +103,13 @@ def t_range(self, value: TRangeType):
# determine time range
try:
self._t_range: tuple[float, float] = (0, float(value)) # type: ignore
- except TypeError: # assume a single number was given
+ except TypeError as err: # assume a single number was given
if len(value) == 2: # type: ignore
self._t_range = tuple(value) # type: ignore
else:
raise ValueError(
"t_range must be set to a single number or a tuple of two numbers"
- )
+ ) from err
def _get_stop_handler(self) -> Callable[[Exception, float], tuple[int, str]]:
"""Return function that handles messaging."""
@@ -196,7 +196,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None:
# evolve the system from t_start to t_end
t = t_start
- self._logger.debug(f"Start simulation at t={t}")
+ self._logger.debug("Start simulation at t=%g", t)
try:
while t < t_end:
# determine next time point with an action
@@ -265,8 +265,10 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None:
self._logger.log(msg_level, msg)
if profiler["tracker"] > max(profiler["solver"], 1):
self._logger.warning(
- f"Spent more time on handling trackers ({profiler['tracker']:.3g}) "
- f"than on the actual simulation ({profiler['solver']:.3g})"
+ "Spent more time on handling trackers (%.3g) than on the actual "
+ "simulation (%.3g)",
+ profiler["tracker"],
+ profiler["solver"],
)
def _run_client_process(self, state: TState, dt: float | None = None) -> None:
@@ -352,7 +354,7 @@ def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None
self._run_main_process(state, dt)
except Exception as err:
print(err) # simply print the exception to show some info
- self._logger.error(f"Error in main node", exc_info=err)
+ self._logger.error("Error in main node", exc_info=err)
time.sleep(0.5) # give some time for info to propagate
MPI.COMM_WORLD.Abort() # abort all other nodes
raise
@@ -365,7 +367,7 @@ def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None
self._run_client_process(state, dt)
except Exception as err:
print(err) # simply print the exception to show some info
- self._logger.error(f"Error in node {mpi.rank}", exc_info=err)
+ self._logger.error("Error in node %d", mpi.rank, exc_info=err)
time.sleep(0.5) # give some time for info to propagate
MPI.COMM_WORLD.Abort() # abort all other (and main) nodes
raise
@@ -394,7 +396,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None:
from ..tools import mpi
# copy the initial state to not modify the supplied one
- if getattr(self.solver, "pde") and self.solver.pde.complex_valued:
+ if hasattr(self.solver, "pde") and self.solver.pde.complex_valued:
self._logger.info("Convert state to complex numbers")
state: TState = initial_state.copy(dtype=complex)
else:
diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py
index cd141f41..4aceee8e 100644
--- a/pde/solvers/explicit.py
+++ b/pde/solvers/explicit.py
@@ -266,7 +266,7 @@ def adaptive_stepper(
)
adaptive_stepper = jit(sig_adaptive)(adaptive_stepper)
- self._logger.info(f"Init adaptive Euler stepper")
+ self._logger.info("Init adaptive Euler stepper")
return adaptive_stepper
def _make_single_step_error_estimate_rkf(
diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py
index a5b511b5..2b63d881 100644
--- a/pde/solvers/explicit_mpi.py
+++ b/pde/solvers/explicit_mpi.py
@@ -155,8 +155,9 @@ def make_stepper(
if not self.adaptive:
self._logger.warning(
"Explicit stepper with a fixed time step did not receive any "
- f"initial value for `dt`. Using dt={dt}, but specifying a value or "
- "enabling adaptive stepping is advisable."
+ "initial value for `dt`. Using dt=%g, but specifying a value or "
+ "enabling adaptive stepping is advisable.",
+ dt,
)
self.info["dt"] = dt
diff --git a/pde/solvers/scipy.py b/pde/solvers/scipy.py
index 579845f1..4eada32c 100644
--- a/pde/solvers/scipy.py
+++ b/pde/solvers/scipy.py
@@ -99,7 +99,7 @@ def stepper(state: FieldBase, t_start: float, t_end: float) -> float:
return sol.t[0] # type: ignore
if dt:
- self._logger.info(f"Init {self.__class__.__name__} stepper with dt=%g", dt)
+ self._logger.info("Init %s stepper with dt=%g", self.__class__.__name__, dt)
else:
- self._logger.info(f"Init {self.__class__.__name__} stepper")
+ self._logger.info("Init %s stepper", self.__class__.__name__)
return stepper
diff --git a/pde/storage/__init__.py b/pde/storage/__init__.py
index ef2c57d8..322f3697 100644
--- a/pde/storage/__init__.py
+++ b/pde/storage/__init__.py
@@ -12,11 +12,11 @@
.. codeauthor:: David Zwicker
"""
+import contextlib
+
from .file import FileStorage
from .memory import MemoryStorage, get_memory_storage
from .movie import MovieStorage
-try:
+with contextlib.suppress(ImportError):
from .modelrunner import ModelrunnerStorage
-except ImportError:
- ... # ModelrunnerStorage is only available when py-modelrunner is available
diff --git a/pde/storage/base.py b/pde/storage/base.py
index 4bcd4e83..6521b9d5 100644
--- a/pde/storage/base.py
+++ b/pde/storage/base.py
@@ -176,8 +176,8 @@ def grid(self) -> GridBase | None:
self._grid = attrs["fields"][0]["grid"]
else:
self._logger.warning(
- "`grid` attribute was not stored. Available attributes: "
- + ", ".join(sorted(attrs.keys()))
+ "`grid` attribute was not stored. Available attributes: %s",
+ ", ".join(sorted(attrs.keys())),
)
else:
@@ -219,8 +219,8 @@ def _init_field(self) -> None:
f"{local_shape} could not be interpreted automatically"
)
self._logger.warning(
- "`field` attribute was not stored. We guessed that the data is of "
- f"type {self._field.__class__.__name__}."
+ "`field` attribute was not stored. Assume data is of type %s.",
+ self._field.__class__.__name__,
)
def _get_field(self, t_index: int) -> FieldBase:
diff --git a/pde/storage/file.py b/pde/storage/file.py
index 05f0b6ca..d4af8137 100644
--- a/pde/storage/file.py
+++ b/pde/storage/file.py
@@ -74,7 +74,7 @@ def __init__(
self._data_length: int = None # type: ignore
self._max_length: int | None = max_length
- if not self.check_mpi or mpi.is_main:
+ if not self.check_mpi or mpi.is_main: # noqa: SIM102
# we are on the main process and can thus open the file directly
if self.filename.is_file() and self.filename.stat().st_size > 0:
try:
@@ -82,7 +82,7 @@ def __init__(
except (OSError, KeyError):
self.close()
self._logger.warning(
- f"File `{filename}` could not be opened for reading"
+ "File `%s` could not be opened for reading", filename
)
def __del__(self):
@@ -103,7 +103,7 @@ def _file_state(self) -> str:
def close(self) -> None:
"""Close the currently opened file."""
if self._file is not None:
- self._logger.info(f"Close file `{self.filename}`")
+ self._logger.info("Close file `%s`", self.filename)
self._file.close()
self._file = None
self._data_length = None # type: ignore
@@ -117,7 +117,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
def _create_hdf_dataset(
self,
name: str,
- shape: tuple[int, ...] = tuple(),
+ shape: tuple[int, ...] = (),
dtype: DTypeLike = np.double,
):
"""Create a hdf5 dataset with the given name and data_shape.
@@ -176,7 +176,7 @@ def _open(
# close file to open it again for reading or appending
if self._file:
self._file.close()
- self._logger.info(f"Open file `{self.filename}` for reading")
+ self._logger.info("Open file `%s` for reading", self.filename)
self._file = h5py.File(self.filename, mode="r")
self._times = self._file["times"]
self._data = self._file["data"]
@@ -200,7 +200,7 @@ def _open(
self.close()
# open file for reading or appending
- self._logger.info(f"Open file `{self.filename}` for appending")
+ self._logger.info("Open file `%s` for appending", self.filename)
self._file = h5py.File(self.filename, mode="a")
if "times" in self._file and "data" in self._file:
@@ -234,7 +234,7 @@ def _open(
self.close()
else:
ensure_directory_exists(self.filename.parent)
- self._logger.info(f"Open file `{self.filename}` for writing")
+ self._logger.info("Open file `%s` for writing", self.filename)
self._file = h5py.File(self.filename, "w")
self._times = self._create_hdf_dataset("times")
self._data = self._create_hdf_dataset(
@@ -337,7 +337,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
super().start_writing(field, info=info)
# initialize the file for writing with the correct mode
- self._logger.debug(f"Start writing with mode `{self.write_mode}`")
+ self._logger.debug("Start writing with mode '%s'", self.write_mode)
if self.write_mode == "truncate_once":
self._open("writing", info)
self.write_mode = "append" # do not truncate for next writing
diff --git a/pde/storage/modelrunner.py b/pde/storage/modelrunner.py
index 28cb240a..d79fe7a7 100644
--- a/pde/storage/modelrunner.py
+++ b/pde/storage/modelrunner.py
@@ -130,7 +130,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
super().start_writing(field, info=info)
# initialize the file for writing with the correct mode
- self._logger.debug(f"Start writing with mode `{self.write_mode}`")
+ self._logger.debug("Start writing with mode '%s'", self.write_mode)
if self.write_mode == "truncate_once":
self.write_mode = "append" # do not truncate for next writing
elif self.write_mode == "readonly":
diff --git a/pde/storage/movie.py b/pde/storage/movie.py
index a32c8e7c..8c80f5ae 100644
--- a/pde/storage/movie.py
+++ b/pde/storage/movie.py
@@ -144,7 +144,7 @@ def __del__(self):
def close(self) -> None:
"""Close the currently opened file."""
if self._ffmpeg is not None:
- self._logger.info(f"Close movie file `{self.filename}`")
+ self._logger.info("Close movie file '%s'", self.filename)
if self._state == "writing":
self._ffmpeg.stdin.close()
self._ffmpeg.wait()
@@ -191,7 +191,7 @@ def _read_metadata(self) -> None:
# sanity checks on the video
nb_streams = info["format"]["nb_streams"]
if nb_streams != 1:
- self._logger.warning(f"Only using first of {nb_streams} streams")
+ self._logger.warning("Only using first of %d streams", nb_streams)
tags = info["format"].get("tags", {})
# read comment field, which can be either lower case or upper case
@@ -202,7 +202,7 @@ def _read_metadata(self) -> None:
version = metadata.pop("version", 1)
if version != 1:
- self._logger.warning(f"Unknown metadata version `{version}`")
+ self._logger.warning("Unknown metadata version `%d`", version)
self.vmin = metadata.pop("vmin", 0)
self.vmax = metadata.pop("vmax", 1)
self.info.update(metadata)
@@ -217,8 +217,8 @@ def _read_metadata(self) -> None:
try:
fps = Fraction(stream.get("avg_frame_rate", None))
duration = parse_duration(stream.get("tags", {}).get("DURATION"))
- except TypeError:
- raise RuntimeError("Frame count could not be read from video")
+ except TypeError as err:
+ raise RuntimeError("Frame count could not be read from video") from err
else:
self.info["num_frames"] = int(duration.total_seconds() * float(fps))
self.info["width"] = stream["width"]
@@ -234,12 +234,13 @@ def _read_metadata(self) -> None:
try:
self._format = FFmpeg.formats[video_format]
except KeyError:
- self._logger.warning(f"Unknown pixel format `{video_format}`")
+ self._logger.warning("Unknown pixel format `%s`", video_format)
else:
if self._format.pix_fmt_file != stream.get("pix_fmt"):
self._logger.info(
- "Pixel format differs from requested one: "
- f"{self._format.pix_fmt_file} != {stream.get('pix_fmt')}"
+ "Pixel format differs from requested one: %s != %s",
+ self._format.pix_fmt_file,
+ stream.get("pix_fmt"),
)
def _init_normalization(self, field: FieldBase) -> None:
@@ -332,7 +333,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
self._init_normalization(field)
# set input
- self._logger.debug(f"Start ffmpeg process for `{self.filename}`")
+ self._logger.debug("Start ffmpeg process for `%s`", self.filename)
input_args = {
"format": "rawvideo",
"s": f"{width}x{height}",
@@ -358,7 +359,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
self._ffmpeg = f_output.run_async(pipe_stdin=True) # start process
if self.write_times:
- self._times_file = open(self._filename_times, "w")
+ self._times_file = self._filename_times.open("w") # noqa: SIM115
self.info["num_frames"] = 0
self._warned_normalization = False
@@ -390,9 +391,9 @@ def _append_data(self, data: np.ndarray, time: float) -> None:
t_start = 0
dt = self.info.get("dt", 1)
time_expect = t_start + dt * self.info["num_frames"]
- if not np.isclose(time, time_expect):
- if not self.info.get("time_mismatch", False):
- self._logger.warning(f"Time mismatch: {time} != {time_expect}")
+ if not np.isclose(time, time_expect): # discrepancy in time # noqa: SIM102
+ if not self.info.get("time_mismatch", False): # not yet warned
+ self._logger.warning("Time mismatch: %g != %g", time, time_expect)
self.info["time_mismatch"] = True
# make sure there are two spatial dimensions
@@ -421,8 +422,9 @@ def _append_data(self, data: np.ndarray, time: float) -> None:
if not self._warned_normalization:
if np.any(data[i, ...] < norm.vmin) or np.any(data[i, ...] > norm.vmax):
self._logger.warning(
- f"Data outside range specified by `vmin={norm.vmin}` and "
- f"`vmax={norm.vmax}`"
+ "Data outside range specified by `vmin=%g` and `vmax=%g`",
+ norm.vmin,
+ norm.vmax,
)
self._warned_normalization = True # only warn once
data_norm = norm(data[i, ...])
@@ -434,7 +436,7 @@ def _append_data(self, data: np.ndarray, time: float) -> None:
def end_writing(self) -> None:
"""Finalize the storage after writing."""
- if not self._state == "writing":
+ if self._state != "writing":
self._logger.warning("Writing was already terminated")
return # writing mode was already ended
self._logger.debug("End writing")
@@ -459,8 +461,9 @@ def times(self):
times = np.loadtxt(self._filename_times)
except OSError:
self._logger.warning(
- f"Could not read time stamps from file `{self._filename_times}`. "
- "Return equidistant times instead."
+ "Could not read time stamps from file `%s`. "
+ "Return equidistant times instead.",
+ self._filename_times,
)
if times is None:
@@ -630,7 +633,7 @@ def add_to_state(state):
if not (self.write_times or isinstance(interrupts, ConstantInterrupts)):
self._logger.warning(
- f"Use `write_times=True` to write times for complex interrupts"
+ "Use `write_times=True` to write times for complex interrupts"
)
# store data for common case of constant intervals
self.info["dt"] = getattr(interrupts, "dt", 1)
diff --git a/pde/tools/cache.py b/pde/tools/cache.py
index ba0c23c6..744a69f4 100644
--- a/pde/tools/cache.py
+++ b/pde/tools/cache.py
@@ -244,7 +244,7 @@ def make_serializer(method: SerializerMethod) -> Callable:
return lambda s: yaml.dump(s).encode("utf-8")
- raise ValueError("Unknown serialization method `%s`" % method)
+ raise ValueError(f"Unknown serialization method `{method}`")
def make_unserializer(method: SerializerMethod) -> Callable:
@@ -285,7 +285,7 @@ def make_unserializer(method: SerializerMethod) -> Callable:
return yaml.unsafe_load
- raise ValueError("Unknown serialization method `%s`" % method)
+ raise ValueError(f"Unknown serialization method `{method}`")
class DictFiniteCapacity(collections.OrderedDict):
diff --git a/pde/tools/config.py b/pde/tools/config.py
index d4fbacae..0b6e3c91 100644
--- a/pde/tools/config.py
+++ b/pde/tools/config.py
@@ -103,10 +103,10 @@ def __setitem__(self, key: str, value):
elif self.mode == "update":
try:
self[key] # test whether the key already exist (including magic keys)
- except KeyError:
+ except KeyError as err:
raise KeyError(
f"{key} is not present and config is not in `insert` mode"
- )
+ ) from err
self.data[key] = value
elif self.mode == "locked":
@@ -128,7 +128,7 @@ def to_dict(self) -> dict[str, Any]:
Returns:
dict: A representation of the configuration in a normal :class:`dict`.
"""
- return {k: v for k, v in self.items()}
+ return dict(self.items())
def __repr__(self) -> str:
"""Represent the configuration as a string."""
@@ -179,10 +179,8 @@ def parse_version_str(ver_str: str) -> list[int]:
"""Helper function converting a version string into a list of integers."""
result = []
for token in ver_str.split(".")[:3]:
- try:
+ with contextlib.suppress(ValueError):
result.append(int(token))
- except ValueError:
- pass
return result
@@ -215,7 +213,7 @@ def packages_from_requirements(requirements_file: Path | str) -> list[str]:
"""
result = []
try:
- with open(requirements_file) as fp:
+ with Path(requirements_file).open() as fp:
for line in fp:
line_s = line.strip()
if line_s.startswith("#"):
diff --git a/pde/tools/cuboid.py b/pde/tools/cuboid.py
index 933b30e8..4eaa2558 100644
--- a/pde/tools/cuboid.py
+++ b/pde/tools/cuboid.py
@@ -120,9 +120,7 @@ def copy(self) -> Cuboid:
return self.__class__(self.pos, self.size)
def __repr__(self):
- return "{cls}(pos={pos}, size={size})".format(
- cls=self.__class__.__name__, pos=self.pos, size=self.size
- )
+ return f"{self.__class__.__name__}(pos={self.pos}, size={self.size})"
def __add__(self, other: Cuboid) -> Cuboid:
"""The sum of two cuboids is the minimal cuboid enclosing both."""
diff --git a/pde/tools/docstrings.py b/pde/tools/docstrings.py
index 888a5fc1..eba7716b 100644
--- a/pde/tools/docstrings.py
+++ b/pde/tools/docstrings.py
@@ -14,6 +14,7 @@
import re
import textwrap
+from functools import partial
from typing import TypeVar
DOCSTRING_REPLACEMENTS = {
@@ -164,17 +165,16 @@ def fill_in_docstring(f: TFunc) -> TFunc:
width=80, expand_tabs=True, replace_whitespace=True, drop_whitespace=True
)
- for name, value in DOCSTRING_REPLACEMENTS.items():
-
- def repl(matchobj) -> str:
- """Helper function replacing token in docstring."""
- tw.initial_indent = tw.subsequent_indent = matchobj.group(1)
- return tw.fill(textwrap.dedent(value))
+ def repl(matchobj, value: str) -> str:
+ """Helper function replacing token in docstring."""
+ tw.initial_indent = tw.subsequent_indent = matchobj.group(1)
+ return tw.fill(textwrap.dedent(value))
+ for name, value in DOCSTRING_REPLACEMENTS.items():
token = "{" + name + "}"
f.__doc__ = re.sub(
f"^([ \t]*){token}",
- repl,
+ partial(repl, value=value),
f.__doc__, # type: ignore
flags=re.MULTILINE,
)
diff --git a/pde/tools/expressions.py b/pde/tools/expressions.py
index c8a3c26d..8319999f 100644
--- a/pde/tools/expressions.py
+++ b/pde/tools/expressions.py
@@ -260,14 +260,13 @@ def __init__(
for name, value in self.consts.items():
if isinstance(value, FieldBase):
self._logger.warning(
- f"Constant `{name}` is a field, but expressions usually require "
- f"numerical arrays. Did you mean to use `{name}.data`?"
+ "Constant `%(name)s` is a field, but expressions usually require "
+ "numerical arrays. Did you mean to use `%(name)s.data`?",
+ {"name": name},
)
def __repr__(self):
- return (
- f'{self.__class__.__name__}("{self.expression}", ' f"signature={self.vars})"
- )
+ return f'{self.__class__.__name__}("{self.expression}", signature={self.vars})'
def __eq__(self, other):
"""Compare this expression to another one."""
@@ -324,9 +323,9 @@ def _check_signature(self, signature: Sequence[str | list[str]] | None = None):
args = {str(s).split("[")[0] for s in self._free_symbols}
if signature is None:
# create signature from arguments
- signature = list(sorted(args))
+ signature = sorted(args)
- self._logger.debug(f"Expression arguments: {args}")
+ self._logger.debug("Expression arguments: %s", args)
# check whether variables are in signature
self.vars: Any = []
@@ -345,7 +344,7 @@ def _check_signature(self, signature: Sequence[str | list[str]] | None = None):
old = sympy.symbols(arg)
new = sympy.symbols(arg_name)
self._sympy_expr = self._sympy_expr.subs(old, new)
- self._logger.info(f'Renamed variable "{old}"->"{new}"')
+ self._logger.info('Renamed variable "%s"->"%s"', old, new)
found.add(arg)
break
@@ -514,7 +513,7 @@ def get_compiled(self, single_arg: bool = False) -> Callable[..., NumberOrArray]
class ScalarExpression(ExpressionBase):
"""Describes a mathematical expression of a scalar quantity."""
- shape: tuple[int, ...] = tuple()
+ shape: tuple[int, ...] = ()
@fill_in_docstring
def __init__(
@@ -678,9 +677,8 @@ def differentiate(self, var: str) -> ScalarExpression:
return ScalarExpression(
expression=0, signature=self.vars, allow_indexed=self.allow_indexed
)
- if self.allow_indexed:
- if self._var_indexed(var):
- raise NotImplementedError("Cannot differentiate with respect to vector")
+ if self.allow_indexed and self._var_indexed(var):
+ raise NotImplementedError("Cannot differentiate with respect to vector")
# turn variable into sympy object and treat an indexed variable separately
var_expr = self._prepare_expression(var)
@@ -707,11 +705,10 @@ def derivatives(self) -> TensorExpression:
expression = sympy.Array(np.zeros(dim), shape=(dim,))
return TensorExpression(expression=expression, signature=self.vars)
- if self.allow_indexed:
- if any(self._var_indexed(var) for var in self.vars):
- raise RuntimeError(
- "Cannot calculate gradient for expressions with indexed variables"
- )
+ if self.allow_indexed and any(self._var_indexed(var) for var in self.vars):
+ raise RuntimeError(
+ "Cannot calculate gradient for expressions with indexed variables"
+ )
grad = sympy.Array([self._sympy_expr.diff(sympy.Symbol(v)) for v in self.vars])
return TensorExpression(
@@ -1090,7 +1087,7 @@ def evaluate(
# check whether there are boundary conditions that have not been used
bcs_left = set(bcs.keys()) - bcs_used - {"*:*", "*"}
if bcs_left:
- logger.warning("Unused BCs: %s", list(sorted(bcs_left)))
+ logger.warning("Unused BCs: %s", sorted(bcs_left))
# obtain the function to calculate the right hand side
signature = tuple(fields_keys) + ("none", "bc_args")
@@ -1108,7 +1105,7 @@ def evaluate(
else:
# expression only depends on the actual variables
- extra_args = tuple() # @UnusedVariable
+ extra_args = () # @UnusedVariable
# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
diff --git a/pde/tools/ffmpeg.py b/pde/tools/ffmpeg.py
index 1416cba0..0032e096 100644
--- a/pde/tools/ffmpeg.py
+++ b/pde/tools/ffmpeg.py
@@ -10,6 +10,8 @@
.. codeauthor:: David Zwicker
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from typing import Optional, Union
@@ -48,7 +50,7 @@ def bytes_per_channel(self) -> int:
return self.bits_per_channel // 8
@property
- def max_value(self) -> Union[float, int]:
+ def max_value(self) -> float | int:
"""Maximal value stored in a color channel."""
if np.issubdtype(self.dtype, np.integer):
return 2**self.bits_per_channel - 1 # type: ignore
@@ -109,35 +111,11 @@ def data_from_frame(self, frame_data: np.ndarray):
bits_per_channel=16,
dtype=np.dtype(" Optional[str]:
+def find_format(channels: int, bits_per_channel: int = 8) -> str | None:
"""Find a defined FFmpegFormat that satisifies the requirements.
Args:
@@ -153,13 +131,14 @@ def find_format(channels: int, bits_per_channel: int = 8) -> Optional[str]:
"""
n_best, f_best = None, None
for n, f in formats.items(): # iterate through all defined formats
- if f.channels >= channels and f.bits_per_channel >= bits_per_channel:
- # this format satisfies the requirements
- if (
+ if (
+ f.channels >= channels
+ and f.bits_per_channel >= bits_per_channel # satisfies the requirements
+ and (
f_best is None
or f.bits_per_channel < f_best.bits_per_channel
or f.channels < f_best.channels
- ):
- # the current format is better than the previous one
- n_best, f_best = n, f
+ ) # the current format is better than the previous one
+ ):
+ n_best, f_best = n, f
return n_best
diff --git a/pde/tools/misc.py b/pde/tools/misc.py
index b19dfc6a..3dcf9e4b 100644
--- a/pde/tools/misc.py
+++ b/pde/tools/misc.py
@@ -60,11 +60,8 @@ def ensure_directory_exists(folder: str | Path):
Args:
folder (str): path of the new folder
"""
- folder = str(folder)
- if folder == "":
- return
try:
- os.makedirs(folder)
+ Path(folder).mkdir(parents=True)
except OSError as err:
if err.errno != errno.EEXIST:
raise
diff --git a/pde/tools/mpi.py b/pde/tools/mpi.py
index 7d3f36ed..d767ab3d 100644
--- a/pde/tools/mpi.py
+++ b/pde/tools/mpi.py
@@ -90,7 +90,7 @@ def __getattr__(self, name: str):
try:
return self._name_ids[name]
except KeyError:
- raise AttributeError
+ raise AttributeError(f"MPI operator `{name}` not registered") from None
Operator = _OperatorRegistry()
Operator.register("MAX", MPI.MAX)
diff --git a/pde/tools/output.py b/pde/tools/output.py
index af58b618..110c4c12 100644
--- a/pde/tools/output.py
+++ b/pde/tools/output.py
@@ -41,7 +41,7 @@ def get_progress_bar_class(fancy: bool = True):
progress_bar_class = tqdm.tqdm
else:
# use the fancier version of the progress bar in jupyter
- from tqdm.auto import tqdm as progress_bar_class # type: ignore
+ from tqdm.auto import tqdm as progress_bar_class
else:
# only import text progress bar
progress_bar_class = tqdm.tqdm
diff --git a/pde/tools/parameters.py b/pde/tools/parameters.py
index 8d6bd38a..a56f932c 100644
--- a/pde/tools/parameters.py
+++ b/pde/tools/parameters.py
@@ -134,11 +134,11 @@ def convert(self, value=None):
else:
try:
return self.cls(value)
- except ValueError:
+ except ValueError as err:
raise ValueError(
f"Could not convert {value!r} to {self.cls.__name__} for parameter "
f"'{self.name}'"
- )
+ ) from err
class DeprecatedParameter(Parameter):
@@ -228,9 +228,9 @@ def get_parameters(
"""
# collect the parameters from the class hierarchy
parameters: dict[str, Parameter] = {}
- for cls in reversed(cls.__mro__):
- if hasattr(cls, "parameters_default"):
- for p in cls.parameters_default:
+ for parent_cls in reversed(cls.__mro__):
+ if hasattr(parent_cls, "parameters_default"):
+ for p in parent_cls.parameters_default:
if isinstance(p, HideParameter):
if include_hidden:
parameters[p.name].hidden = True
@@ -504,19 +504,22 @@ def sphinx_display_parameters(app, what, name, obj, options, lines):
app.connect('autodoc-process-docstring', sphinx_display_parameters)
"""
- if what == "class" and issubclass(obj, Parameterized):
- if any(":param parameters:" in line for line in lines):
- # parse parameters
- parameters = obj.get_parameters(sort=False)
- if parameters:
- lines.append(".. admonition::")
- lines.append(f" Parameters of {obj.__name__}:")
- lines.append(" ")
- for p in parameters.values():
- lines.append(f" {p.name}")
- text = p.description.splitlines()
- text.append(f"(Default value: :code:`{p.default_value!r}`)")
- text = [" " + t for t in text]
- lines.extend(text)
- lines.append("")
+ if (
+ what == "class"
+ and issubclass(obj, Parameterized)
+ and any(":param parameters:" in line for line in lines)
+ ):
+ # parse parameters
+ parameters = obj.get_parameters(sort=False)
+ if parameters:
+ lines.append(".. admonition::")
+ lines.append(f" Parameters of {obj.__name__}:")
+ lines.append(" ")
+ for p in parameters.values():
+ lines.append(f" {p.name}")
+ text = p.description.splitlines()
+ text.append(f"(Default value: :code:`{p.default_value!r}`)")
+ text = [" " + t for t in text]
+ lines.extend(text)
lines.append("")
+ lines.append("")
diff --git a/pde/tools/plotting.py b/pde/tools/plotting.py
index 50bf6646..16f90f86 100644
--- a/pde/tools/plotting.py
+++ b/pde/tools/plotting.py
@@ -106,14 +106,10 @@ def get_size(self, renderer):
cbar = ax.figure.colorbar(axes_image, cax=cax, **kwargs)
# disable the offset that matplotlib sometimes shows
- try:
+ with contextlib.suppress(AttributeError):
cax.get_xaxis().get_major_formatter().set_useOffset(False)
- except AttributeError:
- pass # can happen for logarithmically formatted axes
- try:
+ with contextlib.suppress(AttributeError):
cax.get_yaxis().get_major_formatter().set_useOffset(False)
- except AttributeError:
- pass # can happen for logarithmically formatted axes
if label:
cbar.set_label(label)
@@ -297,7 +293,7 @@ def wrapper(
if ax is None:
# create new figure
backend = mpl.get_backend()
- if "backend_inline" in backend or "nbAgg" == backend:
+ if "backend_inline" in backend or backend == "nbAgg":
plt.close("all") # close left over figures
auto_show_figure = True # show this figure if action == 'auto'
fig, ax = plt.subplots()
@@ -492,7 +488,7 @@ def wrapper(
if fig is None:
# create new figure
backend = mpl.get_backend()
- if "backend_inline" in backend or "nbAgg" == backend:
+ if "backend_inline" in backend or backend == "nbAgg":
plt.close("all") # close left over figures
fig = plt.figure(constrained_layout=constrained_layout)
@@ -593,7 +589,7 @@ def __init__(self, title: str | None = None, show: bool = True):
self.initial_plot = True
self.fig = None
self._logger = logging.getLogger(__name__)
- self._logger.info(f"Initialize {self.__class__.__name__}")
+ self._logger.info("Initialize %s", self.__class__.__name__)
def __enter__(self):
# start the plotting process
@@ -729,10 +725,8 @@ def close(self):
"""Close the plot."""
super().close()
# close ipython output
- try:
+ with contextlib.suppress(Exception):
self._ipython_out.close()
- except Exception:
- pass
def get_plotting_context(
@@ -845,7 +839,7 @@ def napari_add_layers(
"""adds layers to a `napari `__ viewer
Args:
- viewer (:class:`napar i.viewer.Viewer`):
+ viewer (:class:`napari.viewer.Viewer`):
The napari application
layers_data (dict):
Data for all layers that will be added.
@@ -855,7 +849,7 @@ def napari_add_layers(
layer_type = layer_data.pop("type")
try:
add_layer = getattr(viewer, f"add_{layer_type}")
- except AttributeError:
- raise RuntimeError(f"Unknown layer type: {layer_type}")
+ except AttributeError as err:
+ raise RuntimeError(f"Unknown layer type: {layer_type}") from err
else:
add_layer(**layer_data)
diff --git a/pde/tools/spectral.py b/pde/tools/spectral.py
index 25fd10ff..1b076b8d 100644
--- a/pde/tools/spectral.py
+++ b/pde/tools/spectral.py
@@ -98,7 +98,6 @@ def noise_colored() -> np.ndarray:
arr *= scaling
# backwards transform
- arr = np_irfftn(arr, shape)
- return arr
+ return np_irfftn(arr, shape) # type: ignore
return noise_colored
diff --git a/pde/trackers/base.py b/pde/trackers/base.py
index 838b5d11..0609e096 100644
--- a/pde/trackers/base.py
+++ b/pde/trackers/base.py
@@ -71,9 +71,9 @@ def from_data(cls, data: TrackerDataType, **kwargs) -> TrackerBase:
elif isinstance(data, str):
try:
tracker_cls = cls._subclasses[data]
- except KeyError:
+ except KeyError as err:
trackers = sorted(cls._subclasses.keys())
- raise ValueError(f"Tracker `{data}` is not in {trackers}")
+ raise ValueError(f"Tracker `{data}` is not in {trackers}") from err
return tracker_cls(**kwargs)
else:
raise ValueError(f"Unsupported tracker format: `{data}`.")
diff --git a/pde/trackers/interactive.py b/pde/trackers/interactive.py
index 68b7bbaa..01128380 100644
--- a/pde/trackers/interactive.py
+++ b/pde/trackers/interactive.py
@@ -6,6 +6,7 @@
from __future__ import annotations
+import contextlib
import logging
import multiprocessing as mp
import platform
@@ -27,7 +28,7 @@ def napari_process(
t_initial: float | None = None,
viewer_args: dict[str, Any] | None = None,
):
- """:mod:`multiprocessing.Process` running `napari `__
+ """:mod:`multiprocessing.Process` running `napari120 `__
Args:
data_channel (:class:`multiprocessing.Queue`):
@@ -109,16 +110,16 @@ def update_listener():
update_data = data
# continue running until the queue is empty
else:
- logger.warning(f"Unexpected action: {action}")
+ logger.warning("Unexpected action: %s", action)
# update napari view when there is data
if update_data is not None:
- logger.debug(f"Update napari layer...")
+ logger.debug("Update napari layer...")
layer_data, t = update_data
if label is not None:
label.setText(f"Time: {t}")
- for name, layer_data in layer_data.items():
- viewer.layers[name].data = layer_data["data"]
+ for name, data in layer_data.items():
+ viewer.layers[name].data = data["data"]
yield
@@ -193,10 +194,8 @@ def update(self, state: FieldBase, t: float):
except queue.Full:
pass # could not write data
else:
- try:
+ with contextlib.suppress(queue.Empty):
self.data_channel.get(block=False)
- except queue.Empty:
- pass
def close(self, force: bool = True):
"""Closes the napari process.
@@ -208,10 +207,8 @@ def close(self, force: bool = True):
"""
if self.proc.is_alive() and force:
# signal to napari process that it should be closed
- try:
+ with contextlib.suppress(RuntimeError):
self.data_channel.put(("close", None))
- except RuntimeError:
- pass
self.data_channel.close()
self.data_channel.join_thread()
diff --git a/pde/trackers/trackers.py b/pde/trackers/trackers.py
index 16c5be9f..a96d478c 100644
--- a/pde/trackers/trackers.py
+++ b/pde/trackers/trackers.py
@@ -44,7 +44,7 @@
from .interrupts import InterruptData, RealtimeInterrupts
if TYPE_CHECKING:
- import pandas
+ import pandas # noqa: ICN001
class CallbackTracker(TrackerBase):
@@ -450,7 +450,7 @@ def initialize(self, state: FieldBase, info: InfoDict | None = None) -> float:
else:
self._update_method = "replot"
- self._logger.info(f'Update method: "{self._update_method}"')
+ self._logger.info('Update method: "%s"', self._update_method)
self._last_update = time.monotonic()
return super().initialize(state, info=info)
@@ -709,11 +709,13 @@ def to_file(self, filename: str, **kwargs):
\**kwargs:
Additional parameters may be supported for some formats
"""
- extension = os.path.splitext(filename)[1].lower()
+ from pathlib import Path
+
+ extension = Path(filename).suffix.lower()
if extension == ".pickle":
import pickle
- with open(filename, "wb") as fp:
+ with Path(filename).open("wb") as fp:
pickle.dump((self.times, self.data), fp, **kwargs)
elif extension == ".csv":
diff --git a/pde/visualization/movies.py b/pde/visualization/movies.py
index e4e24a22..e7638e88 100644
--- a/pde/visualization/movies.py
+++ b/pde/visualization/movies.py
@@ -261,13 +261,13 @@ def movie(
except AttributeError:
fig = ref[0].ax.figure
if show_time:
- title = fig.suptitle("Time %g" % t)
+ title = fig.suptitle(f"Time {t:g}")
else:
# update the data in the figure
field._update_plot(ref)
if show_time:
- title.set_text("Time %g" % t)
+ title.set_text(f"Time {t:g}")
# add the current matplotlib figure to the movie
movie.add_figure(fig)
diff --git a/pde/visualization/plotting.py b/pde/visualization/plotting.py
index 42f09ad1..8bb8aa65 100644
--- a/pde/visualization/plotting.py
+++ b/pde/visualization/plotting.py
@@ -844,7 +844,7 @@ def plot_interactive(
raise RuntimeError("Storage did not contain information about the grid")
# collect data from all time points
- timecourse: dict[str, list[np.ndarray]] = dict()
+ timecourse: dict[str, list[np.ndarray]] = {}
for field in storage:
layer_data = field._get_napari_data(**kwargs)
diff --git a/pyproject.toml b/pyproject.toml
index b10c6964..fa1b4ef8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,6 +64,28 @@ exclude = ["scripts/templates"]
[tool.ruff.format]
docstring-code-format = true
+[tool.ruff.lint]
+select = [
+ "UP", # pyupgrade
+ "I", # isort
+ "A", # flake8-builtins
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ "FA", # flake8-future-annotations
+ "ISC", # flake8-implicit-str-concat
+ "ICN", # flake8-import-conventions
+ "LOG", # flake8-logging
+ "G", # flake8-logging-format
+ "PIE", # flake8-pie
+ "PT", # flake8-pytest-style
+ "Q", # flake8-quotes
+ "RSE", # flake8-raise
+ "RET", # flake8-return
+ "SIM", # flake8-simplify
+ "PTH", # flake8-use-pathlib
+]
+ignore = ["B007", "B027", "B028", "SIM108", "ISC001", "PT006", "PT011", "RET504", "RET505", "RET506"]
+
[tool.ruff.lint.isort]
section-order = ["future", "standard-library", "third-party", "first-party", "self", "local-folder"]
diff --git a/scripts/_templates/pyproject.toml b/scripts/_templates/pyproject.toml
index 87588e6d..2160e064 100644
--- a/scripts/_templates/pyproject.toml
+++ b/scripts/_templates/pyproject.toml
@@ -55,12 +55,34 @@ namespaces = false
write_to = "pde/_version.py"
[tool.ruff]
-target-version = "py39"
+target-version = "py$MIN_PYTHON_VERSION_NODOT"
exclude = ["scripts/templates"]
[tool.ruff.format]
docstring-code-format = true
+[tool.ruff.lint]
+select = [
+ "UP", # pyupgrade
+ "I", # isort
+ "A", # flake8-builtins
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ "FA", # flake8-future-annotations
+ "ISC", # flake8-implicit-str-concat
+ "ICN", # flake8-import-conventions
+ "LOG", # flake8-logging
+ "G", # flake8-logging-format
+ "PIE", # flake8-pie
+ "PT", # flake8-pytest-style
+ "Q", # flake8-quotes
+ "RSE", # flake8-raise
+ "RET", # flake8-return
+ "SIM", # flake8-simplify
+ "PTH", # flake8-use-pathlib
+]
+ignore = ["B007", "B027", "B028", "SIM108", "ISC001", "PT006", "PT011", "RET504", "RET505", "RET506"]
+
[tool.ruff.lint.isort]
section-order = ["future", "standard-library", "third-party", "first-party", "self", "local-folder"]
diff --git a/scripts/create_requirements.py b/scripts/create_requirements.py
index 497a89a0..45fef6cc 100755
--- a/scripts/create_requirements.py
+++ b/scripts/create_requirements.py
@@ -200,7 +200,7 @@ def write_requirements_txt(
"""
print(f"Write `{path}`")
path.parent.mkdir(exist_ok=True, parents=True) # ensure path exists
- with open(path, "w") as fp:
+ with path.open("w") as fp:
if comment:
fp.write(f"# {comment}\n")
if ref_base:
@@ -221,7 +221,7 @@ def write_requirements_csv(
requirements (list): The requirements to be written
"""
print(f"Write `{path}`")
- with open(path, "w") as fp:
+ with path.open("w") as fp:
writer = csv.writer(fp)
if incl_version:
writer.writerow(["Package", "Minimal version", "Usage"])
@@ -245,7 +245,7 @@ def write_requirements_py(path: Path, requirements: list[Requirement]):
# read user-created content of file
content = []
- with open(path) as fp:
+ with path.open() as fp:
for line in fp:
if "GENERATED CODE" in line:
content.append(line)
@@ -260,7 +260,7 @@ def write_requirements_py(path: Path, requirements: list[Requirement]):
content.append("del check_package_version\n")
# write content back to file
- with open(path, "w") as fp:
+ with path.open("w") as fp:
fp.writelines(content)
@@ -309,7 +309,7 @@ def write_from_template(
content = template.substitute(substitutes)
# write content to file
- with open(path, "w") as fp:
+ with path.open("w") as fp:
if add_warning:
fp.writelines(SETUP_WARNING.format(template_name))
fp.writelines(content)
diff --git a/scripts/format_code.sh b/scripts/format_code.sh
index d608f699..8608a544 100755
--- a/scripts/format_code.sh
+++ b/scripts/format_code.sh
@@ -7,7 +7,7 @@ find . -name '*.py' -exec pyupgrade --py39-plus {} +
popd > /dev/null
echo "Formating import statements..."
-ruff check --select I --fix --config=../pyproject.toml ..
+ruff check --fix --config=../pyproject.toml ..
echo "Formating docstrings..."
docformatter --in-place --black --recursive ..
diff --git a/scripts/performance_laplace.py b/scripts/performance_laplace.py
index 6b0759a2..68201176 100755
--- a/scripts/performance_laplace.py
+++ b/scripts/performance_laplace.py
@@ -121,7 +121,7 @@ def laplace(arr, out=None):
if out is None:
out = np.empty((dim_r, dim_z))
- for j in range(0, dim_z): # iterate axial points
+ for j in range(dim_z): # iterate axial points
jm = 0 if j == 0 else j - 1
jp = dim_z - 1 if j == dim_z - 1 else j + 1
diff --git a/scripts/run_tests.py b/scripts/run_tests.py
index c4f65a69..4e49ccfb 100755
--- a/scripts/run_tests.py
+++ b/scripts/run_tests.py
@@ -121,7 +121,7 @@ def run_unit_tests(
coverage: bool = False,
nojit: bool = False,
pattern: str = None,
- pytest_args: list[str] = [],
+ pytest_args: list[str] | None = None,
) -> int:
"""Run the unit tests.
@@ -140,6 +140,8 @@ def run_unit_tests(
Returns:
int: The return code indicating success or failure
"""
+ if pytest_args is None:
+ pytest_args = []
# modify current environment
env = os.environ.copy()
env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + env.get("PYTHONPATH", "")
@@ -177,8 +179,10 @@ def run_unit_tests(
if use_mpi:
try:
import numba_mpi # @UnusedImport
- except ImportError:
- raise RuntimeError("Moduled `numba_mpi` is required to test with MPI")
+ except ImportError as err:
+ raise RuntimeError(
+ "Moduled `numba_mpi` is required to test with MPI"
+ ) from err
args.append("--use_mpi") # only run tests requiring MPI multiprocessing
# run tests using multiple cores?
diff --git a/tests/conftest.py b/tests/conftest.py
index 775122f0..4f93a2b0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,8 +11,8 @@
from pde.tools.numba import random_seed
-@pytest.fixture(scope="function", autouse=True)
-def setup_and_teardown():
+@pytest.fixture(autouse=True)
+def _setup_and_teardown():
"""Helper function adjusting environment before and after tests."""
# ensure we use the Agg backend, so figures are not displayed
plt.switch_backend("agg")
@@ -26,7 +26,7 @@ def setup_and_teardown():
plt.close("all")
-@pytest.fixture(scope="function", autouse=False, name="rng")
+@pytest.fixture(autouse=False, name="rng")
def init_random_number_generators():
"""Get a random number generator and set the seed of the random number generator.
diff --git a/tests/fields/test_field_collections.py b/tests/fields/test_field_collections.py
index 32c72d84..4ceecfa7 100644
--- a/tests/fields/test_field_collections.py
+++ b/tests/fields/test_field_collections.py
@@ -131,7 +131,8 @@ def test_collections_append():
c1 = FieldCollection([sf], labels=["scalar"])
c2 = c1.append(vf)
- assert len(c2) == 2 and len(c1) == 1
+ assert len(c2) == 2
+ assert len(c1) == 1
data = np.r_[np.zeros(4), np.ones(8)]
np.testing.assert_allclose(c2.data.flat, data)
@@ -141,7 +142,8 @@ def test_collections_append():
assert list(c2.labels) == ["scalar", "vector"]
c3 = c1.append(c1, label="new")
- assert len(c3) == 2 and len(c1) == 1
+ assert len(c3) == 2
+ assert len(c1) == 1
np.testing.assert_allclose(c3.data.flat, np.zeros(8))
assert c1.data is not c3.data
@@ -150,7 +152,8 @@ def test_collections_append():
assert c3.label == "new"
c4 = c1.append(c1, vf)
- assert len(c4) == 3 and len(c1) == 1
+ assert len(c4) == 3
+ assert len(c1) == 1
data = np.r_[np.zeros(8), np.ones(8)]
np.testing.assert_allclose(c4.data.flat, data)
diff --git a/tests/fields/test_generic_fields.py b/tests/fields/test_generic_fields.py
index f3858992..9c395f7a 100644
--- a/tests/fields/test_generic_fields.py
+++ b/tests/fields/test_generic_fields.py
@@ -510,7 +510,7 @@ def test_smoothing(rng):
np.testing.assert_allclose(out.data, expected)
out.data = 0 # reset data
- f1.smooth(sigma, out=out).data
+ print(f1.smooth(sigma, out=out).data)
np.testing.assert_allclose(out.data, expected)
# test one simple higher order smoothing
diff --git a/tests/fields/test_tensorial_fields.py b/tests/fields/test_tensorial_fields.py
index 238ae396..fdc75552 100644
--- a/tests/fields/test_tensorial_fields.py
+++ b/tests/fields/test_tensorial_fields.py
@@ -188,7 +188,8 @@ def test_complex_tensors(backend, rng):
numbers = rng.random(shape) + rng.random(shape) * 1j
t1 = Tensor2Field(grid, numbers[0])
t2 = Tensor2Field(grid, numbers[1])
- assert t1.is_complex and t2.is_complex
+ assert t1.is_complex
+ assert t2.is_complex
dot_op = t1.make_dot_operator(backend)
diff --git a/tests/fields/test_vectorial_fields.py b/tests/fields/test_vectorial_fields.py
index 8dfd0ac0..1bbfc532 100644
--- a/tests/fields/test_vectorial_fields.py
+++ b/tests/fields/test_vectorial_fields.py
@@ -260,7 +260,8 @@ def test_complex_vectors(rng):
numbers = rng.random(shape) + rng.random(shape) * 1j
v1 = VectorField(grid, numbers[0])
v2 = VectorField(grid, numbers[1])
- assert v1.is_complex and v2.is_complex
+ assert v1.is_complex
+ assert v2.is_complex
for backend in ["numpy", "numba"]:
dot_op = v1.make_dot_operator(backend)
diff --git a/tests/grids/boundaries/test_axes_boundaries.py b/tests/grids/boundaries/test_axes_boundaries.py
index 19ec8693..2a889488 100644
--- a/tests/grids/boundaries/test_axes_boundaries.py
+++ b/tests/grids/boundaries/test_axes_boundaries.py
@@ -121,7 +121,8 @@ def test_bc_values():
"""Test setting the values of boundary conditions."""
g = UnitGrid([5])
bc = g.get_boundary_conditions([{"value": 2}, {"derivative": 3}])
- assert bc[0].low.value == 2 and bc[0].high.value == 3
+ assert bc[0].low.value == 2
+ assert bc[0].high.value == 3
@pytest.mark.parametrize("dim", [1, 2, 3])
diff --git a/tests/grids/boundaries/test_axis_boundaries.py b/tests/grids/boundaries/test_axis_boundaries.py
index 4ead63da..2e18cefc 100644
--- a/tests/grids/boundaries/test_axis_boundaries.py
+++ b/tests/grids/boundaries/test_axis_boundaries.py
@@ -36,9 +36,11 @@ def test_boundary_pair():
data = {"low": {"value": 1}, "high": {"derivative": 2}}
bc1 = BoundaryPair.from_data(g, 0, data)
bc2 = BoundaryPair.from_data(g, 0, data)
- assert bc1 == bc2 and bc1 is not bc2
+ assert bc1 == bc2
+ assert bc1 is not bc2
bc2 = BoundaryPair.from_data(g, 1, data)
- assert bc1 != bc2 and bc1 is not bc2
+ assert bc1 != bc2
+ assert bc1 is not bc2
# miscellaneous methods
data = {"low": {"value": 0}, "high": {"derivative": 0}}
@@ -59,7 +61,8 @@ def test_get_axis_boundaries():
b = get_boundary_axis(g, 0, data)
assert str(b) == '"' + data + '"'
b1, b2 = b.get_mathematical_representation("field")
- assert "field" in b1 and "field" in b2
+ assert "field" in b1
+ assert "field" in b2
if "periodic" in data:
assert b.periodic
diff --git a/tests/grids/boundaries/test_local_boundaries.py b/tests/grids/boundaries/test_local_boundaries.py
index 315d838b..9cddd945 100644
--- a/tests/grids/boundaries/test_local_boundaries.py
+++ b/tests/grids/boundaries/test_local_boundaries.py
@@ -343,7 +343,7 @@ def func(adjacent_value, dx, x, y, t):
np.testing.assert_almost_equal(f_ref._data_full, f2._data_full)
-@pytest.mark.parametrize("value_expr, const_expr", [["1", "1"], ["x", "y**2"]])
+@pytest.mark.parametrize("value_expr, const_expr", [("1", "1"), ("x", "y**2")])
def test_expression_bc_setting_mixed(value_expr, const_expr, rng):
"""Test boundary conditions that use an expression."""
grid = CartesianGrid([[0, 1], [0, 1]], 4)
@@ -529,12 +529,12 @@ def test_expression_bc_specific_value(dim, compiled):
if compiled:
def set_bcs():
- bcs.make_ghost_cell_setter()(field._data_full)
+ bcs.make_ghost_cell_setter()(field._data_full) # noqa: B023
else:
def set_bcs():
- field.set_ghost_cells(bcs)
+ field.set_ghost_cells(bcs) # noqa: B023
if i < -n or i > n - 1:
# check ut-of-bounds errors
diff --git a/tests/grids/operators/test_cylindrical_operators.py b/tests/grids/operators/test_cylindrical_operators.py
index 28364671..7ad65b65 100644
--- a/tests/grids/operators/test_cylindrical_operators.py
+++ b/tests/grids/operators/test_cylindrical_operators.py
@@ -235,7 +235,7 @@ def test_examples_tensor_cyl():
np.testing.assert_allclose(res.data, expect.data, rtol=0.1, atol=0.1)
-@pytest.mark.parametrize("r_inner", (0, 1))
+@pytest.mark.parametrize("r_inner", [0, 1])
def test_laplace_matrix(r_inner, rng):
"""Test laplace operator implemented using matrix multiplication."""
grid = CylindricalSymGrid((r_inner, 2), (2.5, 4.3), 16)
@@ -253,7 +253,7 @@ def test_laplace_matrix(r_inner, rng):
np.testing.assert_allclose(res1.data, res2)
-@pytest.mark.parametrize("r_inner", (0, 1))
+@pytest.mark.parametrize("r_inner", [0, 1])
def test_poisson_solver_cylindrical(r_inner, rng):
"""Test the poisson solver on Cylindrical grids."""
grid = CylindricalSymGrid((r_inner, 2), (2.5, 4.3), 16)
diff --git a/tests/grids/operators/test_polar_operators.py b/tests/grids/operators/test_polar_operators.py
index 48ecf1a6..cda01c90 100644
--- a/tests/grids/operators/test_polar_operators.py
+++ b/tests/grids/operators/test_polar_operators.py
@@ -95,7 +95,7 @@ def test_grid_laplace_polar():
np.testing.assert_allclose(b_1d_2.data[i, i], b_2d.data[i, i], rtol=0.2, atol=0.2)
-@pytest.mark.parametrize("r_inner", (0, 2 * np.pi))
+@pytest.mark.parametrize("r_inner", [0, 2 * np.pi])
def test_gradient_squared_polar(r_inner):
"""Compare gradient squared operator."""
grid = PolarSymGrid((r_inner, 4 * np.pi), 32)
@@ -194,7 +194,7 @@ def test_examples_tensor_polar():
np.testing.assert_allclose(res.data, expect.data, rtol=0.1, atol=0.1)
-@pytest.mark.parametrize("r_inner", (0, 1))
+@pytest.mark.parametrize("r_inner", [0, 1])
def test_laplace_matrix(r_inner, rng):
"""Test laplace operator implemented using matrix multiplication."""
grid = PolarSymGrid((r_inner, 2), 16)
diff --git a/tests/grids/operators/test_spherical_operators.py b/tests/grids/operators/test_spherical_operators.py
index 512d3029..d3ebc625 100644
--- a/tests/grids/operators/test_spherical_operators.py
+++ b/tests/grids/operators/test_spherical_operators.py
@@ -115,7 +115,7 @@ def test_grid_laplace():
)
-@pytest.mark.parametrize("r_inner", (0, 1))
+@pytest.mark.parametrize("r_inner", [0, 1])
def test_gradient_squared(r_inner, rng):
"""Compare gradient squared operator."""
grid = SphericalSymGrid((r_inner, 5), 64)
@@ -288,7 +288,7 @@ def test_tensor_div_div(conservative):
np.testing.assert_allclose(res.data[2:-2], est.data[2:-2], rtol=0.02, atol=1)
-@pytest.mark.parametrize("r_inner", (0, 1))
+@pytest.mark.parametrize("r_inner", [0, 1])
def test_laplace_matrix(r_inner, rng):
"""Test laplace operator implemented using matrix multiplication."""
grid = SphericalSymGrid((r_inner, 2), 16)
diff --git a/tests/pdes/test_pde_class.py b/tests/pdes/test_pde_class.py
index cf5b050b..4796d26a 100644
--- a/tests/pdes/test_pde_class.py
+++ b/tests/pdes/test_pde_class.py
@@ -3,7 +3,8 @@
"""
import logging
-
+import os
+import numba as nb
import numpy as np
import pytest
import sympy
@@ -188,7 +189,6 @@ def test_pde_noise(backend, rng):
with pytest.raises(ValueError):
eq = PDE({"a": 0}, noise=[0.01, 2.0])
- eq.solve(ScalarField(grid), t_range=1, backend=backend, dt=1, tracker=None)
@pytest.mark.parametrize("backend", ["numpy", "numba"])
@@ -207,9 +207,12 @@ def test_pde_spatial_args(backend):
# test invalid spatial dependence
eq = PDE({"a": "x + y"})
- with pytest.raises(RuntimeError):
- rhs = eq.make_pde_rhs(field, backend=backend)
- rhs(field.data, 0.0)
+ if backend == "numpy":
+ with pytest.raises(RuntimeError):
+ eq.evolution_rate(field)
+ elif backend == "numba":
+ with pytest.raises(RuntimeError):
+ rhs = eq.make_pde_rhs(field, backend=backend)
def test_pde_user_funcs(rng):
@@ -284,11 +287,15 @@ def test_pde_consts():
np.testing.assert_allclose(eq.evolution_rate(field).data, 0)
eq = PDE({"a": "laplace(b)"}, consts={"b": 3})
- with pytest.raises(Exception):
+ with pytest.raises(
+ AttributeError
+ if os.environ.get("NUMBA_DISABLE_JIT", "0") == "1"
+ else nb.TypingError
+ ):
eq.evolution_rate(field)
eq = PDE({"a": "laplace(b)"}, consts={"b": field.data})
- with pytest.raises(Exception):
+ with pytest.raises(TypeError):
eq.evolution_rate(field)
diff --git a/tests/solvers/test_explicit_solvers.py b/tests/solvers/test_explicit_solvers.py
index 4b9da38f..822b1f8e 100644
--- a/tests/solvers/test_explicit_solvers.py
+++ b/tests/solvers/test_explicit_solvers.py
@@ -122,9 +122,9 @@ def test_stochastic_adaptive_solver(caplog, rng):
field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng)
eq = DiffusionPDE(noise=1e-6)
+ solver = ExplicitSolver(eq, backend="numpy", adaptive=True)
+ c = Controller(solver, t_range=1, tracker=None)
with pytest.raises(RuntimeError):
- solver = ExplicitSolver(eq, backend="numpy", adaptive=True)
- c = Controller(solver, t_range=1, tracker=None)
c.run(field, dt=1e-2)
diff --git a/tests/storage/test_movie_storages.py b/tests/storage/test_movie_storages.py
index 755271c9..0ea68901 100644
--- a/tests/storage/test_movie_storages.py
+++ b/tests/storage/test_movie_storages.py
@@ -177,13 +177,15 @@ def test_complex_data(tmp_path, rng):
@pytest.mark.skipif(not module_available("ffmpeg"), reason="requires `ffmpeg-python`")
def test_wrong_format():
"""Test how wrong files are dealt with."""
+ from ffmpeg._run import Error as FFmpegError
+
reader = MovieStorage(RESOURCES_PATH / "does_not_exist.avi")
with pytest.raises(OSError):
- reader.times
+ print(reader.times)
reader = MovieStorage(RESOURCES_PATH / "empty.avi")
- with pytest.raises(Exception):
- reader.times
+ with pytest.raises(FFmpegError):
+ print(reader.times)
reader = MovieStorage(RESOURCES_PATH / "no_metadata.avi")
np.testing.assert_allclose(reader.times, [0, 1])
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 630fe9f0..8d9cbab0 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -53,14 +53,14 @@ def test_example_scripts(path):
# delete files that might be created by the test
try:
- os.remove(PACKAGE_PATH / "diffusion.mov")
- os.remove(PACKAGE_PATH / "allen_cahn.avi")
- os.remove(PACKAGE_PATH / "allen_cahn.hdf")
+ (PACKAGE_PATH / "diffusion.mov").unlink()
+ (PACKAGE_PATH / "allen_cahn.avi").unlink()
+ (PACKAGE_PATH / "allen_cahn.hdf").unlink()
except OSError:
pass
# prepare output
- msg = "Script `%s` failed with following output:" % path
+ msg = f"Script `{path}` failed with following output:"
if outs:
msg = f"{msg}\nSTDOUT:\n{outs}"
if errs:
@@ -86,7 +86,7 @@ def test_jupyter_notebooks(path, tmp_path):
my_env = os.environ.copy()
my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "")
- outfile = tmp_path / os.path.basename(path)
+ outfile = tmp_path / path.name
if jupyter_notebook.__version__.startswith("6"):
# older version of running jypyter notebook
# deprecated on 2023-07-31
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 51f751c5..f538d9ce 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -254,7 +254,8 @@ def test_modelrunner_storage_one(tmp_path, capsys):
assert output.is_file()
print("=" * 40)
- print(open(output).read())
+ with Path(output).open() as fp:
+ print(fp.read())
print("=" * 40)
# read storage manually
@@ -298,9 +299,9 @@ def test_modelrunner_storage_many(tmp_path):
for path in tmp_path.iterdir():
if path.is_file() and not path.name.endswith("txt"):
with mr.open_storage(path) as storage:
- assert "initial_state" in storage["storage"].keys()
- assert "trajectory" in storage["storage"].keys()
- assert "result" in storage.keys()
+ assert "initial_state" in storage["storage"]
+ assert "trajectory" in storage["storage"]
+ assert "result" in storage
# read result using ResultCollection
results = mr.ResultCollection.from_folder(tmp_path)
diff --git a/tests/tools/test_cache.py b/tests/tools/test_cache.py
index b94605fe..4e961651 100644
--- a/tests/tools/test_cache.py
+++ b/tests/tools/test_cache.py
@@ -365,7 +365,7 @@ def cached_kwarg(self, a=0, b=0):
assert method(1) == 1
assert obj.counter == 3
else:
- raise ValueError("Unknown cache_factory `%s`" % cache_factory)
+ raise ValueError(f"Unknown cache_factory `{cache_factory}`")
obj.counter = 0
# clear cache to test the second run
diff --git a/tests/tools/test_config.py b/tests/tools/test_config.py
index ac9c01fe..9cd7a94e 100644
--- a/tests/tools/test_config.py
+++ b/tests/tools/test_config.py
@@ -19,8 +19,8 @@ def test_config():
assert c["numba.multithreading_threshold"] > 0
assert "numba.multithreading_threshold" in c
- assert any("numba.multithreading_threshold" == k for k in c)
- assert any("numba.multithreading_threshold" == k and v > 0 for k, v in c.items())
+ assert any(k == "numba.multithreading_threshold" for k in c)
+ assert any(k == "numba.multithreading_threshold" and v > 0 for k, v in c.items())
assert "numba.multithreading_threshold" in c.to_dict()
assert isinstance(repr(c), str)
@@ -90,4 +90,5 @@ def test_packages_from_requirements():
"""Test the packages_from_requirements function."""
results = packages_from_requirements("file_not_existing")
assert len(results) == 1
- assert "Could not open" in results[0] and "file_not_existing" in results[0]
+ assert "Could not open" in results[0]
+ assert "file_not_existing" in results[0]
diff --git a/tests/tools/test_cuboid.py b/tests/tools/test_cuboid.py
index b7b4f004..19357c59 100644
--- a/tests/tools/test_cuboid.py
+++ b/tests/tools/test_cuboid.py
@@ -58,8 +58,8 @@ def test_cuboid_2d():
np.testing.assert_array_equal(c.contains_point([[1, 3], [3, 1]]), [False, False])
np.testing.assert_array_equal(c.contains_point([[1, -1], [-1, 1]]), [False, False])
+ c.mutable = False
with pytest.raises(ValueError):
- c.mutable = False
c.centroid = [0, 0]
# test surface area
diff --git a/tests/tools/test_expressions.py b/tests/tools/test_expressions.py
index ba583eb3..7d95d0b2 100644
--- a/tests/tools/test_expressions.py
+++ b/tests/tools/test_expressions.py
@@ -60,7 +60,7 @@ def test_const(expr):
assert e.get_compiled()() == val
assert not e.depends_on("a")
assert e.differentiate("a").value == 0
- assert e.shape == tuple()
+ assert e.shape == ()
assert e.rank == 0
assert bool(e) == (val != 0)
assert e.is_zero == (val == 0)
@@ -89,7 +89,7 @@ def test_wrong_const(caplog):
assert e() == field
assert "field" in caplog.text
if not nb.config.DISABLE_JIT: # @UndefinedVariable
- with pytest.raises(Exception):
+ with pytest.raises(nb.TypingError):
e.get_compiled()()
@@ -102,14 +102,14 @@ def test_single_arg(rng):
assert e.get_compiled()(4) == 8
assert e.differentiate("a").value == 2
assert e.differentiate("b").value == 0
- assert e.shape == tuple()
+ assert e.shape == ()
assert e.rank == 0
assert bool(e)
assert not e.is_zero
assert e == ScalarExpression(e.expression)
with pytest.raises(TypeError):
- e.value
+ print(e.value)
arr = rng.random(5)
np.testing.assert_allclose(e(arr), 2 * arr)
@@ -135,7 +135,7 @@ def test_two_args(rng):
assert e.differentiate("a")(4, 2) == 16
assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4))
assert e.differentiate("c").value == 0
- assert e.shape == tuple()
+ assert e.shape == ()
assert e.rank == 0
assert e == ScalarExpression(e.expression)
@@ -160,7 +160,8 @@ def test_two_args(rng):
def test_derivatives():
"""Test vector expressions."""
e = ScalarExpression("a * b**2")
- assert e.depends_on("a") and e.depends_on("b")
+ assert e.depends_on("a")
+ assert e.depends_on("b")
assert not e.constant
assert e.rank == 0
@@ -199,7 +200,7 @@ def test_indexed():
with pytest.raises(RuntimeError):
e.differentiate("a")
with pytest.raises(RuntimeError):
- e.derivatives
+ print(e.derivatives)
def test_synonyms(caplog):
@@ -217,7 +218,7 @@ def test_tensor_expression():
assert e.rank == 2
assert e.constant
np.testing.assert_allclose(e.get_compiled_array()(), [[0, 1], [2, 3]])
- np.testing.assert_allclose(e.get_compiled_array()(tuple()), [[0, 1], [2, 3]])
+ np.testing.assert_allclose(e.get_compiled_array()(()), [[0, 1], [2, 3]])
assert e.differentiate("a") == TensorExpression("[[0, 0], [0, 0]]")
np.testing.assert_allclose(e.value, np.arange(4).reshape(2, 2))
@@ -229,7 +230,7 @@ def test_tensor_expression():
assert not e.constant
np.testing.assert_allclose(e.differentiate("a").value, np.array([1, 2]))
with pytest.raises(TypeError):
- e.value
+ print(e.value)
assert e[0] == ScalarExpression("a")
assert e[1] == ScalarExpression("2*a")
assert e[0:1] == TensorExpression("[a]")
@@ -331,7 +332,8 @@ def test_expression_consts():
expr = ScalarExpression("a + b", consts={"a": 1})
assert not expr.constant
- assert not expr.depends_on("a") and expr.depends_on("b")
+ assert not expr.depends_on("a")
+ assert expr.depends_on("b")
assert expr(2) == 3
assert expr.get_compiled()(2) == 3
@@ -433,8 +435,8 @@ def test_evaluate_func_collection():
assert isinstance(evaluate("gradient(a)", col), VectorField)
assert isinstance(evaluate("a * v", col), VectorField)
+ col.labels = ["a", "a"]
with pytest.raises(RuntimeError):
- col.labels = ["a", "a"]
evaluate("1", col)
diff --git a/tests/tools/test_misc.py b/tests/tools/test_misc.py
index 2f8fefd6..be655174 100644
--- a/tests/tools/test_misc.py
+++ b/tests/tools/test_misc.py
@@ -4,6 +4,7 @@
import json
import os
+from pathlib import Path
import numpy as np
import pytest
@@ -23,7 +24,7 @@ def test_ensure_directory_exists(tmp_path):
misc.ensure_directory_exists(path)
assert path.is_dir()
# remove the folder again
- os.rmdir(path)
+ Path.rmdir(path)
assert not path.exists()
@@ -105,8 +106,7 @@ def test_hdf_write_attributes(tmp_path):
assert data2 == {"a": 1}
# test raising problematic items
- with h5py.File(path, "w") as hdf_file:
- with pytest.raises(TypeError):
- misc.hdf_write_attributes(
- hdf_file, {"a": object()}, raise_serialization_error=True
- )
+ with h5py.File(path, "w") as hdf_file, pytest.raises(TypeError):
+ misc.hdf_write_attributes(
+ hdf_file, {"a": object()}, raise_serialization_error=True
+ )
diff --git a/tests/tools/test_parameters.py b/tests/tools/test_parameters.py
index 9647cf12..837cca76 100644
--- a/tests/tools/test_parameters.py
+++ b/tests/tools/test_parameters.py
@@ -108,7 +108,10 @@ class TestHelp2(TestHelp1):
t = TestHelp2()
for in_jupyter in [False, True]:
- monkeypatch.setattr("pde.tools.output.in_jupyter_notebook", lambda: in_jupyter)
+ monkeypatch.setattr(
+ "pde.tools.output.in_jupyter_notebook",
+ lambda: in_jupyter, # noqa: B023
+ )
for flags in itertools.combinations_with_replacement([True, False], 3):
TestHelp2.show_parameters(*flags)
diff --git a/tests/tools/test_spectral.py b/tests/tools/test_spectral.py
index a96f07dd..b964c91e 100644
--- a/tests/tools/test_spectral.py
+++ b/tests/tools/test_spectral.py
@@ -67,18 +67,17 @@ def test_noise_scaling(rng):
def noise_div():
return div(rng.normal(size=shape))
+ def get_noise(noise_func):
+ k, density = spectral_density(data=noise_func(), dx=grid.discretization[0])
+ assert k[0] == 0
+ assert density[0] == pytest.approx(0)
+ return np.log(density[1]) # log of spectral density
+
# calculate spectral densities of the two noises
result = []
for noise_func in [noise_colored, noise_div]:
-
- def get_noise():
- k, density = spectral_density(data=noise_func(), dx=grid.discretization[0])
- assert k[0] == 0
- assert density[0] == pytest.approx(0)
- return np.log(density[1]) # log of spectral density
-
# average spectral density of longest length scale
- mean = np.mean([get_noise() for _ in range(64)])
+ mean = np.mean([get_noise(noise_func=noise_func) for _ in range(64)])
result.append(mean)
np.testing.assert_allclose(*result, rtol=0.5)
diff --git a/tests/trackers/test_trackers.py b/tests/trackers/test_trackers.py
index de4a7920..aec51478 100644
--- a/tests/trackers/test_trackers.py
+++ b/tests/trackers/test_trackers.py
@@ -76,23 +76,21 @@ def store_time(state, t):
def get_sparse_matrix_data(state):
return {"integral": state.integral}
- devnull = open(os.devnull, "w")
- data = trackers.DataTracker(get_sparse_matrix_data, interrupts=0.1)
- tracker_list = [
- trackers.PrintTracker(interrupts=0.1, stream=devnull),
- trackers.CallbackTracker(store_time, interrupts=0.1),
- None, # should be ignored
- data,
- ]
- if module_available("matplotlib"):
- tracker_list.append(trackers.PlotTracker(interrupts=0.1, show=False))
-
- grid = UnitGrid([16, 16])
- state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng)
- eq = DiffusionPDE()
- eq.solve(state, t_range=1, dt=0.005, tracker=tracker_list)
-
- devnull.close()
+ with open(os.devnull, "w") as devnull: # noqa: PTH123
+ data = trackers.DataTracker(get_sparse_matrix_data, interrupts=0.1)
+ tracker_list = [
+ trackers.PrintTracker(interrupts=0.1, stream=devnull),
+ trackers.CallbackTracker(store_time, interrupts=0.1),
+ None, # should be ignored
+ data,
+ ]
+ if module_available("matplotlib"):
+ tracker_list.append(trackers.PlotTracker(interrupts=0.1, show=False))
+
+ grid = UnitGrid([16, 16])
+ state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng)
+ eq = DiffusionPDE()
+ eq.solve(state, t_range=1, dt=0.005, tracker=tracker_list)
assert times == data.times
if module_available("pandas"):