Skip to content

Commit

Permalink
Improve mpi.allreduce and error synchronizers (#590)
Browse files Browse the repository at this point in the history
* Improve mpi.allreduce to allow specifying operator as strings
* Made the error_synchronizer available to all solvers
* Introduced `_mpi_synchronization` flags for PDEs and solvers

The flag allows us to control explicitely whether synchronization is
required. This is particularly important in situations where MPI is
used, but no synchronization is required (e.g., because PDEs are solved
independently on each node). Setting the flag explicetly allows us to
distinguish such situations.
  • Loading branch information
david-zwicker authored Aug 9, 2024
1 parent 05196a5 commit 21323ef
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray:
arr (:class:`~numpy.ndarray`): discretized data on grid
"""
integral = integrate_local(arr)
return mpi_allreduce(integral) # type: ignore
return mpi_allreduce(integral, operator="SUM") # type: ignore

return integrate_global # type: ignore

Expand Down
6 changes: 6 additions & 0 deletions pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class PDEBase(metaclass=ABCMeta):
"""bool: Flag indicating whether the right hand side is a complex-valued PDE, which
requires all involved variables to have complex data type."""

_mpi_synchronization: bool = False
"""bool: Flag indicating whether the PDE will be solved on multiple nodes using MPI.
This flag will be set by the solver. If it is true and the PDE requires global
values in its evaluation, the synchronization between nodes needs to be handled. In
many cases, PDEs are defined locally and no such synchronization is necessary."""

def __init__(self, *, noise: ArrayLike = 0, rng: np.random.Generator | None = None):
"""
Args:
Expand Down
44 changes: 35 additions & 9 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class SolverBase(metaclass=ABCMeta):
_use_post_step_hook: bool = True
"""bool: flag choosing whether the post-step hook of the PDE is called"""

_mpi_synchronization: bool = False
"""bool: Flag indicating whether MPI synchronization is required. This is never the
case for serial solvers and even parallelized solvers might set this flag to False
if no synchronization between nodes is required"""

_subclasses: dict[str, type[SolverBase]] = {}
"""dict: dictionary of all inheriting classes"""

Expand Down Expand Up @@ -117,6 +122,36 @@ def _compiled(self) -> bool:
"""bool: indicates whether functions need to be compiled"""
return self.backend == "numba" and not nb.config.DISABLE_JIT

def _make_error_synchronizer(
self, operator: int | str = "MAX"
) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes.
Args:
operator (str or int):
Flag determining how the value from multiple nodes is combined.
Possible values include "MAX", "MIN", and "SUM".
Returns:
Function that can be used to synchronize errors across nodes
"""
if self._mpi_synchronization: # mpi.parallel_run:
# in a parallel run, we need to synchronize values
from ..tools.mpi import mpi_allreduce

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return error synchronized accross all cores."""
return mpi_allreduce(error, operator=operator) # type: ignore

else:

@register_jitable
def synchronize_errors(value: float) -> float:
return value

return synchronize_errors # type: ignore

def _make_post_step_hook(self, state: FieldBase) -> StepperHook:
"""Create a function that calls the post-step hook of the PDE.
Expand Down Expand Up @@ -410,15 +445,6 @@ def __init__(
self.adaptive = adaptive
self.tolerance = tolerance

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes."""

@register_jitable
def synchronize_errors(error: float) -> float:
return error

return synchronize_errors # type: ignore

def _make_dt_adjuster(self) -> Callable[[float, float], float]:
"""Return a function that can be used to adjust time steps."""
dt_min = self.dt_min
Expand Down
93 changes: 76 additions & 17 deletions pde/solvers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ def _handle_stop_iteration(err: Exception, t: float) -> tuple[int, str]:

return _handle_stop_iteration

def _run_single(self, state: TState, dt: float | None = None) -> None:
"""Run the simulation.
def _run_main_process(self, state: TState, dt: float | None = None) -> None:
"""Run the main part of the simulation.
Diagnostic information about the solver procedure are available in the
`diagnostics` property of the instance after this function has been called.
This is either a serial run or the main node of an MPI run. Diagnostic
information about the solver procedure are available in the `diagnostics`
property of the instance after this function has been called.
Args:
state:
Expand Down Expand Up @@ -269,8 +270,8 @@ def _run_single(self, state: TState, dt: float | None = None) -> None:
f"than on the actual simulation ({profiler['solver']:.3g})"
)

