Skip to content

Commit

Permalink
Refactored synchronization of values in MPI runs
Browse files Browse the repository at this point in the history
* Replace error_synchronization by generic mpi_allreduce
* Improved mpi_allreduce to deal with string Operators in numba
* Added additional tests for mpi_allreduce
* Improved exception handling in MPI simulations
* Multiple smaller improvements
  • Loading branch information
david-zwicker committed Aug 9, 2024
1 parent 05196a5 commit 64d433c
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 78 deletions.
14 changes: 4 additions & 10 deletions pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..tools.cache import cached_method, cached_property
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import Number, hybridmethod
from ..tools.mpi import mpi_allreduce
from ..tools.numba import jit
from ..tools.typing import (
CellVolume,
Expand Down Expand Up @@ -1572,14 +1573,9 @@ def integrate(
if self._mesh is None or len(self._mesh) == 1:
# standard case of a single integral
return integral # type: ignore

else:
# we are in a parallel run, so we need to gather the sub-integrals from all
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport

integral_full = np.empty_like(integral)
COMM_WORLD.Allreduce(integral, integral_full)
return integral_full # type: ignore
# accumulate integrals from all subprocesse (necessary if MPI is used)
return mpi_allreduce(integral, operator="SUM") # type: ignore

@cached_method()
def make_normalize_point_compiled(
Expand Down Expand Up @@ -2135,8 +2131,6 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray:
else:
# we are in a parallel run, so we need to gather the sub-integrals from all
# subgrids in the grid mesh
from ..tools.mpi import mpi_allreduce

@jit
def integrate_global(arr: np.ndarray) -> NumberOrArray:
"""Integrate data over MPI parallelized grid.
Expand All @@ -2145,7 +2139,7 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray:
arr (:class:`~numpy.ndarray`): discretized data on grid
"""
integral = integrate_local(arr)
return mpi_allreduce(integral) # type: ignore
return mpi_allreduce(integral, operator="SUM") # type: ignore

return integrate_global # type: ignore

Expand Down
19 changes: 12 additions & 7 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.misc import classproperty
from ..tools.mpi import mpi_allreduce
from ..tools.numba import is_jitted, jit
from ..tools.typing import BackendType, StepperHook

Expand Down Expand Up @@ -410,14 +411,19 @@ def __init__(
self.adaptive = adaptive
self.tolerance = tolerance

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes."""
def _make_error_synchronizer(self):
# Deprecated on 2024-08-09
warnings.warn(
"`_make_error_synchronizer` has been replaced by "
"`pde.tools.mpi.mpi_allreduce`",
DeprecationWarning,
)

@register_jitable
def synchronize_errors(error: float) -> float:
return error
def error_synchronizer(value):
return mpi_allreduce(value, operator="MAX")

return synchronize_errors # type: ignore
return error_synchronizer

def _make_dt_adjuster(self) -> Callable[[float, float], float]:
"""Return a function that can be used to adjust time steps."""
Expand Down Expand Up @@ -545,7 +551,6 @@ def _make_adaptive_stepper(self, state: FieldBase) -> Callable[
# obtain functions determining how the PDE is evolved
single_step_error = self._make_single_step_error_estimate(state)
post_step_hook = self._make_post_step_hook(state)
sync_errors = self._make_error_synchronizer()

# obtain auxiliary functions
adjust_dt = self._make_dt_adjuster()
Expand Down Expand Up @@ -578,7 +583,7 @@ def adaptive_stepper(

error_rel = error / tolerance # normalize error to given tolerance
# synchronize the error between all processes (necessary for MPI)
error_rel = sync_errors(error_rel)
error_rel = mpi_allreduce(error_rel, operator="MAX")

# do the step if the error is sufficiently small
if error_rel <= 1:
Expand Down
32 changes: 25 additions & 7 deletions pde/solvers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,32 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None:
state = initial_state.copy()

# decide whether to call the main routine or whether this is an MPI client
if mpi.is_main:
# this node is the primary one
self._run_single(state, dt)
if mpi.size > 1:
self.info["parallel_run"] = True
self.info["process_count"] = mpi.size
else:
# multiple processes are used and this is one of the secondaries
self._run_mpi_client(state, dt)
self.info["process_rank"] = mpi.rank
return None # do not return anything in client processes
from mpi4py import MPI

if mpi.is_main: # this node is the primary one
try:
self._run_single(state, dt)
except Exception:
# found exception on the main node
MPI.COMM_WORLD.Abort() # abort all other nodes
raise
else: # this node is a secondary (client) node
try:
self._run_mpi_client(state, dt)
except Exception as exception:
print(exception) # simply print the exception to show some info
MPI.COMM_WORLD.Abort() # abort all other (and main) nodes
raise
else:
return None # do not return anything in client processes

else:
# serial run without MPI
self.info["parallel_run"] = False
self._run_single(state, dt)

return state
4 changes: 2 additions & 2 deletions pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.mpi import mpi_allreduce
from ..tools.numba import jit
from ..tools.typing import BackendType
from .base import AdaptiveSolverBase
Expand Down Expand Up @@ -182,7 +183,6 @@ def _make_adaptive_euler_stepper(self, state: FieldBase) -> Callable[
post_step_hook = self._make_post_step_hook(state)

# obtain auxiliary functions
sync_errors = self._make_error_synchronizer()
adjust_dt = self._make_dt_adjuster()
tolerance = self.tolerance
dt_min = self.dt_min
Expand Down Expand Up @@ -226,7 +226,7 @@ def adaptive_stepper(
error_rel = error / tolerance # normalize error to given tolerance

# synchronize the error between all processes (necessary for MPI)
error_rel = sync_errors(error_rel)
error_rel = mpi_allreduce(error_rel, operator="MAX")

if error_rel <= 1: # error is sufficiently small
try:
Expand Down
28 changes: 10 additions & 18 deletions pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable, Literal

import numpy as np
from numba.extending import register_jitable

from ..fields.base import FieldBase
from ..grids._mesh import GridMesh
Expand Down Expand Up @@ -116,23 +115,6 @@ def __init__(
)
self.decomposition = decomposition

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes."""
# if mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import Operator, mpi_allreduce

operator_max_id = Operator.MAX

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return maximal error accross all cores."""
return mpi_allreduce(error, operator_max_id) # type: ignore

return synchronize_errors # type: ignore
# else:
# return super()._make_error_synchronizer()

def make_stepper(
self, state: FieldBase, dt=None
) -> Callable[[FieldBase, float, float], float]:
Expand Down Expand Up @@ -178,12 +160,16 @@ def make_stepper(
# decompose the state into multiple cells
self.mesh = GridMesh.from_grid(state.grid, self.decomposition)
sub_state = self.mesh.extract_subfield(state)
print("GOT sub_state for node", mpi.rank)
self.info["grid_decomposition"] = self.mesh.shape

if self.adaptive:
# create stepper with adaptive steps
self.info["dt_statistics"] = OnlineStatistics()

print("MAKE adaptive stepper for rank", mpi.rank)
adaptive_stepper = self._make_adaptive_stepper(sub_state)
print("GOT adaptive stepper for rank", mpi.rank)
self.info["post_step_data"] = self._post_step_data_init

def wrapped_stepper(
Expand All @@ -196,13 +182,15 @@ def wrapped_stepper(
post_step_data = self.info["post_step_data"]

# distribute the end time and the field to all nodes
print("BROADCAST DATA")
t_end = self.mesh.broadcast(t_end)
substate_data = self.mesh.split_field_data_mpi(state.data)

# Evolve the sub-state on each individual node. The nodes synchronize
# field data via special boundary conditions and they synchronize the
# maximal error via the error synchronizer. Apart from that, all nodes
# work independently.
print("START ADAPTIVE STEPPER")
t_last, dt, steps = adaptive_stepper(
substate_data,
t_start,
Expand All @@ -211,16 +199,20 @@ def wrapped_stepper(
self.info["dt_statistics"],
post_step_data,
)
print("t_last", t_last)

# check whether dt is the same for all processes
dt_list = self.mesh.allgather(dt)
if not np.isclose(min(dt_list), max(dt_list)):
# abort simulations in all nodes when they went out of sync
raise RuntimeError(f"Processes went out of sync: dt={dt_list}")
print("dt_list", dt_list)

# collect the data from all nodes
post_step_data_list = self.mesh.gather(post_step_data)
print("post_step_data_list", post_step_data_list)
self.mesh.combine_field_data_mpi(substate_data, out=state.data)
print("substate_data", state.data)

if mpi.is_main:
self.info["steps"] += steps
Expand Down
5 changes: 4 additions & 1 deletion pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def np_heaviside(x1, x2):

# special functions that we want to support in expressions but that are not defined by
# sympy version 1.6 or have a different signature than expected by numba/numpy
SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot}
SPECIAL_FUNCTIONS: dict[str, Callable] = {
"Heaviside": _heaviside_implemention,
"hypot": np.hypot,
}


class ListArrayPrinter(PythonCodePrinter):
Expand Down
Loading

0 comments on commit 64d433c

Please sign in to comment.