Skip to content

Commit

Permalink
Add tests for adaptive MPI simulation of multiple PDEs (#588)
Browse files Browse the repository at this point in the history
* Improved documentation for MPI simulations
* Removed deprecated skipUnlessModule
  • Loading branch information
david-zwicker authored Aug 8, 2024
1 parent e9cbc35 commit 05196a5
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 87 deletions.
9 changes: 4 additions & 5 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
backend (str):
Determines how the function is created. Accepted values are 'numpy` and
'numba'. Alternatively, 'auto' lets the code decide for the most optimal
Expand All @@ -411,8 +411,7 @@ def __init__(
self.tolerance = tolerance

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return helper function that synchronizes errors between multiple
processes."""
"""Return function that synchronizes errors between multiple processes."""

@register_jitable
def synchronize_errors(error: float) -> float:
Expand Down Expand Up @@ -578,7 +577,7 @@ def adaptive_stepper(
new_state, error = single_step_error(state_data, t, dt_step)

error_rel = error / tolerance # normalize error to given tolerance
# synchronize the error between all processes (if necessary)
# synchronize the error between all processes (necessary for MPI)
error_rel = sync_errors(error_rel)

# do the step if the error is sufficiently small
Expand All @@ -592,7 +591,7 @@ def adaptive_stepper(
dt_stats.add(dt_step)

if t < t_end:
# adjust the time step and continue
# adjust the time step and continue (happens in every MPI process)
dt_opt = adjust_dt(dt_step, error_rel)
else:
break # return to the controller
Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/crank_nicolson.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
maxiter (int):
The maximal number of iterations per step
maxerror (float):
Expand Down
4 changes: 2 additions & 2 deletions pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
scheme (str):
Defines the explicit scheme to use. Supported values are 'euler' and
'runge-kutta' (or 'rk' for short).
Expand Down Expand Up @@ -225,7 +225,7 @@ def adaptive_stepper(
error = np.abs(step_large - step_small).max()
error_rel = error / tolerance # normalize error to given tolerance

# synchronize the error between all processes (if necessary)
# synchronize the error between all processes (necessary for MPI)
error_rel = sync_errors(error_rel)

if error_rel <= 1: # error is sufficiently small
Expand Down
36 changes: 20 additions & 16 deletions pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
scheme (str):
Defines the explicit scheme to use. Supported values are 'euler' and
'runge-kutta' (or 'rk' for short).
Expand Down Expand Up @@ -117,22 +117,21 @@ def __init__(
self.decomposition = decomposition

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return helper function that synchronizes errors between multiple
processes."""
if mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import Operator, mpi_allreduce
"""Return function that synchronizes errors between multiple processes."""
# if mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import Operator, mpi_allreduce

operator_max_id = Operator.MAX
operator_max_id = Operator.MAX

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return maximal error accross all cores."""
return mpi_allreduce(error, operator_max_id) # type: ignore
@register_jitable
def synchronize_errors(error: float) -> float:
"""Return maximal error accross all cores."""
return mpi_allreduce(error, operator_max_id) # type: ignore

return synchronize_errors # type: ignore
else:
return super()._make_error_synchronizer()
return synchronize_errors # type: ignore
# else:
# return super()._make_error_synchronizer()

def make_stepper(
self, state: FieldBase, dt=None
Expand Down Expand Up @@ -200,7 +199,10 @@ def wrapped_stepper(
t_end = self.mesh.broadcast(t_end)
substate_data = self.mesh.split_field_data_mpi(state.data)

# evolve the sub state
# Evolve the sub-state on each individual node. The nodes synchronize
# field data via special boundary conditions and they synchronize the
# maximal error via the error synchronizer. Apart from that, all nodes
# work independently.
t_last, dt, steps = adaptive_stepper(
substate_data,
t_start,
Expand Down Expand Up @@ -244,7 +246,9 @@ def wrapped_stepper(
steps = self.mesh.broadcast(steps)
substate_data = self.mesh.split_field_data_mpi(state.data)

# evolve the sub state
# Evolve the sub-state on each individual node. The nodes synchronize
# field data via special boundary conditions. Apart from that, all nodes
# work independently.
t_last = fixed_stepper(substate_data, t_start, steps, post_step_data)

# check whether t_last is the same for all processes
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ class ImplicitSolver(SolverBase):
def __init__(
self,
pde: PDEBase,
*,
maxiter: int = 100,
maxerror: float = 1e-4,
backend: BackendType = "auto",
):
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
maxiter (int):
The maximal number of iterations per step
maxerror (float):
Expand Down
4 changes: 2 additions & 2 deletions pde/solvers/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class ScipySolver(SolverBase):

name = "scipy"

def __init__(self, pde: PDEBase, backend: BackendType = "auto", **kwargs):
def __init__(self, pde: PDEBase, *, backend: BackendType = "auto", **kwargs):
r"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
The partial differential equation that should be solved
backend (str):
Determines how the function is created. Accepted values are
'numpy` and 'numba'. Alternatively, 'auto' lets the code decide
Expand Down
2 changes: 1 addition & 1 deletion pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def np_heaviside(x1, x2):

# special functions that we want to support in expressions but that are not defined by
# sympy version 1.6 or have a different signature than expected by numba/numpy
SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention}
SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot}


class ListArrayPrinter(PythonCodePrinter):
Expand Down
31 changes: 0 additions & 31 deletions pde/tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,37 +126,6 @@ def new_decorator(*args, **kwargs):
return new_decorator


def skipUnlessModule(module_names: str | Sequence[str]) -> Callable[[TFunc], TFunc]:
"""Decorator that skips a test when a module is not available.
Args:
module_names (str): The name of the required module(s)
Returns:
A function, so this can be used as a decorator
"""
# deprecated since 2024-01-03
warnings.warn(
"`skipUnlessModule` is deprecated. Use "
'`@pytest.mark.skipif(not module_available("module"))` instead.',
DeprecationWarning,
)

if isinstance(module_names, str):
module_names = [module_names]

for module_name in module_names:
if not module_available(module_name):
# return decorator skipping test
return unittest.skip(f"requires {module_name}")

# return no-op decorator if all modules are available
def wrapper(f: TFunc) -> TFunc:
return f

return wrapper


def import_class(identifier: str):
"""Import a class or module given an identifier.
Expand Down
3 changes: 2 additions & 1 deletion pde/visualization/movies.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(
The number of frames per second, which determines how fast the
movie will appear to run.
dpi (float):
The resolution of the resulting movie
The resolution of the resulting movie. The default value is controlled
by :mod:`matplotlib` and is usally set to 100.
\**kwargs:
Additional parameters are used to initialize
:class:`matplotlib.animation.FFMpegWriter`. Here, we can for instance
Expand Down
76 changes: 55 additions & 21 deletions tests/solvers/test_explicit_mpi_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@
import numpy as np
import pytest

from pde import DiffusionPDE, ScalarField, UnitGrid
from pde import DiffusionPDE, ScalarField, UnitGrid, PDE, FieldCollection
from pde.solvers import Controller, ExplicitMPISolver
from pde.tools import mpi


@pytest.mark.multiprocessing
@pytest.mark.parametrize("backend", ["numpy", "numba"])
@pytest.mark.parametrize(
"scheme, decomposition",
[("euler", "auto"), ("euler", [1, -1]), ("runge-kutta", [-1, 1])],
"scheme, adaptive, decomposition",
[
("euler", False, "auto"),
("euler", True, [1, -1]),
("runge-kutta", False, [-1, 1]),
],
)
def test_simple_pde_mpi(scheme, decomposition, rng):
def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng):
"""Test setting boundary conditions using numba."""
grid = UnitGrid([8, 8], periodic=[True, False])

Expand All @@ -26,32 +31,28 @@ def test_simple_pde_mpi(scheme, decomposition, rng):
"state": field,
"t_range": 1.01,
"dt": 0.1,
"adaptive": adaptive,
"scheme": scheme,
"tracker": None,
"ret_info": True,
}
res1, info1 = eq.solve(
backend="numpy", solver="explicit_mpi", decomposition=decomposition, **args
)
res2, info2 = eq.solve(
backend="numba", solver="explicit_mpi", decomposition=decomposition, **args
res_mpi, info_mpi = eq.solve(
backend=backend, solver="explicit_mpi", decomposition=decomposition, **args
)

if mpi.is_main:
# check results in the main process
expect, _ = eq.solve(backend="numpy", solver="explicit", **args)
np.testing.assert_allclose(res1.data, expect.data)
np.testing.assert_allclose(res2.data, expect.data)
expect, info2 = eq.solve(backend="numpy", solver="explicit", **args)
np.testing.assert_allclose(res_mpi.data, expect.data)

for info in [info1, info2]:
assert info["solver"]["steps"] == 11
assert info["solver"]["use_mpi"]
if decomposition != "auto":
for i in range(2):
if decomposition[i] == 1:
assert info["solver"]["grid_decomposition"][i] == 1
else:
assert info["solver"]["grid_decomposition"][i] == mpi.size
assert info_mpi["solver"]["steps"] == info2["solver"]["steps"]
assert info_mpi["solver"]["use_mpi"]
if decomposition != "auto":
for i in range(2):
if decomposition[i] == 1:
assert info_mpi["solver"]["grid_decomposition"][i] == 1
else:
assert info_mpi["solver"]["grid_decomposition"][i] == mpi.size


@pytest.mark.multiprocessing
Expand All @@ -77,3 +78,36 @@ def test_stochastic_mpi_solvers(backend, rng):

assert not solver1.info["dt_adaptive"]
assert not solver2.info["dt_adaptive"]


@pytest.mark.multiprocessing
@pytest.mark.parametrize("backend", ["numpy", "numba"])
def test_multiple_pdes_mpi(backend, rng):
"""Test setting boundary conditions using numba."""
grid = UnitGrid([8, 8], periodic=[True, False])

fields = FieldCollection.scalar_random_uniform(2, grid, rng=rng)
eq = PDE({"a": "laplace(a) - b", "b": "laplace(b) + a"})

args = {
"state": fields,
"t_range": 1.01,
"dt": 0.1,
"adaptive": True,
"scheme": "euler",
"tracker": None,
"ret_info": True,
}
res_mpi, info_mpi = eq.solve(backend=backend, solver="explicit_mpi", **args)

if mpi.is_main:
# check results in the main process
expect, info2 = eq.solve(backend="numpy", solver="explicit", **args)
np.testing.assert_allclose(res_mpi.data, expect.data)

assert info_mpi["solver"]["steps"] == info2["solver"]["steps"]
assert info_mpi["solver"]["use_mpi"]
from pprint import pprint

pprint(info_mpi)
print(info_mpi)
6 changes: 0 additions & 6 deletions tests/tools/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,3 @@ def test_hdf_write_attributes(tmp_path):
misc.hdf_write_attributes(
hdf_file, {"a": object()}, raise_serialization_error=True
)


@misc.skipUnlessModule("undefined_module_name")
def test_skipUnlessModule():
"""Test skipUnlessModule decorator."""
raise RuntimeError # test should never run

0 comments on commit 05196a5

Please sign in to comment.