def _run_mpi_client(self, state: TState, dt: float | None = None) -> None:
"""Loop for run the simulation on client nodes during an MPI run.
def _run_client_process(self, state: TState, dt: float | None = None) -> None:
"""Run the simulation on client nodes during an MPI run.
This function just loops the stepper advancing the sub field of the current node
in time. All other logic, including trackers, are done in the main node.
Expand Down Expand Up @@ -300,6 +301,72 @@ def _run_mpi_client(self, state: TState, dt: float | None = None) -> None:
while t < t_end:
t = stepper(state, t, t_end)

def _run_serial(self, state: TState, dt: float | None = None) -> TState:
"""Run the simulation in serial mode.
Diagnostic information about the solver are available in the
:attr:`~Controller.diagnostics` property after this function has been called.
Args:
state (:class:`~pde.fields.base.FieldBase`):
The initial state of the simulation.
dt (float):
Time step of the chosen stepping scheme. If `None`, a default value
based on the stepper will be chosen.
Returns:
The state at the final time point. If multiprocessing is used, only the main
node will return the state. All other nodes return None.
"""
self.info["mpi_run"] = False
self._run_main_process(state, dt)
return state

def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None:
"""Run the simulation in MPI mode.
Diagnostic information about the solver are available in the
:attr:`~Controller.diagnostics` property after this function has been called.
Args:
state (:class:`~pde.fields.base.FieldBase`):
The initial state of the simulation.
dt (float):
Time step of the chosen stepping scheme. If `None`, a default value
based on the stepper will be chosen.
Returns:
The state at the final time point. If multiprocessing is used, only the main
node will return the state. All other nodes return None.
"""
from mpi4py import MPI

self.info["mpi_run"] = True
self.info["mpi_count"] = mpi.size
self.info["mpi_rank"] = mpi.rank

if mpi.is_main:
# this node is the primary one and must thus run the main process
try:
self._run_main_process(state, dt)
except Exception as err:
print(err) # simply print the exception to show some info
MPI.COMM_WORLD.Abort() # abort all other nodes
raise
else:
return state

else:
# this node is a secondary node and must thus run the client process
try:
self._run_client_process(state, dt)
except Exception as err:
print(err) # simply print the exception to show some info
MPI.COMM_WORLD.Abort() # abort all other (and main) nodes
raise
else:
return None # do not return anything in client processes

def run(self, initial_state: TState, dt: float | None = None) -> TState | None:
"""Run the simulation.
Expand All @@ -326,15 +393,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None:
else:
state = initial_state.copy()

# decide whether to call the main routine or whether this is an MPI client
if mpi.is_main:
# this node is the primary one
self._run_single(state, dt)
self.info["process_count"] = mpi.size
if mpi.size > 1: # run the simulation on multiple nodes
return self._run_parallel(state, dt)
else:
# multiple processes are used and this is one of the secondaries
self._run_mpi_client(state, dt)
self.info["process_rank"] = mpi.rank
return None # do not return anything in client processes

return state
return self._run_serial(state, dt)
21 changes: 3 additions & 18 deletions pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable, Literal

import numpy as np
from numba.extending import register_jitable

from ..fields.base import FieldBase
from ..grids._mesh import GridMesh
Expand Down Expand Up @@ -76,6 +75,8 @@ class ExplicitMPISolver(ExplicitSolver):

name = "explicit_mpi"

_mpi_synchronization = mpi.parallel_run

def __init__(
self,
pde: PDEBase,
Expand Down Expand Up @@ -111,28 +112,12 @@ def __init__(
adaptive time stepping to choose a time step which is small enough so
the truncation error of a single step is below `tolerance`.
"""
pde._mpi_synchronization = self._mpi_synchronization
super().__init__(
pde, scheme=scheme, backend=backend, adaptive=adaptive, tolerance=tolerance
)
self.decomposition = decomposition

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes."""
# if mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import Operator, mpi_allreduce

operator_max_id = Operator.MAX

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return maximal error accross all cores."""
return mpi_allreduce(error, operator_max_id) # type: ignore

return synchronize_errors # type: ignore
# else:
# return super()._make_error_synchronizer()

def make_stepper(
self, state: FieldBase, dt=None
) -> Callable[[FieldBase, float, float], float]:
Expand Down
5 changes: 4 additions & 1 deletion pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def np_heaviside(x1, x2):

