From 21323efbc38cb818bb4dc9018938fccc201d4a22 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Fri, 9 Aug 2024 16:24:36 +0200 Subject: [PATCH] Improve mpi.allreduce and error synchronizers (#590) * Improve mpi.allreduce to allow specifying operator as strings * Made the error_synchronizer available to all solvers * Introduced `_mpi_synchronization` flags for PDEs and solvers The flag allows us to control explicitely whether synchronization is required. This is particularly important in situations where MPI is used, but no synchronization is required (e.g., because PDEs are solved independently on each node). Setting the flag explicetly allows us to distinguish such situations. --- pde/grids/base.py | 2 +- pde/pdes/base.py | 6 ++ pde/solvers/base.py | 44 +++++++--- pde/solvers/controller.py | 93 ++++++++++++++++++---- pde/solvers/explicit_mpi.py | 21 +---- pde/tools/expressions.py | 5 +- pde/tools/mpi.py | 61 +++++++++----- pde/tools/output.py | 2 +- tests/solvers/test_explicit_mpi_solvers.py | 8 +- tests/tools/test_mpi.py | 37 +++++++-- 10 files changed, 198 insertions(+), 81 deletions(-) diff --git a/pde/grids/base.py b/pde/grids/base.py index 3da7949f..c716d4ee 100644 --- a/pde/grids/base.py +++ b/pde/grids/base.py @@ -2145,7 +2145,7 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray: arr (:class:`~numpy.ndarray`): discretized data on grid """ integral = integrate_local(arr) - return mpi_allreduce(integral) # type: ignore + return mpi_allreduce(integral, operator="SUM") # type: ignore return integrate_global # type: ignore diff --git a/pde/pdes/base.py b/pde/pdes/base.py index c63c477c..c6a01b77 100644 --- a/pde/pdes/base.py +++ b/pde/pdes/base.py @@ -60,6 +60,12 @@ class PDEBase(metaclass=ABCMeta): """bool: Flag indicating whether the right hand side is a complex-valued PDE, which requires all involved variables to have complex data type.""" + _mpi_synchronization: bool = False + """bool: Flag indicating whether the PDE will be solved on multiple nodes using MPI. + This flag will be set by the solver. If it is true and the PDE requires global + values in its evaluation, the synchronization between nodes needs to be handled. In + many cases, PDEs are defined locally and no such synchronization is necessary.""" + def __init__(self, *, noise: ArrayLike = 0, rng: np.random.Generator | None = None): """ Args: diff --git a/pde/solvers/base.py b/pde/solvers/base.py index 5731542a..24542b5d 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -39,6 +39,11 @@ class SolverBase(metaclass=ABCMeta): _use_post_step_hook: bool = True """bool: flag choosing whether the post-step hook of the PDE is called""" + _mpi_synchronization: bool = False + """bool: Flag indicating whether MPI synchronization is required. This is never the + case for serial solvers and even parallelized solvers might set this flag to False + if no synchronization between nodes is required""" + _subclasses: dict[str, type[SolverBase]] = {} """dict: dictionary of all inheriting classes""" @@ -117,6 +122,36 @@ def _compiled(self) -> bool: """bool: indicates whether functions need to be compiled""" return self.backend == "numba" and not nb.config.DISABLE_JIT + def _make_error_synchronizer( + self, operator: int | str = "MAX" + ) -> Callable[[float], float]: + """Return function that synchronizes errors between multiple processes. + + Args: + operator (str or int): + Flag determining how the value from multiple nodes is combined. + Possible values include "MAX", "MIN", and "SUM". + + Returns: + Function that can be used to synchronize errors across nodes + """ + if self._mpi_synchronization: # mpi.parallel_run: + # in a parallel run, we need to synchronize values + from ..tools.mpi import mpi_allreduce + + @register_jitable + def synchronize_errors(error: float) -> float: + """Return error synchronized accross all cores.""" + return mpi_allreduce(error, operator=operator) # type: ignore + + else: + + @register_jitable + def synchronize_errors(value: float) -> float: + return value + + return synchronize_errors # type: ignore + def _make_post_step_hook(self, state: FieldBase) -> StepperHook: """Create a function that calls the post-step hook of the PDE. @@ -410,15 +445,6 @@ def __init__( self.adaptive = adaptive self.tolerance = tolerance - def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return function that synchronizes errors between multiple processes.""" - - @register_jitable - def synchronize_errors(error: float) -> float: - return error - - return synchronize_errors # type: ignore - def _make_dt_adjuster(self) -> Callable[[float, float], float]: """Return a function that can be used to adjust time steps.""" dt_min = self.dt_min diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py index 2ef4ca6a..06d74a48 100644 --- a/pde/solvers/controller.py +++ b/pde/solvers/controller.py @@ -143,11 +143,12 @@ def _handle_stop_iteration(err: Exception, t: float) -> tuple[int, str]: return _handle_stop_iteration - def _run_single(self, state: TState, dt: float | None = None) -> None: - """Run the simulation. + def _run_main_process(self, state: TState, dt: float | None = None) -> None: + """Run the main part of the simulation. - Diagnostic information about the solver procedure are available in the - `diagnostics` property of the instance after this function has been called. + This is either a serial run or the main node of an MPI run. Diagnostic + information about the solver procedure are available in the `diagnostics` + property of the instance after this function has been called. Args: state: @@ -269,8 +270,8 @@ def _run_single(self, state: TState, dt: float | None = None) -> None: f"than on the actual simulation ({profiler['solver']:.3g})" ) - def _run_mpi_client(self, state: TState, dt: float | None = None) -> None: - """Loop for run the simulation on client nodes during an MPI run. + def _run_client_process(self, state: TState, dt: float | None = None) -> None: + """Run the simulation on client nodes during an MPI run. This function just loops the stepper advancing the sub field of the current node in time. All other logic, including trackers, are done in the main node. @@ -300,6 +301,72 @@ def _run_mpi_client(self, state: TState, dt: float | None = None) -> None: while t < t_end: t = stepper(state, t, t_end) + def _run_serial(self, state: TState, dt: float | None = None) -> TState: + """Run the simulation in serial mode. + + Diagnostic information about the solver are available in the + :attr:`~Controller.diagnostics` property after this function has been called. + + Args: + state (:class:`~pde.fields.base.FieldBase`): + The initial state of the simulation. + dt (float): + Time step of the chosen stepping scheme. If `None`, a default value + based on the stepper will be chosen. + + Returns: + The state at the final time point. If multiprocessing is used, only the main + node will return the state. All other nodes return None. + """ + self.info["mpi_run"] = False + self._run_main_process(state, dt) + return state + + def _run_parallel(self, state: TState, dt: float | None = None) -> TState | None: + """Run the simulation in MPI mode. + + Diagnostic information about the solver are available in the + :attr:`~Controller.diagnostics` property after this function has been called. + + Args: + state (:class:`~pde.fields.base.FieldBase`): + The initial state of the simulation. + dt (float): + Time step of the chosen stepping scheme. If `None`, a default value + based on the stepper will be chosen. + + Returns: + The state at the final time point. If multiprocessing is used, only the main + node will return the state. All other nodes return None. + """ + from mpi4py import MPI + + self.info["mpi_run"] = True + self.info["mpi_count"] = mpi.size + self.info["mpi_rank"] = mpi.rank + + if mpi.is_main: + # this node is the primary one and must thus run the main process + try: + self._run_main_process(state, dt) + except Exception as err: + print(err) # simply print the exception to show some info + MPI.COMM_WORLD.Abort() # abort all other nodes + raise + else: + return state + + else: + # this node is a secondary node and must thus run the client process + try: + self._run_client_process(state, dt) + except Exception as err: + print(err) # simply print the exception to show some info + MPI.COMM_WORLD.Abort() # abort all other (and main) nodes + raise + else: + return None # do not return anything in client processes + def run(self, initial_state: TState, dt: float | None = None) -> TState | None: """Run the simulation. @@ -326,15 +393,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None: else: state = initial_state.copy() - # decide whether to call the main routine or whether this is an MPI client - if mpi.is_main: - # this node is the primary one - self._run_single(state, dt) - self.info["process_count"] = mpi.size + if mpi.size > 1: # run the simulation on multiple nodes + return self._run_parallel(state, dt) else: - # multiple processes are used and this is one of the secondaries - self._run_mpi_client(state, dt) - self.info["process_rank"] = mpi.rank - return None # do not return anything in client processes - - return state + return self._run_serial(state, dt) diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index ac792be9..26bdb2cc 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -8,7 +8,6 @@ from typing import Callable, Literal import numpy as np -from numba.extending import register_jitable from ..fields.base import FieldBase from ..grids._mesh import GridMesh @@ -76,6 +75,8 @@ class ExplicitMPISolver(ExplicitSolver): name = "explicit_mpi" + _mpi_synchronization = mpi.parallel_run + def __init__( self, pde: PDEBase, @@ -111,28 +112,12 @@ def __init__( adaptive time stepping to choose a time step which is small enough so the truncation error of a single step is below `tolerance`. """ + pde._mpi_synchronization = self._mpi_synchronization super().__init__( pde, scheme=scheme, backend=backend, adaptive=adaptive, tolerance=tolerance ) self.decomposition = decomposition - def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return function that synchronizes errors between multiple processes.""" - # if mpi.parallel_run: - # in a parallel run, we need to return the maximal error - from ..tools.mpi import Operator, mpi_allreduce - - operator_max_id = Operator.MAX - - @register_jitable - def synchronize_errors(error: float) -> float: - """Return maximal error accross all cores.""" - return mpi_allreduce(error, operator_max_id) # type: ignore - - return synchronize_errors # type: ignore - # else: - # return super()._make_error_synchronizer() - def make_stepper( self, state: FieldBase, dt=None ) -> Callable[[FieldBase, float, float], float]: diff --git a/pde/tools/expressions.py b/pde/tools/expressions.py index 411c9242..14471b0b 100644 --- a/pde/tools/expressions.py +++ b/pde/tools/expressions.py @@ -131,7 +131,10 @@ def np_heaviside(x1, x2): # special functions that we want to support in expressions but that are not defined by # sympy version 1.6 or have a different signature than expected by numba/numpy -SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot} +SPECIAL_FUNCTIONS: dict[str, Callable] = { + "Heaviside": _heaviside_implemention, + "hypot": np.hypot, +} class ListArrayPrinter(PythonCodePrinter): diff --git a/pde/tools/mpi.py b/pde/tools/mpi.py index dc9bb509..7d3f36ed 100644 --- a/pde/tools/mpi.py +++ b/pde/tools/mpi.py @@ -24,7 +24,7 @@ import numba as nb import numpy as np from numba import types -from numba.extending import overload, register_jitable +from numba.extending import SentryLiteralArgs, overload, register_jitable try: from numba.types import Literal @@ -158,10 +158,11 @@ def impl(data, source: int, tag: int) -> None: return impl -def mpi_allreduce(data, operator: int | str | None = None): +def mpi_allreduce(data, operator): """Combines data from all MPI nodes. - Note that complex datatypes and user-defined functions are not properly supported. + Note that complex datatypes and user-defined reduction operators are not properly + supported in numba-compiled cases. Args: data: @@ -173,21 +174,37 @@ def mpi_allreduce(data, operator: int | str | None = None): Returns: The accumulated data """ - if operator: - return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator)) + if not parallel_run: + # in a serial run, we can always return the value as is + return data + + if isinstance(data, np.ndarray): + # synchronize an array + out = np.empty_like(data) + MPI.COMM_WORLD.Allreduce(data, out, op=Operator.operator(operator)) + return out + else: - return MPI.COMM_WORLD.allreduce(data) + # synchronize a single value + return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator)) @overload(mpi_allreduce) -def ol_mpi_allreduce(data, operator: int | str | None = None): +def ol_mpi_allreduce(data, operator): """Overload the `mpi_allreduce` function.""" - import numba_mpi + if size == 1: + # We can simply return the value in a serial run - if operator is None or isinstance(operator, nb.types.NoneType): - op_id = -1 # value will not be used - elif isinstance(operator, Literal): - # an operator is specified (using a literal value + def impl(data, operator): + return data + + return impl + + # Conversely, in a parallel run, we need to use the correct reduction. Let's first + # determine the operator, which must be given as a literal type + SentryLiteralArgs(["operator"]).for_function(ol_mpi_allreduce).bind(data, operator) + if isinstance(operator, Literal): + # an operator is specified (using a literal value) if isinstance(operator.literal_value, str): # an operator is specified by it's name op_id = Operator.id(operator.literal_value) @@ -199,33 +216,37 @@ def ol_mpi_allreduce(data, operator: int | str | None = None): else: raise RuntimeError(f"`operator` must be a literal type, not {operator}") + import numba_mpi + @register_jitable - def _allreduce(sendobj, recvobj, operator: int | str | None = None) -> int: + def _allreduce(sendobj, recvobj, operator) -> int: """Helper function that calls `numba_mpi.allreduce`""" - if operator is None: - return numba_mpi.allreduce(sendobj, recvobj) # type: ignore - elif op_id is None: + if op_id is None: return numba_mpi.allreduce(sendobj, recvobj, operator) # type: ignore else: return numba_mpi.allreduce(sendobj, recvobj, op_id) # type: ignore if isinstance(data, types.Number): + # implementation of the reduction for a single number - def impl(data, operator: int | str | None = None): + def impl(data, operator): """Reduce a single number across all cores.""" sendobj = np.array([data]) recvobj = np.empty((1,), sendobj.dtype) status = _allreduce(sendobj, recvobj, operator) - assert status == 0 + if status != 0: + raise RuntimeError return recvobj[0] elif isinstance(data, types.Array): + # implementation of the reduction for a numpy array - def impl(data, operator: int | str | None = None): + def impl(data, operator): """Reduce an array across all cores.""" recvobj = np.empty(data.shape, data.dtype) status = _allreduce(data, recvobj, operator) - assert status == 0 + if status != 0: + raise RuntimeError return recvobj else: diff --git a/pde/tools/output.py b/pde/tools/output.py index 110c4c12..af58b618 100644 --- a/pde/tools/output.py +++ b/pde/tools/output.py @@ -41,7 +41,7 @@ def get_progress_bar_class(fancy: bool = True): progress_bar_class = tqdm.tqdm else: # use the fancier version of the progress bar in jupyter - from tqdm.auto import tqdm as progress_bar_class + from tqdm.auto import tqdm as progress_bar_class # type: ignore else: # only import text progress bar progress_bar_class = tqdm.tqdm diff --git a/tests/solvers/test_explicit_mpi_solvers.py b/tests/solvers/test_explicit_mpi_solvers.py index 24bd6126..13e0f6ef 100644 --- a/tests/solvers/test_explicit_mpi_solvers.py +++ b/tests/solvers/test_explicit_mpi_solvers.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from pde import DiffusionPDE, ScalarField, UnitGrid, PDE, FieldCollection +from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid from pde.solvers import Controller, ExplicitMPISolver from pde.tools import mpi @@ -17,7 +17,7 @@ [ ("euler", False, "auto"), ("euler", True, [1, -1]), - ("runge-kutta", False, [-1, 1]), + ("runge-kutta", True, [-1, 1]), ], ) def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng): @@ -107,7 +107,3 @@ def test_multiple_pdes_mpi(backend, rng): assert info_mpi["solver"]["steps"] == info2["solver"]["steps"] assert info_mpi["solver"]["use_mpi"] - from pprint import pprint - - pprint(info_mpi) - print(info_mpi) diff --git a/tests/tools/test_mpi.py b/tests/tools/test_mpi.py index 402bc6aa..d3f55e07 100644 --- a/tests/tools/test_mpi.py +++ b/tests/tools/test_mpi.py @@ -24,13 +24,34 @@ def test_send_recv(): @pytest.mark.multiprocessing -def test_allreduce(): - """Test basic send and receive.""" - from numba_mpi import Operator +@pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) +def test_allreduce(operator, rng): + """Test MPI allreduce function.""" + data = rng.uniform(size=size) + result = mpi_allreduce(data[rank], operator=operator) + + if operator == "MAX": + assert result == data.max() + elif operator == "MIN": + assert result == data.min() + elif operator == "SUM": + assert result == data.sum() + else: + raise NotImplementedError - data = np.arange(size) - total = mpi_allreduce(data[rank]) - assert total == data.sum() - total = mpi_allreduce(data[rank], int(Operator.MAX)) - assert total == data.max() +@pytest.mark.multiprocessing +@pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) +def test_allreduce_array(operator, rng): + """Test MPI allreduce function.""" + data = rng.uniform(size=(size, 3)) + result = mpi_allreduce(data[rank], operator=operator) + + if operator == "MAX": + np.testing.assert_allclose(result, data.max(axis=0)) + elif operator == "MIN": + np.testing.assert_allclose(result, data.min(axis=0)) + elif operator == "SUM": + np.testing.assert_allclose(result, data.sum(axis=0)) + else: + raise NotImplementedError