diff --git a/docs/source/conf.py b/docs/source/conf.py index a17e4567..cbe98263 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,9 +15,9 @@ import sys sys.path.insert(0, ".") -sys.path.insert(0, os.path.abspath("../..")) -sys.path.insert(0, os.path.abspath("../../scripts")) -sys.path.insert(0, os.path.abspath("../sphinx_ext/")) +sys.path.insert(0, os.path.abspath("../..")) # noqa: PTH100 +sys.path.insert(0, os.path.abspath("../../scripts")) # noqa: PTH100 +sys.path.insert(0, os.path.abspath("../sphinx_ext/")) # noqa: PTH100 from datetime import date @@ -26,7 +26,7 @@ project = "py-pde" module_name = "pde" author = "Zwicker Group" -copyright = f"{date.today().year}, {author}" # @ReservedAssignment +copyright = f"{date.today().year}, {author}" # @ReservedAssignment # noqa: A001 html_logo = "_images/logo_small.png" # Determine the version from the actual package diff --git a/docs/source/manual/contributing.rst b/docs/source/manual/contributing.rst index 28a6ad73..c1f182cd 100644 --- a/docs/source/manual/contributing.rst +++ b/docs/source/manual/contributing.rst @@ -55,8 +55,9 @@ where ``vector_component`` is either 0 or 1. Coding style """""""""""" -The coding style is enforced using `isort `_ -and `black `_. Moreover, we use `Google Style docstrings +The coding style is enforced using `ruff `_, based on the +styles suggest by `isort `_ and +`black `_. Moreover, we use `Google Style docstrings `_, which might be best `learned by example `_. diff --git a/docs/source/run_autodoc.py b/docs/source/run_autodoc.py index ae7fc760..c810af80 100755 --- a/docs/source/run_autodoc.py +++ b/docs/source/run_autodoc.py @@ -4,6 +4,7 @@ import logging import os import subprocess as sp +from pathlib import Path logging.basicConfig(level=logging.INFO) @@ -31,21 +32,21 @@ def replace_in_file(infile, replacements, outfile=None): if outfile is None: outfile = infile - with open(infile) as fp: + with Path(infile).open() as fp: content = fp.read() for key, value in replacements.items(): content = content.replace(key, value) - with open(outfile, "w") as fp: + with Path(outfile).open("w") as fp: fp.write(content) def main(): # remove old files - for path in glob.glob(f"{OUTPUT_PATH}/*.rst"): + for path in Path(OUTPUT_PATH).glob("*.rst"): logging.info("Remove file `%s`", path) - os.remove(path) + path.unlink() # run sphinx-apidoc sp.check_call( @@ -65,7 +66,7 @@ def main(): ) # replace unwanted information - for path in glob.glob(f"{OUTPUT_PATH}/*.rst"): + for path in Path(OUTPUT_PATH).glob("*.rst"): logging.info("Patch file `%s`", path) replace_in_file(path, REPLACEMENTS) diff --git a/examples/pde_brusselator_class.py b/examples/pde_brusselator_class.py index 47ffce9a..c0345500 100644 --- a/examples/pde_brusselator_class.py +++ b/examples/pde_brusselator_class.py @@ -27,11 +27,11 @@ class BrusselatorPDE(PDEBase): """Brusselator with diffusive mobility.""" - def __init__(self, a=1, b=3, diffusivity=[1, 0.1], bc="auto_periodic_neumann"): + def __init__(self, a=1, b=3, diffusivity=None, bc="auto_periodic_neumann"): super().__init__() self.a = a self.b = b - self.diffusivity = diffusivity # spatial mobility + self.diffusivity = [1, 0.1] if diffusivity is None else diffusivity self.bc = bc # boundary condition def get_initial_state(self, grid): diff --git a/pde/__init__.py b/pde/__init__.py index 1f0621bb..bfe5c750 100644 --- a/pde/__init__.py +++ b/pde/__init__.py @@ -4,7 +4,7 @@ # determine the package version try: # try reading version of the automatically generated module - from ._version import __version__ # type: ignore + from ._version import __version__ except ImportError: # determine version automatically from CVS information from importlib.metadata import PackageNotFoundError, version @@ -20,7 +20,8 @@ from .tools.config import Config, environment config = Config() # initialize the default configuration -del Config # clean name space + +import contextlib # import all other modules that should occupy the main name space from .fields import * # @UnusedWildImport @@ -32,7 +33,7 @@ from .trackers import * # @UnusedWildImport from .visualization import * # @UnusedWildImport -try: +with contextlib.suppress(ImportError): from .tools.modelrunner import * -except ImportError: - pass # modelrunner extensions are simply not loaded + +del contextlib, Config # clean name space diff --git a/pde/fields/datafield_base.py b/pde/fields/datafield_base.py index e48c0778..961689df 100644 --- a/pde/fields/datafield_base.py +++ b/pde/fields/datafield_base.py @@ -1131,7 +1131,7 @@ def get_vector_data(self, transpose: bool = False, **kwargs) -> dict[str, Any]: Returns: dict: Information useful for plotting an vector field """ - raise NotImplementedError() + raise NotImplementedError def _plot_line( self, diff --git a/pde/fields/scalar.py b/pde/fields/scalar.py index 7f8708d6..215582b1 100644 --- a/pde/fields/scalar.py +++ b/pde/fields/scalar.py @@ -372,7 +372,7 @@ def slice( raise ValueError( f"The axes {ax} is not contained in " f"{self.grid} with axes {self.grid.axes}" - ) + ) from None ax_remove.append(i) # check the position diff --git a/pde/fields/tensorial.py b/pde/fields/tensorial.py index 62e7aa79..9118c892 100644 --- a/pde/fields/tensorial.py +++ b/pde/fields/tensorial.py @@ -124,8 +124,8 @@ def _get_axes_index(self, key: tuple[int | str, int | str]) -> tuple[int, int]: try: if len(key) != 2: raise IndexError("Index must be given as two integers") - except TypeError: - raise IndexError("Index must be given as two values") + except TypeError as err: + raise IndexError("Index must be given as two values") from err return tuple(self.grid.get_axis_index(k) for k in key) # type: ignore def __getitem__(self, key: tuple[int | str, int | str]) -> ScalarField: diff --git a/pde/fields/vectorial.py b/pde/fields/vectorial.py index 272ae2bf..e6c4c16f 100644 --- a/pde/fields/vectorial.py +++ b/pde/fields/vectorial.py @@ -317,8 +317,8 @@ def check_rank(arr: nb.types.Type | nb.types.Optional) -> None: @register_jitable def calc(a: np.ndarray, b: np.ndarray, out: np.ndarray) -> np.ndarray: """Calculate outer product between fields `a` and `b`""" - for i in range(0, dim): - for j in range(0, dim): + for i in range(dim): + for j in range(dim): out[i, j, :] = a[i] * b[j] return out diff --git a/pde/grids/base.py b/pde/grids/base.py index f366bef0..f17a33fd 100644 --- a/pde/grids/base.py +++ b/pde/grids/base.py @@ -1021,7 +1021,7 @@ def contains_point( the grid """ cell_coords = self.transform(points, source=coords, target="cell", full=full) - return np.all((0 <= cell_coords) & (cell_coords <= self.shape), axis=-1) # type: ignore + return np.all((cell_coords >= 0) & (cell_coords <= self.shape), axis=-1) # type: ignore def iter_mirror_points( self, point: np.ndarray, with_self: bool = False, only_periodic: bool = True @@ -1222,6 +1222,7 @@ def register_operator(factor_func_arg: OperatorFactory): else: # method is used directly register_operator(factory_func) + return None @hybridmethod # type: ignore @property diff --git a/pde/grids/boundaries/axes.py b/pde/grids/boundaries/axes.py index cd4e743d..a323d0e3 100644 --- a/pde/grids/boundaries/axes.py +++ b/pde/grids/boundaries/axes.py @@ -264,7 +264,7 @@ def __setitem__(self, index, data) -> None: else: # handle all other cases, in particular integer indices - return super().__setitem__(index, data) + super().__setitem__(index, data) def get_mathematical_representation(self, field_name: str = "C") -> str: """Return mathematical representation of the boundary condition.""" @@ -303,7 +303,7 @@ def set_ghost_cells( for i, j in itertools.product([0, -1], [0, -1]): d[..., i, j] = (d[..., nxt[i], j] + d[..., i, nxt[j]]) / 2 - elif self.grid.num_axes >= 3: + elif self.grid.num_axes == 3: # iterate all edges for i, j in itertools.product([0, -1], [0, -1]): d[..., :, i, j] = (+d[..., :, nxt[i], j] + d[..., :, i, nxt[j]]) / 2 @@ -316,8 +316,9 @@ def set_ghost_cells( + d[..., i, nxt[j], k] + d[..., i, j, nxt[k]] ) / 3 - else: - logging.getLogger(self.__class__.__name__).warning( + + elif self.grid.num_axes > 3: + raise NotImplementedError( f"Can't interpolate corners for grid with {self.grid.num_axes} axes" ) diff --git a/pde/grids/boundaries/axis.py b/pde/grids/boundaries/axis.py index 5d928209..61805c45 100644 --- a/pde/grids/boundaries/axis.py +++ b/pde/grids/boundaries/axis.py @@ -287,7 +287,7 @@ def from_data(cls, grid: GridBase, axis: int, data, rank: int = 0) -> BoundaryPa # if len is not supported, the format must be wrong raise BCDataError( f"Unsupported boundary format: `{data}`. " + cls.get_help() - ) + ) from None else: if data_len == 2: # assume that data is given for each boundary diff --git a/pde/grids/boundaries/local.py b/pde/grids/boundaries/local.py index 161d6f54..a96c7719 100644 --- a/pde/grids/boundaries/local.py +++ b/pde/grids/boundaries/local.py @@ -426,7 +426,7 @@ def from_str( except KeyError: raise BCDataError( f"Boundary condition `{condition}` not defined. " + cls.get_help() - ) + ) from None # create the actual class return boundary_class(grid=grid, axis=axis, upper=upper, rank=rank, **kwargs) @@ -459,7 +459,7 @@ def from_dict( data = data.copy() # need to make a copy since we modify it below # parse all possible variants that could be given - if "type" in data.keys(): + if "type" in data: # type is given (optionally with a value) b_type = data.pop("type") return cls.from_str(grid, axis, upper, condition=b_type, rank=rank, **data) @@ -1224,14 +1224,13 @@ def __init__( except Exception as err: if self._is_func: raise BCDataError( - f"Could not evaluate BC function. Expected signature " - f"{signature}.\nEncountered error: {err}" - ) + f"Could not evaluate BC function. Expected signature {signature}." + ) from err else: raise BCDataError( f"Could not evaluate BC expression `{expression}` with signature " - f"{signature}.\nEncountered error: {err}" - ) + f"{signature}." + ) from err @property def _test_values(self) -> tuple[float, ...]: @@ -1275,7 +1274,7 @@ def value_func(*args): except nb.NumbaError: # if compilation fails, we simply fall back to pure-python mode - self._logger.warning(f"Cannot compile BC {self}") + self._logger.warning("Cannot compile BC %s", self) @register_jitable def value_func(*args): @@ -1360,7 +1359,7 @@ def _get_function_from_expression(self, do_jit: bool) -> Callable: except nb.NumbaError: # if compilation fails, we simply fall back to pure-python mode - self._logger.warning(f"Cannot compile BC {self._func_expression}") + self._logger.warning("Cannot compile BC %s", self._func_expression) # calculate the expected value to test this later (and fail early) expected = func(*self._test_values) @@ -1391,12 +1390,14 @@ def value_func(grid_value, dx, x, y, z, t): else: # cheap way to signal a problem - raise ValueError + raise ValueError from None # compile the actual functio and check the result result_compiled = value_func(*self._test_values) if not np.allclose(result_compiled, expected): - raise RuntimeError("Compiled function does not give same value") + raise RuntimeError( + "Compiled function does not give same value" + ) from None return value_func # type: ignore @@ -2974,4 +2975,4 @@ def registered_boundary_condition_names() -> dict[str, type[BCBase]]: Returns: dict: a dictionary with the names of the boundary conditions that can be used """ - return {cls_name: cls for cls_name, cls in BCBase._conditions.items()} + return dict(BCBase._conditions.items()) diff --git a/pde/grids/cartesian.py b/pde/grids/cartesian.py index ca6cc600..187561de 100644 --- a/pde/grids/cartesian.py +++ b/pde/grids/cartesian.py @@ -319,7 +319,7 @@ def _get_axis(axis): try: axis = self.axes.index(axis) except ValueError: - raise ValueError(f"Axis `{axis}` not defined") + raise ValueError(f"Axis `{axis}` not defined") from None return axis if extract == "auto": diff --git a/pde/grids/operators/cartesian.py b/pde/grids/operators/cartesian.py index e5a4907f..7032f962 100644 --- a/pde/grids/operators/cartesian.py +++ b/pde/grids/operators/cartesian.py @@ -986,7 +986,8 @@ def _make_divergence_scipy_nd( def divergence(arr: np.ndarray, out: np.ndarray) -> None: """Apply divergence operator to array `arr`""" - assert arr.shape[0] == len(data_shape) and arr.shape[1:] == data_shape + assert arr.shape[0] == len(data_shape) + assert arr.shape[1:] == data_shape # need to initialize with zeros since data is added later if out is None: diff --git a/pde/grids/operators/polar_sym.py b/pde/grids/operators/polar_sym.py index 1a25ee15..078c2120 100644 --- a/pde/grids/operators/polar_sym.py +++ b/pde/grids/operators/polar_sym.py @@ -300,12 +300,11 @@ def _get_laplace_matrix(bcs: Boundaries) -> tuple[np.ndarray, np.ndarray]: if r_min == 0: matrix[i, i + 1] = 2 * scale continue # the special case of the inner boundary is handled - else: - const, entries = bcs[0].get_sparse_matrix_data((-1,)) - factor = scale - scale_i - vector[i] += const * factor - for k, v in entries.items(): - matrix[i, k] += v * factor + const, entries = bcs[0].get_sparse_matrix_data((-1,)) + factor = scale - scale_i + vector[i] += const * factor + for k, v in entries.items(): + matrix[i, k] += v * factor else: matrix[i, i - 1] = scale - scale_i diff --git a/pde/grids/operators/spherical_sym.py b/pde/grids/operators/spherical_sym.py index fb0b9f80..f5972b47 100644 --- a/pde/grids/operators/spherical_sym.py +++ b/pde/grids/operators/spherical_sym.py @@ -57,7 +57,8 @@ def make_laplace(grid: SphericalSymGrid, *, conservative: bool = True) -> Operat # create a conservative spherical laplace operator rl = rs - dr / 2 # inner radii of spherical shells rh = rs + dr / 2 # outer radii - assert np.isclose(rl[0], r_min) and np.isclose(rh[-1], r_max) + assert np.isclose(rl[0], r_min) + assert np.isclose(rh[-1], r_max) volumes = (rh**3 - rl**3) / 3 # volume of the spherical shells factor_l = rl**2 / (dr * volumes) factor_h = rh**2 / (dr * volumes) @@ -496,7 +497,8 @@ def make_tensor_double_divergence( rl = rs - dr / 2 # inner radii of spherical shells rh = rs + dr / 2 # outer radii r_min, r_max = grid.axes_bounds[0] - assert np.isclose(rl[0], r_min) and np.isclose(rh[-1], r_max) + assert np.isclose(rl[0], r_min) + assert np.isclose(rh[-1], r_max) volumes = (rh**3 - rl**3) / 3 # volume of the spherical shells factor_l = rl / volumes factor_h = rh / volumes diff --git a/pde/pdes/base.py b/pde/pdes/base.py index 03f20c91..4ae79ad2 100644 --- a/pde/pdes/base.py +++ b/pde/pdes/base.py @@ -447,7 +447,6 @@ def noise_realization(state_data: np.ndarray, t: float) -> np.ndarray: @jit def noise_realization(state_data: np.ndarray, t: float) -> None: """Helper function returning a noise realization.""" - return None return noise_realization # type: ignore diff --git a/pde/pdes/laplace.py b/pde/pdes/laplace.py index e8d58837..8931f6f4 100644 --- a/pde/pdes/laplace.py +++ b/pde/pdes/laplace.py @@ -65,14 +65,14 @@ def solve_poisson_equation( result = ScalarField(rhs.grid, label=label) try: solver(rhs.data, result.data) - except RuntimeError: + except RuntimeError as err: magnitude = rhs.magnitude if magnitude > 1e-10: raise RuntimeError( "Could not solve the Poisson problem. One possible reason for this is " "that only periodic or Neumann conditions are applied although the " f"magnitude of the field is {magnitude} and thus non-zero." - ) + ) from err else: raise # another error occurred diff --git a/pde/pdes/pde.py b/pde/pdes/pde.py index 9345509e..19e00dbf 100644 --- a/pde/pdes/pde.py +++ b/pde/pdes/pde.py @@ -224,16 +224,16 @@ def __init__( else: raise ValueError(f'Cannot parse boundary condition "{key_str}"') if key in self.bcs: - self._logger.warning(f"Two boundary conditions for key {key}") + self._logger.warning("Two boundary conditions for key %s", key) self.bcs[key] = value # save information for easy inspection self.diagnostics["pde"] = { "variables": list(self.variables), - "constants": list(sorted(self.consts)), + "constants": sorted(self.consts), "explicit_time_dependence": explicit_time_dependence, "complex_valued_rhs": complex_valued, - "operators": list(sorted(set().union(*self._operators.values()))), + "operators": sorted(set().union(*self._operators.values())), } self._cache: dict[str, dict[str, Any]] = {} @@ -309,7 +309,7 @@ def _compile_rhs_single( expr._sympy_expr = expr._sympy_expr.replace( # only modify the relevant operator lambda expr: isinstance(expr.func, UndefinedFunction) - and expr.name == func + and expr.name == func # noqa: B023 # and do not modify it when the bc_args have already been set and not ( isinstance(expr.args[-1], Symbol) @@ -335,7 +335,7 @@ def _compile_rhs_single( else: # expression only depends on the actual variables - extra_args = tuple() # @UnusedVariable + extra_args = () # @UnusedVariable # check whether all variables are accounted for extra_vars = set(expr.vars) - set(signature) @@ -462,7 +462,7 @@ def _prepare_cache( # check whether there are boundary conditions that have not been used bcs_left = set(self.bcs.keys()) - self.diagnostics["pde"]["bcs_used"] - {"*:*"} if bcs_left: - self._logger.warning("Unused BCs: %s", list(sorted(bcs_left))) + self._logger.warning("Unused BCs: %s", sorted(bcs_left)) # add extra information for field collection if isinstance(state, FieldCollection): diff --git a/pde/solvers/base.py b/pde/solvers/base.py index c8a67233..a74c9126 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -30,7 +30,7 @@ class ConvergenceError(RuntimeError): """Indicates that an implicit step did not converge.""" -class SolverBase(metaclass=ABCMeta): +class SolverBase: """Base class for PDE solvers.""" dt_default: float = 1e-3 @@ -73,7 +73,7 @@ def __init_subclass__(cls, **kwargs): # @NoSelf cls._subclasses[cls.__name__] = cls if hasattr(cls, "name") and cls.name: if cls.name in cls._subclasses: - logging.warning(f"Solver with name {cls.name} is already registered") + logging.warning("Solver with name %s is already registered", cls.name) cls._subclasses[cls.name] = cls @classmethod @@ -108,14 +108,14 @@ def from_name(cls, name: str, pde: PDEBase, **kwargs) -> SolverBase: raise ValueError( f"Unknown solver method '{name}'. Registered solvers are " + ", ".join(solvers) - ) + ) from None return solver_class(pde, **kwargs) @classproperty def registered_solvers(cls) -> list[str]: # @NoSelf """list of str: the names of the registered solvers""" - return list(sorted(cls._subclasses.keys())) + return sorted(cls._subclasses.keys()) @property def _compiled(self) -> bool: @@ -250,7 +250,7 @@ def _make_pde_rhs( time. The function returns the deterministic evolution rate and (if applicable) a realization of the associated noise. """ - if getattr(self.pde, "is_sde"): + if getattr(self.pde, "is_sde", False): raise RuntimeError( f"Cannot create a deterministic stepper for a stochastic equation" ) @@ -379,8 +379,9 @@ def make_stepper( dt = self.dt_default self._logger.warning( "Explicit stepper with a fixed time step did not receive any " - f"initial value for `dt`. Using dt={dt}, but specifying a value or " - "enabling adaptive stepping is advisable." + "initial value for `dt`. Using dt=%g, but specifying a value or " + "enabling adaptive stepping is advisable.", + dt, ) dt_float = float(dt) # explicit casting to help type checking @@ -526,7 +527,7 @@ def _make_single_step_error_estimate( An example for the state from which the grid and other information can be extracted """ - if getattr(self.pde, "is_sde"): + if getattr(self.pde, "is_sde", False): raise RuntimeError("Cannot use adaptive stepper with stochastic equation") single_step = self._make_single_step_variable_dt(state) @@ -638,7 +639,7 @@ def adaptive_stepper( ) adaptive_stepper = jit(sig_adaptive)(adaptive_stepper) - self._logger.info(f"Initialized adaptive stepper") + self._logger.info("Initialized adaptive stepper") return adaptive_stepper def make_stepper( diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py index bf53281c..972f9127 100644 --- a/pde/solvers/controller.py +++ b/pde/solvers/controller.py @@ -103,13 +103,13 @@ def t_range(self, value: TRangeType): # determine time range try: self._t_range: tuple[float, float] = (0, float(value)) # type: ignore - except TypeError: # assume a single number was given + except TypeError as err: # assume a single number was given if len(value) == 2: # type: ignore self._t_range = tuple(value) # type: ignore else: raise ValueError( "t_range must be set to a single number or a tuple of two numbers" - ) + ) from err def _get_stop_handler(self) -> Callable[[Exception, float], tuple[int, str]]: """Return function that handles messaging.""" @@ -196,7 +196,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None: # evolve the system from t_start to t_end t = t_start - self._logger.debug(f"Start simulation at t={t}") + self._logger.debug("Start simulation at t=%g", t) try: while t < t_end: # determine next time point with an action @@ -265,8 +265,10 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None: self._logger.log(msg_level, msg) if profiler["tracker"] > max(profiler["solver"], 1): self._logger.warning( - f"Spent more time on handling trackers ({profiler['tracker']:.3g}) " - f"than on the actual simulation ({profiler['solver']:.3g})" + "Spent more time on handling trackers (%.3g) than on the actual " + "simulation (%.3g)", + profiler["tracker"], + profiler["solver"], ) def _run_client_process(self, state: TState, dt: float | None = None) -> None: @@ -352,7 +354,7 @@ def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None self._run_main_process(state, dt) except Exception as err: print(err) # simply print the exception to show some info - self._logger.error(f"Error in main node", exc_info=err) + self._logger.error("Error in main node", exc_info=err) time.sleep(0.5) # give some time for info to propagate MPI.COMM_WORLD.Abort() # abort all other nodes raise @@ -365,7 +367,7 @@ def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None self._run_client_process(state, dt) except Exception as err: print(err) # simply print the exception to show some info - self._logger.error(f"Error in node {mpi.rank}", exc_info=err) + self._logger.error("Error in node %d", mpi.rank, exc_info=err) time.sleep(0.5) # give some time for info to propagate MPI.COMM_WORLD.Abort() # abort all other (and main) nodes raise @@ -394,7 +396,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None: from ..tools import mpi # copy the initial state to not modify the supplied one - if getattr(self.solver, "pde") and self.solver.pde.complex_valued: + if hasattr(self.solver, "pde") and self.solver.pde.complex_valued: self._logger.info("Convert state to complex numbers") state: TState = initial_state.copy(dtype=complex) else: diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py index cd141f41..4aceee8e 100644 --- a/pde/solvers/explicit.py +++ b/pde/solvers/explicit.py @@ -266,7 +266,7 @@ def adaptive_stepper( ) adaptive_stepper = jit(sig_adaptive)(adaptive_stepper) - self._logger.info(f"Init adaptive Euler stepper") + self._logger.info("Init adaptive Euler stepper") return adaptive_stepper def _make_single_step_error_estimate_rkf( diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index a5b511b5..2b63d881 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -155,8 +155,9 @@ def make_stepper( if not self.adaptive: self._logger.warning( "Explicit stepper with a fixed time step did not receive any " - f"initial value for `dt`. Using dt={dt}, but specifying a value or " - "enabling adaptive stepping is advisable." + "initial value for `dt`. Using dt=%g, but specifying a value or " + "enabling adaptive stepping is advisable.", + dt, ) self.info["dt"] = dt diff --git a/pde/solvers/scipy.py b/pde/solvers/scipy.py index 579845f1..4eada32c 100644 --- a/pde/solvers/scipy.py +++ b/pde/solvers/scipy.py @@ -99,7 +99,7 @@ def stepper(state: FieldBase, t_start: float, t_end: float) -> float: return sol.t[0] # type: ignore if dt: - self._logger.info(f"Init {self.__class__.__name__} stepper with dt=%g", dt) + self._logger.info("Init %s stepper with dt=%g", self.__class__.__name__, dt) else: - self._logger.info(f"Init {self.__class__.__name__} stepper") + self._logger.info("Init %s stepper", self.__class__.__name__) return stepper diff --git a/pde/storage/__init__.py b/pde/storage/__init__.py index ef2c57d8..322f3697 100644 --- a/pde/storage/__init__.py +++ b/pde/storage/__init__.py @@ -12,11 +12,11 @@ .. codeauthor:: David Zwicker """ +import contextlib + from .file import FileStorage from .memory import MemoryStorage, get_memory_storage from .movie import MovieStorage -try: +with contextlib.suppress(ImportError): from .modelrunner import ModelrunnerStorage -except ImportError: - ... # ModelrunnerStorage is only available when py-modelrunner is available diff --git a/pde/storage/base.py b/pde/storage/base.py index 4bcd4e83..6521b9d5 100644 --- a/pde/storage/base.py +++ b/pde/storage/base.py @@ -176,8 +176,8 @@ def grid(self) -> GridBase | None: self._grid = attrs["fields"][0]["grid"] else: self._logger.warning( - "`grid` attribute was not stored. Available attributes: " - + ", ".join(sorted(attrs.keys())) + "`grid` attribute was not stored. Available attributes: %s", + ", ".join(sorted(attrs.keys())), ) else: @@ -219,8 +219,8 @@ def _init_field(self) -> None: f"{local_shape} could not be interpreted automatically" ) self._logger.warning( - "`field` attribute was not stored. We guessed that the data is of " - f"type {self._field.__class__.__name__}." + "`field` attribute was not stored. Assume data is of type %s.", + self._field.__class__.__name__, ) def _get_field(self, t_index: int) -> FieldBase: diff --git a/pde/storage/file.py b/pde/storage/file.py index 05f0b6ca..d4af8137 100644 --- a/pde/storage/file.py +++ b/pde/storage/file.py @@ -74,7 +74,7 @@ def __init__( self._data_length: int = None # type: ignore self._max_length: int | None = max_length - if not self.check_mpi or mpi.is_main: + if not self.check_mpi or mpi.is_main: # noqa: SIM102 # we are on the main process and can thus open the file directly if self.filename.is_file() and self.filename.stat().st_size > 0: try: @@ -82,7 +82,7 @@ def __init__( except (OSError, KeyError): self.close() self._logger.warning( - f"File `{filename}` could not be opened for reading" + "File `%s` could not be opened for reading", filename ) def __del__(self): @@ -103,7 +103,7 @@ def _file_state(self) -> str: def close(self) -> None: """Close the currently opened file.""" if self._file is not None: - self._logger.info(f"Close file `{self.filename}`") + self._logger.info("Close file `%s`", self.filename) self._file.close() self._file = None self._data_length = None # type: ignore @@ -117,7 +117,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def _create_hdf_dataset( self, name: str, - shape: tuple[int, ...] = tuple(), + shape: tuple[int, ...] = (), dtype: DTypeLike = np.double, ): """Create a hdf5 dataset with the given name and data_shape. @@ -176,7 +176,7 @@ def _open( # close file to open it again for reading or appending if self._file: self._file.close() - self._logger.info(f"Open file `{self.filename}` for reading") + self._logger.info("Open file `%s` for reading", self.filename) self._file = h5py.File(self.filename, mode="r") self._times = self._file["times"] self._data = self._file["data"] @@ -200,7 +200,7 @@ def _open( self.close() # open file for reading or appending - self._logger.info(f"Open file `{self.filename}` for appending") + self._logger.info("Open file `%s` for appending", self.filename) self._file = h5py.File(self.filename, mode="a") if "times" in self._file and "data" in self._file: @@ -234,7 +234,7 @@ def _open( self.close() else: ensure_directory_exists(self.filename.parent) - self._logger.info(f"Open file `{self.filename}` for writing") + self._logger.info("Open file `%s` for writing", self.filename) self._file = h5py.File(self.filename, "w") self._times = self._create_hdf_dataset("times") self._data = self._create_hdf_dataset( @@ -337,7 +337,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None: super().start_writing(field, info=info) # initialize the file for writing with the correct mode - self._logger.debug(f"Start writing with mode `{self.write_mode}`") + self._logger.debug("Start writing with mode '%s'", self.write_mode) if self.write_mode == "truncate_once": self._open("writing", info) self.write_mode = "append" # do not truncate for next writing diff --git a/pde/storage/modelrunner.py b/pde/storage/modelrunner.py index 28cb240a..d79fe7a7 100644 --- a/pde/storage/modelrunner.py +++ b/pde/storage/modelrunner.py @@ -130,7 +130,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None: super().start_writing(field, info=info) # initialize the file for writing with the correct mode - self._logger.debug(f"Start writing with mode `{self.write_mode}`") + self._logger.debug("Start writing with mode '%s'", self.write_mode) if self.write_mode == "truncate_once": self.write_mode = "append" # do not truncate for next writing elif self.write_mode == "readonly": diff --git a/pde/storage/movie.py b/pde/storage/movie.py index a32c8e7c..8c80f5ae 100644 --- a/pde/storage/movie.py +++ b/pde/storage/movie.py @@ -144,7 +144,7 @@ def __del__(self): def close(self) -> None: """Close the currently opened file.""" if self._ffmpeg is not None: - self._logger.info(f"Close movie file `{self.filename}`") + self._logger.info("Close movie file '%s'", self.filename) if self._state == "writing": self._ffmpeg.stdin.close() self._ffmpeg.wait() @@ -191,7 +191,7 @@ def _read_metadata(self) -> None: # sanity checks on the video nb_streams = info["format"]["nb_streams"] if nb_streams != 1: - self._logger.warning(f"Only using first of {nb_streams} streams") + self._logger.warning("Only using first of %d streams", nb_streams) tags = info["format"].get("tags", {}) # read comment field, which can be either lower case or upper case @@ -202,7 +202,7 @@ def _read_metadata(self) -> None: version = metadata.pop("version", 1) if version != 1: - self._logger.warning(f"Unknown metadata version `{version}`") + self._logger.warning("Unknown metadata version `%d`", version) self.vmin = metadata.pop("vmin", 0) self.vmax = metadata.pop("vmax", 1) self.info.update(metadata) @@ -217,8 +217,8 @@ def _read_metadata(self) -> None: try: fps = Fraction(stream.get("avg_frame_rate", None)) duration = parse_duration(stream.get("tags", {}).get("DURATION")) - except TypeError: - raise RuntimeError("Frame count could not be read from video") + except TypeError as err: + raise RuntimeError("Frame count could not be read from video") from err else: self.info["num_frames"] = int(duration.total_seconds() * float(fps)) self.info["width"] = stream["width"] @@ -234,12 +234,13 @@ def _read_metadata(self) -> None: try: self._format = FFmpeg.formats[video_format] except KeyError: - self._logger.warning(f"Unknown pixel format `{video_format}`") + self._logger.warning("Unknown pixel format `%s`", video_format) else: if self._format.pix_fmt_file != stream.get("pix_fmt"): self._logger.info( - "Pixel format differs from requested one: " - f"{self._format.pix_fmt_file} != {stream.get('pix_fmt')}" + "Pixel format differs from requested one: %s != %s", + self._format.pix_fmt_file, + stream.get("pix_fmt"), ) def _init_normalization(self, field: FieldBase) -> None: @@ -332,7 +333,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None: self._init_normalization(field) # set input - self._logger.debug(f"Start ffmpeg process for `{self.filename}`") + self._logger.debug("Start ffmpeg process for `%s`", self.filename) input_args = { "format": "rawvideo", "s": f"{width}x{height}", @@ -358,7 +359,7 @@ def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None: self._ffmpeg = f_output.run_async(pipe_stdin=True) # start process if self.write_times: - self._times_file = open(self._filename_times, "w") + self._times_file = self._filename_times.open("w") # noqa: SIM115 self.info["num_frames"] = 0 self._warned_normalization = False @@ -390,9 +391,9 @@ def _append_data(self, data: np.ndarray, time: float) -> None: t_start = 0 dt = self.info.get("dt", 1) time_expect = t_start + dt * self.info["num_frames"] - if not np.isclose(time, time_expect): - if not self.info.get("time_mismatch", False): - self._logger.warning(f"Time mismatch: {time} != {time_expect}") + if not np.isclose(time, time_expect): # discrepancy in time # noqa: SIM102 + if not self.info.get("time_mismatch", False): # not yet warned + self._logger.warning("Time mismatch: %g != %g", time, time_expect) self.info["time_mismatch"] = True # make sure there are two spatial dimensions @@ -421,8 +422,9 @@ def _append_data(self, data: np.ndarray, time: float) -> None: if not self._warned_normalization: if np.any(data[i, ...] < norm.vmin) or np.any(data[i, ...] > norm.vmax): self._logger.warning( - f"Data outside range specified by `vmin={norm.vmin}` and " - f"`vmax={norm.vmax}`" + "Data outside range specified by `vmin=%g` and `vmax=%g`", + norm.vmin, + norm.vmax, ) self._warned_normalization = True # only warn once data_norm = norm(data[i, ...]) @@ -434,7 +436,7 @@ def _append_data(self, data: np.ndarray, time: float) -> None: def end_writing(self) -> None: """Finalize the storage after writing.""" - if not self._state == "writing": + if self._state != "writing": self._logger.warning("Writing was already terminated") return # writing mode was already ended self._logger.debug("End writing") @@ -459,8 +461,9 @@ def times(self): times = np.loadtxt(self._filename_times) except OSError: self._logger.warning( - f"Could not read time stamps from file `{self._filename_times}`. " - "Return equidistant times instead." + "Could not read time stamps from file `%s`. " + "Return equidistant times instead.", + self._filename_times, ) if times is None: @@ -630,7 +633,7 @@ def add_to_state(state): if not (self.write_times or isinstance(interrupts, ConstantInterrupts)): self._logger.warning( - f"Use `write_times=True` to write times for complex interrupts" + "Use `write_times=True` to write times for complex interrupts" ) # store data for common case of constant intervals self.info["dt"] = getattr(interrupts, "dt", 1) diff --git a/pde/tools/cache.py b/pde/tools/cache.py index ba0c23c6..744a69f4 100644 --- a/pde/tools/cache.py +++ b/pde/tools/cache.py @@ -244,7 +244,7 @@ def make_serializer(method: SerializerMethod) -> Callable: return lambda s: yaml.dump(s).encode("utf-8") - raise ValueError("Unknown serialization method `%s`" % method) + raise ValueError(f"Unknown serialization method `{method}`") def make_unserializer(method: SerializerMethod) -> Callable: @@ -285,7 +285,7 @@ def make_unserializer(method: SerializerMethod) -> Callable: return yaml.unsafe_load - raise ValueError("Unknown serialization method `%s`" % method) + raise ValueError(f"Unknown serialization method `{method}`") class DictFiniteCapacity(collections.OrderedDict): diff --git a/pde/tools/config.py b/pde/tools/config.py index d4fbacae..0b6e3c91 100644 --- a/pde/tools/config.py +++ b/pde/tools/config.py @@ -103,10 +103,10 @@ def __setitem__(self, key: str, value): elif self.mode == "update": try: self[key] # test whether the key already exist (including magic keys) - except KeyError: + except KeyError as err: raise KeyError( f"{key} is not present and config is not in `insert` mode" - ) + ) from err self.data[key] = value elif self.mode == "locked": @@ -128,7 +128,7 @@ def to_dict(self) -> dict[str, Any]: Returns: dict: A representation of the configuration in a normal :class:`dict`. """ - return {k: v for k, v in self.items()} + return dict(self.items()) def __repr__(self) -> str: """Represent the configuration as a string.""" @@ -179,10 +179,8 @@ def parse_version_str(ver_str: str) -> list[int]: """Helper function converting a version string into a list of integers.""" result = [] for token in ver_str.split(".")[:3]: - try: + with contextlib.suppress(ValueError): result.append(int(token)) - except ValueError: - pass return result @@ -215,7 +213,7 @@ def packages_from_requirements(requirements_file: Path | str) -> list[str]: """ result = [] try: - with open(requirements_file) as fp: + with Path(requirements_file).open() as fp: for line in fp: line_s = line.strip() if line_s.startswith("#"): diff --git a/pde/tools/cuboid.py b/pde/tools/cuboid.py index 933b30e8..4eaa2558 100644 --- a/pde/tools/cuboid.py +++ b/pde/tools/cuboid.py @@ -120,9 +120,7 @@ def copy(self) -> Cuboid: return self.__class__(self.pos, self.size) def __repr__(self): - return "{cls}(pos={pos}, size={size})".format( - cls=self.__class__.__name__, pos=self.pos, size=self.size - ) + return f"{self.__class__.__name__}(pos={self.pos}, size={self.size})" def __add__(self, other: Cuboid) -> Cuboid: """The sum of two cuboids is the minimal cuboid enclosing both.""" diff --git a/pde/tools/docstrings.py b/pde/tools/docstrings.py index 888a5fc1..eba7716b 100644 --- a/pde/tools/docstrings.py +++ b/pde/tools/docstrings.py @@ -14,6 +14,7 @@ import re import textwrap +from functools import partial from typing import TypeVar DOCSTRING_REPLACEMENTS = { @@ -164,17 +165,16 @@ def fill_in_docstring(f: TFunc) -> TFunc: width=80, expand_tabs=True, replace_whitespace=True, drop_whitespace=True ) - for name, value in DOCSTRING_REPLACEMENTS.items(): - - def repl(matchobj) -> str: - """Helper function replacing token in docstring.""" - tw.initial_indent = tw.subsequent_indent = matchobj.group(1) - return tw.fill(textwrap.dedent(value)) + def repl(matchobj, value: str) -> str: + """Helper function replacing token in docstring.""" + tw.initial_indent = tw.subsequent_indent = matchobj.group(1) + return tw.fill(textwrap.dedent(value)) + for name, value in DOCSTRING_REPLACEMENTS.items(): token = "{" + name + "}" f.__doc__ = re.sub( f"^([ \t]*){token}", - repl, + partial(repl, value=value), f.__doc__, # type: ignore flags=re.MULTILINE, ) diff --git a/pde/tools/expressions.py b/pde/tools/expressions.py index c8a3c26d..8319999f 100644 --- a/pde/tools/expressions.py +++ b/pde/tools/expressions.py @@ -260,14 +260,13 @@ def __init__( for name, value in self.consts.items(): if isinstance(value, FieldBase): self._logger.warning( - f"Constant `{name}` is a field, but expressions usually require " - f"numerical arrays. Did you mean to use `{name}.data`?" + "Constant `%(name)s` is a field, but expressions usually require " + "numerical arrays. Did you mean to use `%(name)s.data`?", + {"name": name}, ) def __repr__(self): - return ( - f'{self.__class__.__name__}("{self.expression}", ' f"signature={self.vars})" - ) + return f'{self.__class__.__name__}("{self.expression}", signature={self.vars})' def __eq__(self, other): """Compare this expression to another one.""" @@ -324,9 +323,9 @@ def _check_signature(self, signature: Sequence[str | list[str]] | None = None): args = {str(s).split("[")[0] for s in self._free_symbols} if signature is None: # create signature from arguments - signature = list(sorted(args)) + signature = sorted(args) - self._logger.debug(f"Expression arguments: {args}") + self._logger.debug("Expression arguments: %s", args) # check whether variables are in signature self.vars: Any = [] @@ -345,7 +344,7 @@ def _check_signature(self, signature: Sequence[str | list[str]] | None = None): old = sympy.symbols(arg) new = sympy.symbols(arg_name) self._sympy_expr = self._sympy_expr.subs(old, new) - self._logger.info(f'Renamed variable "{old}"->"{new}"') + self._logger.info('Renamed variable "%s"->"%s"', old, new) found.add(arg) break @@ -514,7 +513,7 @@ def get_compiled(self, single_arg: bool = False) -> Callable[..., NumberOrArray] class ScalarExpression(ExpressionBase): """Describes a mathematical expression of a scalar quantity.""" - shape: tuple[int, ...] = tuple() + shape: tuple[int, ...] = () @fill_in_docstring def __init__( @@ -678,9 +677,8 @@ def differentiate(self, var: str) -> ScalarExpression: return ScalarExpression( expression=0, signature=self.vars, allow_indexed=self.allow_indexed ) - if self.allow_indexed: - if self._var_indexed(var): - raise NotImplementedError("Cannot differentiate with respect to vector") + if self.allow_indexed and self._var_indexed(var): + raise NotImplementedError("Cannot differentiate with respect to vector") # turn variable into sympy object and treat an indexed variable separately var_expr = self._prepare_expression(var) @@ -707,11 +705,10 @@ def derivatives(self) -> TensorExpression: expression = sympy.Array(np.zeros(dim), shape=(dim,)) return TensorExpression(expression=expression, signature=self.vars) - if self.allow_indexed: - if any(self._var_indexed(var) for var in self.vars): - raise RuntimeError( - "Cannot calculate gradient for expressions with indexed variables" - ) + if self.allow_indexed and any(self._var_indexed(var) for var in self.vars): + raise RuntimeError( + "Cannot calculate gradient for expressions with indexed variables" + ) grad = sympy.Array([self._sympy_expr.diff(sympy.Symbol(v)) for v in self.vars]) return TensorExpression( @@ -1090,7 +1087,7 @@ def evaluate( # check whether there are boundary conditions that have not been used bcs_left = set(bcs.keys()) - bcs_used - {"*:*", "*"} if bcs_left: - logger.warning("Unused BCs: %s", list(sorted(bcs_left))) + logger.warning("Unused BCs: %s", sorted(bcs_left)) # obtain the function to calculate the right hand side signature = tuple(fields_keys) + ("none", "bc_args") @@ -1108,7 +1105,7 @@ def evaluate( else: # expression only depends on the actual variables - extra_args = tuple() # @UnusedVariable + extra_args = () # @UnusedVariable # check whether all variables are accounted for extra_vars = set(expr.vars) - set(signature) diff --git a/pde/tools/ffmpeg.py b/pde/tools/ffmpeg.py index 1416cba0..0032e096 100644 --- a/pde/tools/ffmpeg.py +++ b/pde/tools/ffmpeg.py @@ -10,6 +10,8 @@ .. codeauthor:: David Zwicker """ +from __future__ import annotations + from dataclasses import dataclass from typing import Optional, Union @@ -48,7 +50,7 @@ def bytes_per_channel(self) -> int: return self.bits_per_channel // 8 @property - def max_value(self) -> Union[float, int]: + def max_value(self) -> float | int: """Maximal value stored in a color channel.""" if np.issubdtype(self.dtype, np.integer): return 2**self.bits_per_channel - 1 # type: ignore @@ -109,35 +111,11 @@ def data_from_frame(self, frame_data: np.ndarray): bits_per_channel=16, dtype=np.dtype(" Optional[str]: +def find_format(channels: int, bits_per_channel: int = 8) -> str | None: """Find a defined FFmpegFormat that satisifies the requirements. Args: @@ -153,13 +131,14 @@ def find_format(channels: int, bits_per_channel: int = 8) -> Optional[str]: """ n_best, f_best = None, None for n, f in formats.items(): # iterate through all defined formats - if f.channels >= channels and f.bits_per_channel >= bits_per_channel: - # this format satisfies the requirements - if ( + if ( + f.channels >= channels + and f.bits_per_channel >= bits_per_channel # satisfies the requirements + and ( f_best is None or f.bits_per_channel < f_best.bits_per_channel or f.channels < f_best.channels - ): - # the current format is better than the previous one - n_best, f_best = n, f + ) # the current format is better than the previous one + ): + n_best, f_best = n, f return n_best diff --git a/pde/tools/misc.py b/pde/tools/misc.py index b19dfc6a..3dcf9e4b 100644 --- a/pde/tools/misc.py +++ b/pde/tools/misc.py @@ -60,11 +60,8 @@ def ensure_directory_exists(folder: str | Path): Args: folder (str): path of the new folder """ - folder = str(folder) - if folder == "": - return try: - os.makedirs(folder) + Path(folder).mkdir(parents=True) except OSError as err: if err.errno != errno.EEXIST: raise diff --git a/pde/tools/mpi.py b/pde/tools/mpi.py index 7d3f36ed..d767ab3d 100644 --- a/pde/tools/mpi.py +++ b/pde/tools/mpi.py @@ -90,7 +90,7 @@ def __getattr__(self, name: str): try: return self._name_ids[name] except KeyError: - raise AttributeError + raise AttributeError(f"MPI operator `{name}` not registered") from None Operator = _OperatorRegistry() Operator.register("MAX", MPI.MAX) diff --git a/pde/tools/output.py b/pde/tools/output.py index af58b618..110c4c12 100644 --- a/pde/tools/output.py +++ b/pde/tools/output.py @@ -41,7 +41,7 @@ def get_progress_bar_class(fancy: bool = True): progress_bar_class = tqdm.tqdm else: # use the fancier version of the progress bar in jupyter - from tqdm.auto import tqdm as progress_bar_class # type: ignore + from tqdm.auto import tqdm as progress_bar_class else: # only import text progress bar progress_bar_class = tqdm.tqdm diff --git a/pde/tools/parameters.py b/pde/tools/parameters.py index 8d6bd38a..a56f932c 100644 --- a/pde/tools/parameters.py +++ b/pde/tools/parameters.py @@ -134,11 +134,11 @@ def convert(self, value=None): else: try: return self.cls(value) - except ValueError: + except ValueError as err: raise ValueError( f"Could not convert {value!r} to {self.cls.__name__} for parameter " f"'{self.name}'" - ) + ) from err class DeprecatedParameter(Parameter): @@ -228,9 +228,9 @@ def get_parameters( """ # collect the parameters from the class hierarchy parameters: dict[str, Parameter] = {} - for cls in reversed(cls.__mro__): - if hasattr(cls, "parameters_default"): - for p in cls.parameters_default: + for parent_cls in reversed(cls.__mro__): + if hasattr(parent_cls, "parameters_default"): + for p in parent_cls.parameters_default: if isinstance(p, HideParameter): if include_hidden: parameters[p.name].hidden = True @@ -504,19 +504,22 @@ def sphinx_display_parameters(app, what, name, obj, options, lines): app.connect('autodoc-process-docstring', sphinx_display_parameters) """ - if what == "class" and issubclass(obj, Parameterized): - if any(":param parameters:" in line for line in lines): - # parse parameters - parameters = obj.get_parameters(sort=False) - if parameters: - lines.append(".. admonition::") - lines.append(f" Parameters of {obj.__name__}:") - lines.append(" ") - for p in parameters.values(): - lines.append(f" {p.name}") - text = p.description.splitlines() - text.append(f"(Default value: :code:`{p.default_value!r}`)") - text = [" " + t for t in text] - lines.extend(text) - lines.append("") + if ( + what == "class" + and issubclass(obj, Parameterized) + and any(":param parameters:" in line for line in lines) + ): + # parse parameters + parameters = obj.get_parameters(sort=False) + if parameters: + lines.append(".. admonition::") + lines.append(f" Parameters of {obj.__name__}:") + lines.append(" ") + for p in parameters.values(): + lines.append(f" {p.name}") + text = p.description.splitlines() + text.append(f"(Default value: :code:`{p.default_value!r}`)") + text = [" " + t for t in text] + lines.extend(text) lines.append("") + lines.append("") diff --git a/pde/tools/plotting.py b/pde/tools/plotting.py index 50bf6646..16f90f86 100644 --- a/pde/tools/plotting.py +++ b/pde/tools/plotting.py @@ -106,14 +106,10 @@ def get_size(self, renderer): cbar = ax.figure.colorbar(axes_image, cax=cax, **kwargs) # disable the offset that matplotlib sometimes shows - try: + with contextlib.suppress(AttributeError): cax.get_xaxis().get_major_formatter().set_useOffset(False) - except AttributeError: - pass # can happen for logarithmically formatted axes - try: + with contextlib.suppress(AttributeError): cax.get_yaxis().get_major_formatter().set_useOffset(False) - except AttributeError: - pass # can happen for logarithmically formatted axes if label: cbar.set_label(label) @@ -297,7 +293,7 @@ def wrapper( if ax is None: # create new figure backend = mpl.get_backend() - if "backend_inline" in backend or "nbAgg" == backend: + if "backend_inline" in backend or backend == "nbAgg": plt.close("all") # close left over figures auto_show_figure = True # show this figure if action == 'auto' fig, ax = plt.subplots() @@ -492,7 +488,7 @@ def wrapper( if fig is None: # create new figure backend = mpl.get_backend() - if "backend_inline" in backend or "nbAgg" == backend: + if "backend_inline" in backend or backend == "nbAgg": plt.close("all") # close left over figures fig = plt.figure(constrained_layout=constrained_layout) @@ -593,7 +589,7 @@ def __init__(self, title: str | None = None, show: bool = True): self.initial_plot = True self.fig = None self._logger = logging.getLogger(__name__) - self._logger.info(f"Initialize {self.__class__.__name__}") + self._logger.info("Initialize %s", self.__class__.__name__) def __enter__(self): # start the plotting process @@ -729,10 +725,8 @@ def close(self): """Close the plot.""" super().close() # close ipython output - try: + with contextlib.suppress(Exception): self._ipython_out.close() - except Exception: - pass def get_plotting_context( @@ -845,7 +839,7 @@ def napari_add_layers( """adds layers to a `napari `__ viewer Args: - viewer (:class:`napar i.viewer.Viewer`): + viewer (:class:`napari.viewer.Viewer`): The napari application layers_data (dict): Data for all layers that will be added. @@ -855,7 +849,7 @@ def napari_add_layers( layer_type = layer_data.pop("type") try: add_layer = getattr(viewer, f"add_{layer_type}") - except AttributeError: - raise RuntimeError(f"Unknown layer type: {layer_type}") + except AttributeError as err: + raise RuntimeError(f"Unknown layer type: {layer_type}") from err else: add_layer(**layer_data) diff --git a/pde/tools/spectral.py b/pde/tools/spectral.py index 25fd10ff..1b076b8d 100644 --- a/pde/tools/spectral.py +++ b/pde/tools/spectral.py @@ -98,7 +98,6 @@ def noise_colored() -> np.ndarray: arr *= scaling # backwards transform - arr = np_irfftn(arr, shape) - return arr + return np_irfftn(arr, shape) # type: ignore return noise_colored diff --git a/pde/trackers/base.py b/pde/trackers/base.py index 838b5d11..0609e096 100644 --- a/pde/trackers/base.py +++ b/pde/trackers/base.py @@ -71,9 +71,9 @@ def from_data(cls, data: TrackerDataType, **kwargs) -> TrackerBase: elif isinstance(data, str): try: tracker_cls = cls._subclasses[data] - except KeyError: + except KeyError as err: trackers = sorted(cls._subclasses.keys()) - raise ValueError(f"Tracker `{data}` is not in {trackers}") + raise ValueError(f"Tracker `{data}` is not in {trackers}") from err return tracker_cls(**kwargs) else: raise ValueError(f"Unsupported tracker format: `{data}`.") diff --git a/pde/trackers/interactive.py b/pde/trackers/interactive.py index 68b7bbaa..01128380 100644 --- a/pde/trackers/interactive.py +++ b/pde/trackers/interactive.py @@ -6,6 +6,7 @@ from __future__ import annotations +import contextlib import logging import multiprocessing as mp import platform @@ -27,7 +28,7 @@ def napari_process( t_initial: float | None = None, viewer_args: dict[str, Any] | None = None, ): - """:mod:`multiprocessing.Process` running `napari `__ + """:mod:`multiprocessing.Process` running `napari120 `__ Args: data_channel (:class:`multiprocessing.Queue`): @@ -109,16 +110,16 @@ def update_listener(): update_data = data # continue running until the queue is empty else: - logger.warning(f"Unexpected action: {action}") + logger.warning("Unexpected action: %s", action) # update napari view when there is data if update_data is not None: - logger.debug(f"Update napari layer...") + logger.debug("Update napari layer...") layer_data, t = update_data if label is not None: label.setText(f"Time: {t}") - for name, layer_data in layer_data.items(): - viewer.layers[name].data = layer_data["data"] + for name, data in layer_data.items(): + viewer.layers[name].data = data["data"] yield @@ -193,10 +194,8 @@ def update(self, state: FieldBase, t: float): except queue.Full: pass # could not write data else: - try: + with contextlib.suppress(queue.Empty): self.data_channel.get(block=False) - except queue.Empty: - pass def close(self, force: bool = True): """Closes the napari process. @@ -208,10 +207,8 @@ def close(self, force: bool = True): """ if self.proc.is_alive() and force: # signal to napari process that it should be closed - try: + with contextlib.suppress(RuntimeError): self.data_channel.put(("close", None)) - except RuntimeError: - pass self.data_channel.close() self.data_channel.join_thread() diff --git a/pde/trackers/trackers.py b/pde/trackers/trackers.py index 16c5be9f..a96d478c 100644 --- a/pde/trackers/trackers.py +++ b/pde/trackers/trackers.py @@ -44,7 +44,7 @@ from .interrupts import InterruptData, RealtimeInterrupts if TYPE_CHECKING: - import pandas + import pandas # noqa: ICN001 class CallbackTracker(TrackerBase): @@ -450,7 +450,7 @@ def initialize(self, state: FieldBase, info: InfoDict | None = None) -> float: else: self._update_method = "replot" - self._logger.info(f'Update method: "{self._update_method}"') + self._logger.info('Update method: "%s"', self._update_method) self._last_update = time.monotonic() return super().initialize(state, info=info) @@ -709,11 +709,13 @@ def to_file(self, filename: str, **kwargs): \**kwargs: Additional parameters may be supported for some formats """ - extension = os.path.splitext(filename)[1].lower() + from pathlib import Path + + extension = Path(filename).suffix.lower() if extension == ".pickle": import pickle - with open(filename, "wb") as fp: + with Path(filename).open("wb") as fp: pickle.dump((self.times, self.data), fp, **kwargs) elif extension == ".csv": diff --git a/pde/visualization/movies.py b/pde/visualization/movies.py index e4e24a22..e7638e88 100644 --- a/pde/visualization/movies.py +++ b/pde/visualization/movies.py @@ -261,13 +261,13 @@ def movie( except AttributeError: fig = ref[0].ax.figure if show_time: - title = fig.suptitle("Time %g" % t) + title = fig.suptitle(f"Time {t:g}") else: # update the data in the figure field._update_plot(ref) if show_time: - title.set_text("Time %g" % t) + title.set_text(f"Time {t:g}") # add the current matplotlib figure to the movie movie.add_figure(fig) diff --git a/pde/visualization/plotting.py b/pde/visualization/plotting.py index 42f09ad1..8bb8aa65 100644 --- a/pde/visualization/plotting.py +++ b/pde/visualization/plotting.py @@ -844,7 +844,7 @@ def plot_interactive( raise RuntimeError("Storage did not contain information about the grid") # collect data from all time points - timecourse: dict[str, list[np.ndarray]] = dict() + timecourse: dict[str, list[np.ndarray]] = {} for field in storage: layer_data = field._get_napari_data(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index b10c6964..fa1b4ef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,28 @@ exclude = ["scripts/templates"] [tool.ruff.format] docstring-code-format = true +[tool.ruff.lint] +select = [ + "UP", # pyupgrade + "I", # isort + "A", # flake8-builtins + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "FA", # flake8-future-annotations + "ISC", # flake8-implicit-str-concat + "ICN", # flake8-import-conventions + "LOG", # flake8-logging + "G", # flake8-logging-format + "PIE", # flake8-pie + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SIM", # flake8-simplify + "PTH", # flake8-use-pathlib +] +ignore = ["B007", "B027", "B028", "SIM108", "ISC001", "PT006", "PT011", "RET504", "RET505", "RET506"] + [tool.ruff.lint.isort] section-order = ["future", "standard-library", "third-party", "first-party", "self", "local-folder"] diff --git a/scripts/_templates/pyproject.toml b/scripts/_templates/pyproject.toml index 87588e6d..2160e064 100644 --- a/scripts/_templates/pyproject.toml +++ b/scripts/_templates/pyproject.toml @@ -55,12 +55,34 @@ namespaces = false write_to = "pde/_version.py" [tool.ruff] -target-version = "py39" +target-version = "py$MIN_PYTHON_VERSION_NODOT" exclude = ["scripts/templates"] [tool.ruff.format] docstring-code-format = true +[tool.ruff.lint] +select = [ + "UP", # pyupgrade + "I", # isort + "A", # flake8-builtins + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "FA", # flake8-future-annotations + "ISC", # flake8-implicit-str-concat + "ICN", # flake8-import-conventions + "LOG", # flake8-logging + "G", # flake8-logging-format + "PIE", # flake8-pie + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SIM", # flake8-simplify + "PTH", # flake8-use-pathlib +] +ignore = ["B007", "B027", "B028", "SIM108", "ISC001", "PT006", "PT011", "RET504", "RET505", "RET506"] + [tool.ruff.lint.isort] section-order = ["future", "standard-library", "third-party", "first-party", "self", "local-folder"] diff --git a/scripts/create_requirements.py b/scripts/create_requirements.py index 497a89a0..45fef6cc 100755 --- a/scripts/create_requirements.py +++ b/scripts/create_requirements.py @@ -200,7 +200,7 @@ def write_requirements_txt( """ print(f"Write `{path}`") path.parent.mkdir(exist_ok=True, parents=True) # ensure path exists - with open(path, "w") as fp: + with path.open("w") as fp: if comment: fp.write(f"# {comment}\n") if ref_base: @@ -221,7 +221,7 @@ def write_requirements_csv( requirements (list): The requirements to be written """ print(f"Write `{path}`") - with open(path, "w") as fp: + with path.open("w") as fp: writer = csv.writer(fp) if incl_version: writer.writerow(["Package", "Minimal version", "Usage"]) @@ -245,7 +245,7 @@ def write_requirements_py(path: Path, requirements: list[Requirement]): # read user-created content of file content = [] - with open(path) as fp: + with path.open() as fp: for line in fp: if "GENERATED CODE" in line: content.append(line) @@ -260,7 +260,7 @@ def write_requirements_py(path: Path, requirements: list[Requirement]): content.append("del check_package_version\n") # write content back to file - with open(path, "w") as fp: + with path.open("w") as fp: fp.writelines(content) @@ -309,7 +309,7 @@ def write_from_template( content = template.substitute(substitutes) # write content to file - with open(path, "w") as fp: + with path.open("w") as fp: if add_warning: fp.writelines(SETUP_WARNING.format(template_name)) fp.writelines(content) diff --git a/scripts/format_code.sh b/scripts/format_code.sh index d608f699..8608a544 100755 --- a/scripts/format_code.sh +++ b/scripts/format_code.sh @@ -7,7 +7,7 @@ find . -name '*.py' -exec pyupgrade --py39-plus {} + popd > /dev/null echo "Formating import statements..." -ruff check --select I --fix --config=../pyproject.toml .. +ruff check --fix --config=../pyproject.toml .. echo "Formating docstrings..." docformatter --in-place --black --recursive .. diff --git a/scripts/performance_laplace.py b/scripts/performance_laplace.py index 6b0759a2..68201176 100755 --- a/scripts/performance_laplace.py +++ b/scripts/performance_laplace.py @@ -121,7 +121,7 @@ def laplace(arr, out=None): if out is None: out = np.empty((dim_r, dim_z)) - for j in range(0, dim_z): # iterate axial points + for j in range(dim_z): # iterate axial points jm = 0 if j == 0 else j - 1 jp = dim_z - 1 if j == dim_z - 1 else j + 1 diff --git a/scripts/run_tests.py b/scripts/run_tests.py index c4f65a69..4e49ccfb 100755 --- a/scripts/run_tests.py +++ b/scripts/run_tests.py @@ -121,7 +121,7 @@ def run_unit_tests( coverage: bool = False, nojit: bool = False, pattern: str = None, - pytest_args: list[str] = [], + pytest_args: list[str] | None = None, ) -> int: """Run the unit tests. @@ -140,6 +140,8 @@ def run_unit_tests( Returns: int: The return code indicating success or failure """ + if pytest_args is None: + pytest_args = [] # modify current environment env = os.environ.copy() env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + env.get("PYTHONPATH", "") @@ -177,8 +179,10 @@ def run_unit_tests( if use_mpi: try: import numba_mpi # @UnusedImport - except ImportError: - raise RuntimeError("Moduled `numba_mpi` is required to test with MPI") + except ImportError as err: + raise RuntimeError( + "Moduled `numba_mpi` is required to test with MPI" + ) from err args.append("--use_mpi") # only run tests requiring MPI multiprocessing # run tests using multiple cores? diff --git a/tests/conftest.py b/tests/conftest.py index 775122f0..4f93a2b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ from pde.tools.numba import random_seed -@pytest.fixture(scope="function", autouse=True) -def setup_and_teardown(): +@pytest.fixture(autouse=True) +def _setup_and_teardown(): """Helper function adjusting environment before and after tests.""" # ensure we use the Agg backend, so figures are not displayed plt.switch_backend("agg") @@ -26,7 +26,7 @@ def setup_and_teardown(): plt.close("all") -@pytest.fixture(scope="function", autouse=False, name="rng") +@pytest.fixture(autouse=False, name="rng") def init_random_number_generators(): """Get a random number generator and set the seed of the random number generator. diff --git a/tests/fields/test_field_collections.py b/tests/fields/test_field_collections.py index 32c72d84..4ceecfa7 100644 --- a/tests/fields/test_field_collections.py +++ b/tests/fields/test_field_collections.py @@ -131,7 +131,8 @@ def test_collections_append(): c1 = FieldCollection([sf], labels=["scalar"]) c2 = c1.append(vf) - assert len(c2) == 2 and len(c1) == 1 + assert len(c2) == 2 + assert len(c1) == 1 data = np.r_[np.zeros(4), np.ones(8)] np.testing.assert_allclose(c2.data.flat, data) @@ -141,7 +142,8 @@ def test_collections_append(): assert list(c2.labels) == ["scalar", "vector"] c3 = c1.append(c1, label="new") - assert len(c3) == 2 and len(c1) == 1 + assert len(c3) == 2 + assert len(c1) == 1 np.testing.assert_allclose(c3.data.flat, np.zeros(8)) assert c1.data is not c3.data @@ -150,7 +152,8 @@ def test_collections_append(): assert c3.label == "new" c4 = c1.append(c1, vf) - assert len(c4) == 3 and len(c1) == 1 + assert len(c4) == 3 + assert len(c1) == 1 data = np.r_[np.zeros(8), np.ones(8)] np.testing.assert_allclose(c4.data.flat, data) diff --git a/tests/fields/test_generic_fields.py b/tests/fields/test_generic_fields.py index f3858992..9c395f7a 100644 --- a/tests/fields/test_generic_fields.py +++ b/tests/fields/test_generic_fields.py @@ -510,7 +510,7 @@ def test_smoothing(rng): np.testing.assert_allclose(out.data, expected) out.data = 0 # reset data - f1.smooth(sigma, out=out).data + print(f1.smooth(sigma, out=out).data) np.testing.assert_allclose(out.data, expected) # test one simple higher order smoothing diff --git a/tests/fields/test_tensorial_fields.py b/tests/fields/test_tensorial_fields.py index 238ae396..fdc75552 100644 --- a/tests/fields/test_tensorial_fields.py +++ b/tests/fields/test_tensorial_fields.py @@ -188,7 +188,8 @@ def test_complex_tensors(backend, rng): numbers = rng.random(shape) + rng.random(shape) * 1j t1 = Tensor2Field(grid, numbers[0]) t2 = Tensor2Field(grid, numbers[1]) - assert t1.is_complex and t2.is_complex + assert t1.is_complex + assert t2.is_complex dot_op = t1.make_dot_operator(backend) diff --git a/tests/fields/test_vectorial_fields.py b/tests/fields/test_vectorial_fields.py index 8dfd0ac0..1bbfc532 100644 --- a/tests/fields/test_vectorial_fields.py +++ b/tests/fields/test_vectorial_fields.py @@ -260,7 +260,8 @@ def test_complex_vectors(rng): numbers = rng.random(shape) + rng.random(shape) * 1j v1 = VectorField(grid, numbers[0]) v2 = VectorField(grid, numbers[1]) - assert v1.is_complex and v2.is_complex + assert v1.is_complex + assert v2.is_complex for backend in ["numpy", "numba"]: dot_op = v1.make_dot_operator(backend) diff --git a/tests/grids/boundaries/test_axes_boundaries.py b/tests/grids/boundaries/test_axes_boundaries.py index 19ec8693..2a889488 100644 --- a/tests/grids/boundaries/test_axes_boundaries.py +++ b/tests/grids/boundaries/test_axes_boundaries.py @@ -121,7 +121,8 @@ def test_bc_values(): """Test setting the values of boundary conditions.""" g = UnitGrid([5]) bc = g.get_boundary_conditions([{"value": 2}, {"derivative": 3}]) - assert bc[0].low.value == 2 and bc[0].high.value == 3 + assert bc[0].low.value == 2 + assert bc[0].high.value == 3 @pytest.mark.parametrize("dim", [1, 2, 3]) diff --git a/tests/grids/boundaries/test_axis_boundaries.py b/tests/grids/boundaries/test_axis_boundaries.py index 4ead63da..2e18cefc 100644 --- a/tests/grids/boundaries/test_axis_boundaries.py +++ b/tests/grids/boundaries/test_axis_boundaries.py @@ -36,9 +36,11 @@ def test_boundary_pair(): data = {"low": {"value": 1}, "high": {"derivative": 2}} bc1 = BoundaryPair.from_data(g, 0, data) bc2 = BoundaryPair.from_data(g, 0, data) - assert bc1 == bc2 and bc1 is not bc2 + assert bc1 == bc2 + assert bc1 is not bc2 bc2 = BoundaryPair.from_data(g, 1, data) - assert bc1 != bc2 and bc1 is not bc2 + assert bc1 != bc2 + assert bc1 is not bc2 # miscellaneous methods data = {"low": {"value": 0}, "high": {"derivative": 0}} @@ -59,7 +61,8 @@ def test_get_axis_boundaries(): b = get_boundary_axis(g, 0, data) assert str(b) == '"' + data + '"' b1, b2 = b.get_mathematical_representation("field") - assert "field" in b1 and "field" in b2 + assert "field" in b1 + assert "field" in b2 if "periodic" in data: assert b.periodic diff --git a/tests/grids/boundaries/test_local_boundaries.py b/tests/grids/boundaries/test_local_boundaries.py index 315d838b..9cddd945 100644 --- a/tests/grids/boundaries/test_local_boundaries.py +++ b/tests/grids/boundaries/test_local_boundaries.py @@ -343,7 +343,7 @@ def func(adjacent_value, dx, x, y, t): np.testing.assert_almost_equal(f_ref._data_full, f2._data_full) -@pytest.mark.parametrize("value_expr, const_expr", [["1", "1"], ["x", "y**2"]]) +@pytest.mark.parametrize("value_expr, const_expr", [("1", "1"), ("x", "y**2")]) def test_expression_bc_setting_mixed(value_expr, const_expr, rng): """Test boundary conditions that use an expression.""" grid = CartesianGrid([[0, 1], [0, 1]], 4) @@ -529,12 +529,12 @@ def test_expression_bc_specific_value(dim, compiled): if compiled: def set_bcs(): - bcs.make_ghost_cell_setter()(field._data_full) + bcs.make_ghost_cell_setter()(field._data_full) # noqa: B023 else: def set_bcs(): - field.set_ghost_cells(bcs) + field.set_ghost_cells(bcs) # noqa: B023 if i < -n or i > n - 1: # check ut-of-bounds errors diff --git a/tests/grids/operators/test_cylindrical_operators.py b/tests/grids/operators/test_cylindrical_operators.py index 28364671..7ad65b65 100644 --- a/tests/grids/operators/test_cylindrical_operators.py +++ b/tests/grids/operators/test_cylindrical_operators.py @@ -235,7 +235,7 @@ def test_examples_tensor_cyl(): np.testing.assert_allclose(res.data, expect.data, rtol=0.1, atol=0.1) -@pytest.mark.parametrize("r_inner", (0, 1)) +@pytest.mark.parametrize("r_inner", [0, 1]) def test_laplace_matrix(r_inner, rng): """Test laplace operator implemented using matrix multiplication.""" grid = CylindricalSymGrid((r_inner, 2), (2.5, 4.3), 16) @@ -253,7 +253,7 @@ def test_laplace_matrix(r_inner, rng): np.testing.assert_allclose(res1.data, res2) -@pytest.mark.parametrize("r_inner", (0, 1)) +@pytest.mark.parametrize("r_inner", [0, 1]) def test_poisson_solver_cylindrical(r_inner, rng): """Test the poisson solver on Cylindrical grids.""" grid = CylindricalSymGrid((r_inner, 2), (2.5, 4.3), 16) diff --git a/tests/grids/operators/test_polar_operators.py b/tests/grids/operators/test_polar_operators.py index 48ecf1a6..cda01c90 100644 --- a/tests/grids/operators/test_polar_operators.py +++ b/tests/grids/operators/test_polar_operators.py @@ -95,7 +95,7 @@ def test_grid_laplace_polar(): np.testing.assert_allclose(b_1d_2.data[i, i], b_2d.data[i, i], rtol=0.2, atol=0.2) -@pytest.mark.parametrize("r_inner", (0, 2 * np.pi)) +@pytest.mark.parametrize("r_inner", [0, 2 * np.pi]) def test_gradient_squared_polar(r_inner): """Compare gradient squared operator.""" grid = PolarSymGrid((r_inner, 4 * np.pi), 32) @@ -194,7 +194,7 @@ def test_examples_tensor_polar(): np.testing.assert_allclose(res.data, expect.data, rtol=0.1, atol=0.1) -@pytest.mark.parametrize("r_inner", (0, 1)) +@pytest.mark.parametrize("r_inner", [0, 1]) def test_laplace_matrix(r_inner, rng): """Test laplace operator implemented using matrix multiplication.""" grid = PolarSymGrid((r_inner, 2), 16) diff --git a/tests/grids/operators/test_spherical_operators.py b/tests/grids/operators/test_spherical_operators.py index 512d3029..d3ebc625 100644 --- a/tests/grids/operators/test_spherical_operators.py +++ b/tests/grids/operators/test_spherical_operators.py @@ -115,7 +115,7 @@ def test_grid_laplace(): ) -@pytest.mark.parametrize("r_inner", (0, 1)) +@pytest.mark.parametrize("r_inner", [0, 1]) def test_gradient_squared(r_inner, rng): """Compare gradient squared operator.""" grid = SphericalSymGrid((r_inner, 5), 64) @@ -288,7 +288,7 @@ def test_tensor_div_div(conservative): np.testing.assert_allclose(res.data[2:-2], est.data[2:-2], rtol=0.02, atol=1) -@pytest.mark.parametrize("r_inner", (0, 1)) +@pytest.mark.parametrize("r_inner", [0, 1]) def test_laplace_matrix(r_inner, rng): """Test laplace operator implemented using matrix multiplication.""" grid = SphericalSymGrid((r_inner, 2), 16) diff --git a/tests/pdes/test_pde_class.py b/tests/pdes/test_pde_class.py index cf5b050b..4796d26a 100644 --- a/tests/pdes/test_pde_class.py +++ b/tests/pdes/test_pde_class.py @@ -3,7 +3,8 @@ """ import logging - +import os +import numba as nb import numpy as np import pytest import sympy @@ -188,7 +189,6 @@ def test_pde_noise(backend, rng): with pytest.raises(ValueError): eq = PDE({"a": 0}, noise=[0.01, 2.0]) - eq.solve(ScalarField(grid), t_range=1, backend=backend, dt=1, tracker=None) @pytest.mark.parametrize("backend", ["numpy", "numba"]) @@ -207,9 +207,12 @@ def test_pde_spatial_args(backend): # test invalid spatial dependence eq = PDE({"a": "x + y"}) - with pytest.raises(RuntimeError): - rhs = eq.make_pde_rhs(field, backend=backend) - rhs(field.data, 0.0) + if backend == "numpy": + with pytest.raises(RuntimeError): + eq.evolution_rate(field) + elif backend == "numba": + with pytest.raises(RuntimeError): + rhs = eq.make_pde_rhs(field, backend=backend) def test_pde_user_funcs(rng): @@ -284,11 +287,15 @@ def test_pde_consts(): np.testing.assert_allclose(eq.evolution_rate(field).data, 0) eq = PDE({"a": "laplace(b)"}, consts={"b": 3}) - with pytest.raises(Exception): + with pytest.raises( + AttributeError + if os.environ.get("NUMBA_DISABLE_JIT", "0") == "1" + else nb.TypingError + ): eq.evolution_rate(field) eq = PDE({"a": "laplace(b)"}, consts={"b": field.data}) - with pytest.raises(Exception): + with pytest.raises(TypeError): eq.evolution_rate(field) diff --git a/tests/solvers/test_explicit_solvers.py b/tests/solvers/test_explicit_solvers.py index 4b9da38f..822b1f8e 100644 --- a/tests/solvers/test_explicit_solvers.py +++ b/tests/solvers/test_explicit_solvers.py @@ -122,9 +122,9 @@ def test_stochastic_adaptive_solver(caplog, rng): field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng) eq = DiffusionPDE(noise=1e-6) + solver = ExplicitSolver(eq, backend="numpy", adaptive=True) + c = Controller(solver, t_range=1, tracker=None) with pytest.raises(RuntimeError): - solver = ExplicitSolver(eq, backend="numpy", adaptive=True) - c = Controller(solver, t_range=1, tracker=None) c.run(field, dt=1e-2) diff --git a/tests/storage/test_movie_storages.py b/tests/storage/test_movie_storages.py index 755271c9..0ea68901 100644 --- a/tests/storage/test_movie_storages.py +++ b/tests/storage/test_movie_storages.py @@ -177,13 +177,15 @@ def test_complex_data(tmp_path, rng): @pytest.mark.skipif(not module_available("ffmpeg"), reason="requires `ffmpeg-python`") def test_wrong_format(): """Test how wrong files are dealt with.""" + from ffmpeg._run import Error as FFmpegError + reader = MovieStorage(RESOURCES_PATH / "does_not_exist.avi") with pytest.raises(OSError): - reader.times + print(reader.times) reader = MovieStorage(RESOURCES_PATH / "empty.avi") - with pytest.raises(Exception): - reader.times + with pytest.raises(FFmpegError): + print(reader.times) reader = MovieStorage(RESOURCES_PATH / "no_metadata.avi") np.testing.assert_allclose(reader.times, [0, 1]) diff --git a/tests/test_examples.py b/tests/test_examples.py index 630fe9f0..8d9cbab0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,14 +53,14 @@ def test_example_scripts(path): # delete files that might be created by the test try: - os.remove(PACKAGE_PATH / "diffusion.mov") - os.remove(PACKAGE_PATH / "allen_cahn.avi") - os.remove(PACKAGE_PATH / "allen_cahn.hdf") + (PACKAGE_PATH / "diffusion.mov").unlink() + (PACKAGE_PATH / "allen_cahn.avi").unlink() + (PACKAGE_PATH / "allen_cahn.hdf").unlink() except OSError: pass # prepare output - msg = "Script `%s` failed with following output:" % path + msg = f"Script `{path}` failed with following output:" if outs: msg = f"{msg}\nSTDOUT:\n{outs}" if errs: @@ -86,7 +86,7 @@ def test_jupyter_notebooks(path, tmp_path): my_env = os.environ.copy() my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "") - outfile = tmp_path / os.path.basename(path) + outfile = tmp_path / path.name if jupyter_notebook.__version__.startswith("6"): # older version of running jypyter notebook # deprecated on 2023-07-31 diff --git a/tests/test_integration.py b/tests/test_integration.py index 51f751c5..f538d9ce 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -254,7 +254,8 @@ def test_modelrunner_storage_one(tmp_path, capsys): assert output.is_file() print("=" * 40) - print(open(output).read()) + with Path(output).open() as fp: + print(fp.read()) print("=" * 40) # read storage manually @@ -298,9 +299,9 @@ def test_modelrunner_storage_many(tmp_path): for path in tmp_path.iterdir(): if path.is_file() and not path.name.endswith("txt"): with mr.open_storage(path) as storage: - assert "initial_state" in storage["storage"].keys() - assert "trajectory" in storage["storage"].keys() - assert "result" in storage.keys() + assert "initial_state" in storage["storage"] + assert "trajectory" in storage["storage"] + assert "result" in storage # read result using ResultCollection results = mr.ResultCollection.from_folder(tmp_path) diff --git a/tests/tools/test_cache.py b/tests/tools/test_cache.py index b94605fe..4e961651 100644 --- a/tests/tools/test_cache.py +++ b/tests/tools/test_cache.py @@ -365,7 +365,7 @@ def cached_kwarg(self, a=0, b=0): assert method(1) == 1 assert obj.counter == 3 else: - raise ValueError("Unknown cache_factory `%s`" % cache_factory) + raise ValueError(f"Unknown cache_factory `{cache_factory}`") obj.counter = 0 # clear cache to test the second run diff --git a/tests/tools/test_config.py b/tests/tools/test_config.py index ac9c01fe..9cd7a94e 100644 --- a/tests/tools/test_config.py +++ b/tests/tools/test_config.py @@ -19,8 +19,8 @@ def test_config(): assert c["numba.multithreading_threshold"] > 0 assert "numba.multithreading_threshold" in c - assert any("numba.multithreading_threshold" == k for k in c) - assert any("numba.multithreading_threshold" == k and v > 0 for k, v in c.items()) + assert any(k == "numba.multithreading_threshold" for k in c) + assert any(k == "numba.multithreading_threshold" and v > 0 for k, v in c.items()) assert "numba.multithreading_threshold" in c.to_dict() assert isinstance(repr(c), str) @@ -90,4 +90,5 @@ def test_packages_from_requirements(): """Test the packages_from_requirements function.""" results = packages_from_requirements("file_not_existing") assert len(results) == 1 - assert "Could not open" in results[0] and "file_not_existing" in results[0] + assert "Could not open" in results[0] + assert "file_not_existing" in results[0] diff --git a/tests/tools/test_cuboid.py b/tests/tools/test_cuboid.py index b7b4f004..19357c59 100644 --- a/tests/tools/test_cuboid.py +++ b/tests/tools/test_cuboid.py @@ -58,8 +58,8 @@ def test_cuboid_2d(): np.testing.assert_array_equal(c.contains_point([[1, 3], [3, 1]]), [False, False]) np.testing.assert_array_equal(c.contains_point([[1, -1], [-1, 1]]), [False, False]) + c.mutable = False with pytest.raises(ValueError): - c.mutable = False c.centroid = [0, 0] # test surface area diff --git a/tests/tools/test_expressions.py b/tests/tools/test_expressions.py index ba583eb3..7d95d0b2 100644 --- a/tests/tools/test_expressions.py +++ b/tests/tools/test_expressions.py @@ -60,7 +60,7 @@ def test_const(expr): assert e.get_compiled()() == val assert not e.depends_on("a") assert e.differentiate("a").value == 0 - assert e.shape == tuple() + assert e.shape == () assert e.rank == 0 assert bool(e) == (val != 0) assert e.is_zero == (val == 0) @@ -89,7 +89,7 @@ def test_wrong_const(caplog): assert e() == field assert "field" in caplog.text if not nb.config.DISABLE_JIT: # @UndefinedVariable - with pytest.raises(Exception): + with pytest.raises(nb.TypingError): e.get_compiled()() @@ -102,14 +102,14 @@ def test_single_arg(rng): assert e.get_compiled()(4) == 8 assert e.differentiate("a").value == 2 assert e.differentiate("b").value == 0 - assert e.shape == tuple() + assert e.shape == () assert e.rank == 0 assert bool(e) assert not e.is_zero assert e == ScalarExpression(e.expression) with pytest.raises(TypeError): - e.value + print(e.value) arr = rng.random(5) np.testing.assert_allclose(e(arr), 2 * arr) @@ -135,7 +135,7 @@ def test_two_args(rng): assert e.differentiate("a")(4, 2) == 16 assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4)) assert e.differentiate("c").value == 0 - assert e.shape == tuple() + assert e.shape == () assert e.rank == 0 assert e == ScalarExpression(e.expression) @@ -160,7 +160,8 @@ def test_two_args(rng): def test_derivatives(): """Test vector expressions.""" e = ScalarExpression("a * b**2") - assert e.depends_on("a") and e.depends_on("b") + assert e.depends_on("a") + assert e.depends_on("b") assert not e.constant assert e.rank == 0 @@ -199,7 +200,7 @@ def test_indexed(): with pytest.raises(RuntimeError): e.differentiate("a") with pytest.raises(RuntimeError): - e.derivatives + print(e.derivatives) def test_synonyms(caplog): @@ -217,7 +218,7 @@ def test_tensor_expression(): assert e.rank == 2 assert e.constant np.testing.assert_allclose(e.get_compiled_array()(), [[0, 1], [2, 3]]) - np.testing.assert_allclose(e.get_compiled_array()(tuple()), [[0, 1], [2, 3]]) + np.testing.assert_allclose(e.get_compiled_array()(()), [[0, 1], [2, 3]]) assert e.differentiate("a") == TensorExpression("[[0, 0], [0, 0]]") np.testing.assert_allclose(e.value, np.arange(4).reshape(2, 2)) @@ -229,7 +230,7 @@ def test_tensor_expression(): assert not e.constant np.testing.assert_allclose(e.differentiate("a").value, np.array([1, 2])) with pytest.raises(TypeError): - e.value + print(e.value) assert e[0] == ScalarExpression("a") assert e[1] == ScalarExpression("2*a") assert e[0:1] == TensorExpression("[a]") @@ -331,7 +332,8 @@ def test_expression_consts(): expr = ScalarExpression("a + b", consts={"a": 1}) assert not expr.constant - assert not expr.depends_on("a") and expr.depends_on("b") + assert not expr.depends_on("a") + assert expr.depends_on("b") assert expr(2) == 3 assert expr.get_compiled()(2) == 3 @@ -433,8 +435,8 @@ def test_evaluate_func_collection(): assert isinstance(evaluate("gradient(a)", col), VectorField) assert isinstance(evaluate("a * v", col), VectorField) + col.labels = ["a", "a"] with pytest.raises(RuntimeError): - col.labels = ["a", "a"] evaluate("1", col) diff --git a/tests/tools/test_misc.py b/tests/tools/test_misc.py index 2f8fefd6..be655174 100644 --- a/tests/tools/test_misc.py +++ b/tests/tools/test_misc.py @@ -4,6 +4,7 @@ import json import os +from pathlib import Path import numpy as np import pytest @@ -23,7 +24,7 @@ def test_ensure_directory_exists(tmp_path): misc.ensure_directory_exists(path) assert path.is_dir() # remove the folder again - os.rmdir(path) + Path.rmdir(path) assert not path.exists() @@ -105,8 +106,7 @@ def test_hdf_write_attributes(tmp_path): assert data2 == {"a": 1} # test raising problematic items - with h5py.File(path, "w") as hdf_file: - with pytest.raises(TypeError): - misc.hdf_write_attributes( - hdf_file, {"a": object()}, raise_serialization_error=True - ) + with h5py.File(path, "w") as hdf_file, pytest.raises(TypeError): + misc.hdf_write_attributes( + hdf_file, {"a": object()}, raise_serialization_error=True + ) diff --git a/tests/tools/test_parameters.py b/tests/tools/test_parameters.py index 9647cf12..837cca76 100644 --- a/tests/tools/test_parameters.py +++ b/tests/tools/test_parameters.py @@ -108,7 +108,10 @@ class TestHelp2(TestHelp1): t = TestHelp2() for in_jupyter in [False, True]: - monkeypatch.setattr("pde.tools.output.in_jupyter_notebook", lambda: in_jupyter) + monkeypatch.setattr( + "pde.tools.output.in_jupyter_notebook", + lambda: in_jupyter, # noqa: B023 + ) for flags in itertools.combinations_with_replacement([True, False], 3): TestHelp2.show_parameters(*flags) diff --git a/tests/tools/test_spectral.py b/tests/tools/test_spectral.py index a96f07dd..b964c91e 100644 --- a/tests/tools/test_spectral.py +++ b/tests/tools/test_spectral.py @@ -67,18 +67,17 @@ def test_noise_scaling(rng): def noise_div(): return div(rng.normal(size=shape)) + def get_noise(noise_func): + k, density = spectral_density(data=noise_func(), dx=grid.discretization[0]) + assert k[0] == 0 + assert density[0] == pytest.approx(0) + return np.log(density[1]) # log of spectral density + # calculate spectral densities of the two noises result = [] for noise_func in [noise_colored, noise_div]: - - def get_noise(): - k, density = spectral_density(data=noise_func(), dx=grid.discretization[0]) - assert k[0] == 0 - assert density[0] == pytest.approx(0) - return np.log(density[1]) # log of spectral density - # average spectral density of longest length scale - mean = np.mean([get_noise() for _ in range(64)]) + mean = np.mean([get_noise(noise_func=noise_func) for _ in range(64)]) result.append(mean) np.testing.assert_allclose(*result, rtol=0.5) diff --git a/tests/trackers/test_trackers.py b/tests/trackers/test_trackers.py index de4a7920..aec51478 100644 --- a/tests/trackers/test_trackers.py +++ b/tests/trackers/test_trackers.py @@ -76,23 +76,21 @@ def store_time(state, t): def get_sparse_matrix_data(state): return {"integral": state.integral} - devnull = open(os.devnull, "w") - data = trackers.DataTracker(get_sparse_matrix_data, interrupts=0.1) - tracker_list = [ - trackers.PrintTracker(interrupts=0.1, stream=devnull), - trackers.CallbackTracker(store_time, interrupts=0.1), - None, # should be ignored - data, - ] - if module_available("matplotlib"): - tracker_list.append(trackers.PlotTracker(interrupts=0.1, show=False)) - - grid = UnitGrid([16, 16]) - state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) - eq = DiffusionPDE() - eq.solve(state, t_range=1, dt=0.005, tracker=tracker_list) - - devnull.close() + with open(os.devnull, "w") as devnull: # noqa: PTH123 + data = trackers.DataTracker(get_sparse_matrix_data, interrupts=0.1) + tracker_list = [ + trackers.PrintTracker(interrupts=0.1, stream=devnull), + trackers.CallbackTracker(store_time, interrupts=0.1), + None, # should be ignored + data, + ] + if module_available("matplotlib"): + tracker_list.append(trackers.PlotTracker(interrupts=0.1, show=False)) + + grid = UnitGrid([16, 16]) + state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) + eq = DiffusionPDE() + eq.solve(state, t_range=1, dt=0.005, tracker=tracker_list) assert times == data.times if module_available("pandas"):