# special functions that we want to support in expressions but that are not defined by
# sympy version 1.6 or have a different signature than expected by numba/numpy
SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot}
SPECIAL_FUNCTIONS: dict[str, Callable] = {
"Heaviside": _heaviside_implemention,
"hypot": np.hypot,
}


class ListArrayPrinter(PythonCodePrinter):
Expand Down
61 changes: 41 additions & 20 deletions pde/tools/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numba as nb
import numpy as np
from numba import types
from numba.extending import overload, register_jitable
from numba.extending import SentryLiteralArgs, overload, register_jitable

try:
from numba.types import Literal
Expand Down Expand Up @@ -158,10 +158,11 @@ def impl(data, source: int, tag: int) -> None:
return impl


def mpi_allreduce(data, operator: int | str | None = None):
def mpi_allreduce(data, operator):
"""Combines data from all MPI nodes.
Note that complex datatypes and user-defined functions are not properly supported.
Note that complex datatypes and user-defined reduction operators are not properly
supported in numba-compiled cases.
Args:
data:
Expand All @@ -173,21 +174,37 @@ def mpi_allreduce(data, operator: int | str | None = None):
Returns:
The accumulated data
"""
if operator:
return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator))
if not parallel_run:
# in a serial run, we can always return the value as is
return data

if isinstance(data, np.ndarray):
# synchronize an array
out = np.empty_like(data)
MPI.COMM_WORLD.Allreduce(data, out, op=Operator.operator(operator))
return out

else:
return MPI.COMM_WORLD.allreduce(data)
# synchronize a single value
return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator))


@overload(mpi_allreduce)
def ol_mpi_allreduce(data, operator: int | str | None = None):
def ol_mpi_allreduce(data, operator):
"""Overload the `mpi_allreduce` function."""
import numba_mpi
if size == 1:
# We can simply return the value in a serial run

if operator is None or isinstance(operator, nb.types.NoneType):
op_id = -1 # value will not be used
elif isinstance(operator, Literal):
# an operator is specified (using a literal value
def impl(data, operator):
return data

return impl

# Conversely, in a parallel run, we need to use the correct reduction. Let's first
# determine the operator, which must be given as a literal type
SentryLiteralArgs(["operator"]).for_function(ol_mpi_allreduce).bind(data, operator)
if isinstance(operator, Literal):
# an operator is specified (using a literal value)
if isinstance(operator.literal_value, str):
# an operator is specified by it's name
op_id = Operator.id(operator.literal_value)
Expand All @@ -199,33 +216,37 @@ def ol_mpi_allreduce(data, operator: int | str | None = None):
else:
raise RuntimeError(f"`operator` must be a literal type, not {operator}")

import numba_mpi

@register_jitable
def _allreduce(sendobj, recvobj, operator: int | str | None = None) -> int:
def _allreduce(sendobj, recvobj, operator) -> int:
"""Helper function that calls `numba_mpi.allreduce`"""
if operator is None:
return numba_mpi.allreduce(sendobj, recvobj) # type: ignore
elif op_id is None:
if op_id is None:
return numba_mpi.allreduce(sendobj, recvobj, operator) # type: ignore
else:
return numba_mpi.allreduce(sendobj, recvobj, op_id) # type: ignore

if isinstance(data, types.Number):
# implementation of the reduction for a single number

def impl(data, operator: int | str | None = None):
def impl(data, operator):
"""Reduce a single number across all cores."""
sendobj = np.array([data])
recvobj = np.empty((1,), sendobj.dtype)
status = _allreduce(sendobj, recvobj, operator)
assert status == 0
if status != 0:
raise RuntimeError
return recvobj[0]

elif isinstance(data, types.Array):
# implementation of the reduction for a numpy array

def impl(data, operator: int | str | None = None):
def impl(data, operator):
"""Reduce an array across all cores."""
recvobj = np.empty(data.shape, data.dtype)
status = _allreduce(data, recvobj, operator)
assert status == 0
if status != 0:
raise RuntimeError
return recvobj

else:
Expand Down
2 changes: 1 addition & 1 deletion pde/tools/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_progress_bar_class(fancy: bool = True):
progress_bar_class = tqdm.tqdm
else:
# use the fancier version of the progress bar in jupyter
from tqdm.auto import tqdm as progress_bar_class
from tqdm.auto import tqdm as progress_bar_class # type: ignore
else:
# only import text progress bar
progress_bar_class = tqdm.tqdm
Expand Down
Loading

0 comments on commit 21323ef

Please sign in to comment.