diff --git a/pde/solvers/base.py b/pde/solvers/base.py index 3acef146..5731542a 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -393,7 +393,7 @@ def __init__( """ Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide for the most optimal @@ -411,8 +411,7 @@ def __init__( self.tolerance = tolerance def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return helper function that synchronizes errors between multiple - processes.""" + """Return function that synchronizes errors between multiple processes.""" @register_jitable def synchronize_errors(error: float) -> float: @@ -578,7 +577,7 @@ def adaptive_stepper( new_state, error = single_step_error(state_data, t, dt_step) error_rel = error / tolerance # normalize error to given tolerance - # synchronize the error between all processes (if necessary) + # synchronize the error between all processes (necessary for MPI) error_rel = sync_errors(error_rel) # do the step if the error is sufficiently small @@ -592,7 +591,7 @@ def adaptive_stepper( dt_stats.add(dt_step) if t < t_end: - # adjust the time step and continue + # adjust the time step and continue (happens in every MPI process) dt_opt = adjust_dt(dt_step, error_rel) else: break # return to the controller diff --git a/pde/solvers/crank_nicolson.py b/pde/solvers/crank_nicolson.py index acbee2a3..d12b6ec0 100644 --- a/pde/solvers/crank_nicolson.py +++ b/pde/solvers/crank_nicolson.py @@ -33,7 +33,7 @@ def __init__( """ Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved maxiter (int): The maximal number of iterations per step maxerror (float): diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py index 029d8588..a2471e0c 100644 --- a/pde/solvers/explicit.py +++ b/pde/solvers/explicit.py @@ -35,7 +35,7 @@ def __init__( """ Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved scheme (str): Defines the explicit scheme to use. Supported values are 'euler' and 'runge-kutta' (or 'rk' for short). @@ -225,7 +225,7 @@ def adaptive_stepper( error = np.abs(step_large - step_small).max() error_rel = error / tolerance # normalize error to given tolerance - # synchronize the error between all processes (if necessary) + # synchronize the error between all processes (necessary for MPI) error_rel = sync_errors(error_rel) if error_rel <= 1: # error is sufficiently small diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index 9751de50..ac792be9 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -89,7 +89,7 @@ def __init__( """ Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved scheme (str): Defines the explicit scheme to use. Supported values are 'euler' and 'runge-kutta' (or 'rk' for short). @@ -117,22 +117,21 @@ def __init__( self.decomposition = decomposition def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return helper function that synchronizes errors between multiple - processes.""" - if mpi.parallel_run: - # in a parallel run, we need to return the maximal error - from ..tools.mpi import Operator, mpi_allreduce + """Return function that synchronizes errors between multiple processes.""" + # if mpi.parallel_run: + # in a parallel run, we need to return the maximal error + from ..tools.mpi import Operator, mpi_allreduce - operator_max_id = Operator.MAX + operator_max_id = Operator.MAX - @register_jitable - def synchronize_errors(error: float) -> float: - """Return maximal error accross all cores.""" - return mpi_allreduce(error, operator_max_id) # type: ignore + @register_jitable + def synchronize_errors(error: float) -> float: + """Return maximal error accross all cores.""" + return mpi_allreduce(error, operator_max_id) # type: ignore - return synchronize_errors # type: ignore - else: - return super()._make_error_synchronizer() + return synchronize_errors # type: ignore + # else: + # return super()._make_error_synchronizer() def make_stepper( self, state: FieldBase, dt=None @@ -200,7 +199,10 @@ def wrapped_stepper( t_end = self.mesh.broadcast(t_end) substate_data = self.mesh.split_field_data_mpi(state.data) - # evolve the sub state + # Evolve the sub-state on each individual node. The nodes synchronize + # field data via special boundary conditions and they synchronize the + # maximal error via the error synchronizer. Apart from that, all nodes + # work independently. t_last, dt, steps = adaptive_stepper( substate_data, t_start, @@ -244,7 +246,9 @@ def wrapped_stepper( steps = self.mesh.broadcast(steps) substate_data = self.mesh.split_field_data_mpi(state.data) - # evolve the sub state + # Evolve the sub-state on each individual node. The nodes synchronize + # field data via special boundary conditions. Apart from that, all nodes + # work independently. t_last = fixed_stepper(substate_data, t_start, steps, post_step_data) # check whether t_last is the same for all processes diff --git a/pde/solvers/implicit.py b/pde/solvers/implicit.py index 364e493b..e5d2a353 100644 --- a/pde/solvers/implicit.py +++ b/pde/solvers/implicit.py @@ -24,6 +24,7 @@ class ImplicitSolver(SolverBase): def __init__( self, pde: PDEBase, + *, maxiter: int = 100, maxerror: float = 1e-4, backend: BackendType = "auto", @@ -31,7 +32,7 @@ def __init__( """ Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved maxiter (int): The maximal number of iterations per step maxerror (float): diff --git a/pde/solvers/scipy.py b/pde/solvers/scipy.py index 60921a11..579845f1 100644 --- a/pde/solvers/scipy.py +++ b/pde/solvers/scipy.py @@ -25,11 +25,11 @@ class ScipySolver(SolverBase): name = "scipy" - def __init__(self, pde: PDEBase, backend: BackendType = "auto", **kwargs): + def __init__(self, pde: PDEBase, *, backend: BackendType = "auto", **kwargs): r""" Args: pde (:class:`~pde.pdes.base.PDEBase`): - The instance describing the pde that needs to be solved + The partial differential equation that should be solved backend (str): Determines how the function is created. Accepted values are 'numpy` and 'numba'. Alternatively, 'auto' lets the code decide diff --git a/pde/tools/expressions.py b/pde/tools/expressions.py index 88b6fa26..411c9242 100644 --- a/pde/tools/expressions.py +++ b/pde/tools/expressions.py @@ -131,7 +131,7 @@ def np_heaviside(x1, x2): # special functions that we want to support in expressions but that are not defined by # sympy version 1.6 or have a different signature than expected by numba/numpy -SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention} +SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot} class ListArrayPrinter(PythonCodePrinter): diff --git a/pde/tools/misc.py b/pde/tools/misc.py index e405ded5..117af350 100644 --- a/pde/tools/misc.py +++ b/pde/tools/misc.py @@ -126,37 +126,6 @@ def new_decorator(*args, **kwargs): return new_decorator -def skipUnlessModule(module_names: str | Sequence[str]) -> Callable[[TFunc], TFunc]: - """Decorator that skips a test when a module is not available. - - Args: - module_names (str): The name of the required module(s) - - Returns: - A function, so this can be used as a decorator - """ - # deprecated since 2024-01-03 - warnings.warn( - "`skipUnlessModule` is deprecated. Use " - '`@pytest.mark.skipif(not module_available("module"))` instead.', - DeprecationWarning, - ) - - if isinstance(module_names, str): - module_names = [module_names] - - for module_name in module_names: - if not module_available(module_name): - # return decorator skipping test - return unittest.skip(f"requires {module_name}") - - # return no-op decorator if all modules are available - def wrapper(f: TFunc) -> TFunc: - return f - - return wrapper - - def import_class(identifier: str): """Import a class or module given an identifier. diff --git a/pde/visualization/movies.py b/pde/visualization/movies.py index f6db4850..e4e24a22 100644 --- a/pde/visualization/movies.py +++ b/pde/visualization/movies.py @@ -57,7 +57,8 @@ def __init__( The number of frames per second, which determines how fast the movie will appear to run. dpi (float): - The resolution of the resulting movie + The resolution of the resulting movie. The default value is controlled + by :mod:`matplotlib` and is usally set to 100. \**kwargs: Additional parameters are used to initialize :class:`matplotlib.animation.FFMpegWriter`. Here, we can for instance diff --git a/tests/solvers/test_explicit_mpi_solvers.py b/tests/solvers/test_explicit_mpi_solvers.py index 67918603..24bd6126 100644 --- a/tests/solvers/test_explicit_mpi_solvers.py +++ b/tests/solvers/test_explicit_mpi_solvers.py @@ -5,17 +5,22 @@ import numpy as np import pytest -from pde import DiffusionPDE, ScalarField, UnitGrid +from pde import DiffusionPDE, ScalarField, UnitGrid, PDE, FieldCollection from pde.solvers import Controller, ExplicitMPISolver from pde.tools import mpi @pytest.mark.multiprocessing +@pytest.mark.parametrize("backend", ["numpy", "numba"]) @pytest.mark.parametrize( - "scheme, decomposition", - [("euler", "auto"), ("euler", [1, -1]), ("runge-kutta", [-1, 1])], + "scheme, adaptive, decomposition", + [ + ("euler", False, "auto"), + ("euler", True, [1, -1]), + ("runge-kutta", False, [-1, 1]), + ], ) -def test_simple_pde_mpi(scheme, decomposition, rng): +def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng): """Test setting boundary conditions using numba.""" grid = UnitGrid([8, 8], periodic=[True, False]) @@ -26,32 +31,28 @@ def test_simple_pde_mpi(scheme, decomposition, rng): "state": field, "t_range": 1.01, "dt": 0.1, + "adaptive": adaptive, "scheme": scheme, "tracker": None, "ret_info": True, } - res1, info1 = eq.solve( - backend="numpy", solver="explicit_mpi", decomposition=decomposition, **args - ) - res2, info2 = eq.solve( - backend="numba", solver="explicit_mpi", decomposition=decomposition, **args + res_mpi, info_mpi = eq.solve( + backend=backend, solver="explicit_mpi", decomposition=decomposition, **args ) if mpi.is_main: # check results in the main process - expect, _ = eq.solve(backend="numpy", solver="explicit", **args) - np.testing.assert_allclose(res1.data, expect.data) - np.testing.assert_allclose(res2.data, expect.data) + expect, info2 = eq.solve(backend="numpy", solver="explicit", **args) + np.testing.assert_allclose(res_mpi.data, expect.data) - for info in [info1, info2]: - assert info["solver"]["steps"] == 11 - assert info["solver"]["use_mpi"] - if decomposition != "auto": - for i in range(2): - if decomposition[i] == 1: - assert info["solver"]["grid_decomposition"][i] == 1 - else: - assert info["solver"]["grid_decomposition"][i] == mpi.size + assert info_mpi["solver"]["steps"] == info2["solver"]["steps"] + assert info_mpi["solver"]["use_mpi"] + if decomposition != "auto": + for i in range(2): + if decomposition[i] == 1: + assert info_mpi["solver"]["grid_decomposition"][i] == 1 + else: + assert info_mpi["solver"]["grid_decomposition"][i] == mpi.size @pytest.mark.multiprocessing @@ -77,3 +78,36 @@ def test_stochastic_mpi_solvers(backend, rng): assert not solver1.info["dt_adaptive"] assert not solver2.info["dt_adaptive"] + + +@pytest.mark.multiprocessing +@pytest.mark.parametrize("backend", ["numpy", "numba"]) +def test_multiple_pdes_mpi(backend, rng): + """Test setting boundary conditions using numba.""" + grid = UnitGrid([8, 8], periodic=[True, False]) + + fields = FieldCollection.scalar_random_uniform(2, grid, rng=rng) + eq = PDE({"a": "laplace(a) - b", "b": "laplace(b) + a"}) + + args = { + "state": fields, + "t_range": 1.01, + "dt": 0.1, + "adaptive": True, + "scheme": "euler", + "tracker": None, + "ret_info": True, + } + res_mpi, info_mpi = eq.solve(backend=backend, solver="explicit_mpi", **args) + + if mpi.is_main: + # check results in the main process + expect, info2 = eq.solve(backend="numpy", solver="explicit", **args) + np.testing.assert_allclose(res_mpi.data, expect.data) + + assert info_mpi["solver"]["steps"] == info2["solver"]["steps"] + assert info_mpi["solver"]["use_mpi"] + from pprint import pprint + + pprint(info_mpi) + print(info_mpi) diff --git a/tests/tools/test_misc.py b/tests/tools/test_misc.py index 63859518..2f8fefd6 100644 --- a/tests/tools/test_misc.py +++ b/tests/tools/test_misc.py @@ -110,9 +110,3 @@ def test_hdf_write_attributes(tmp_path): misc.hdf_write_attributes( hdf_file, {"a": object()}, raise_serialization_error=True ) - - -@misc.skipUnlessModule("undefined_module_name") -def test_skipUnlessModule(): - """Test skipUnlessModule decorator.""" - raise RuntimeError # test should never run