From 64d433c7d5c3699de2a0aee919f8b74add24867b Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Fri, 9 Aug 2024 13:03:32 +0200 Subject: [PATCH] Refactored synchronization of values in MPI runs * Replace error_synchronization by generic mpi_allreduce * Improved mpi_allreduce to deal with string Operators in numba * Added additional tests for mpi_allreduce * Improved exception handling in MPI simulations * Multiple smaller improvements --- pde/grids/base.py | 14 +--- pde/solvers/base.py | 19 +++-- pde/solvers/controller.py | 32 ++++++-- pde/solvers/explicit.py | 4 +- pde/solvers/explicit_mpi.py | 28 +++---- pde/tools/expressions.py | 5 +- pde/tools/mpi.py | 93 +++++++++++++++++----- pde/tools/output.py | 2 +- tests/grids/test_grid_mesh.py | 2 + tests/solvers/test_explicit_mpi_solvers.py | 4 +- tests/tools/test_mpi.py | 63 ++++++++++++--- 11 files changed, 188 insertions(+), 78 deletions(-) diff --git a/pde/grids/base.py b/pde/grids/base.py index 3da7949f..139bea0d 100644 --- a/pde/grids/base.py +++ b/pde/grids/base.py @@ -26,6 +26,7 @@ from ..tools.cache import cached_method, cached_property from ..tools.docstrings import fill_in_docstring from ..tools.misc import Number, hybridmethod +from ..tools.mpi import mpi_allreduce from ..tools.numba import jit from ..tools.typing import ( CellVolume, @@ -1572,14 +1573,9 @@ def integrate( if self._mesh is None or len(self._mesh) == 1: # standard case of a single integral return integral # type: ignore - else: - # we are in a parallel run, so we need to gather the sub-integrals from all - from mpi4py.MPI import COMM_WORLD # @UnresolvedImport - - integral_full = np.empty_like(integral) - COMM_WORLD.Allreduce(integral, integral_full) - return integral_full # type: ignore + # accumulate integrals from all subprocesse (necessary if MPI is used) + return mpi_allreduce(integral, operator="SUM") # type: ignore @cached_method() def make_normalize_point_compiled( @@ -2135,8 +2131,6 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray: else: # we are in a parallel run, so we need to gather the sub-integrals from all # subgrids in the grid mesh - from ..tools.mpi import mpi_allreduce - @jit def integrate_global(arr: np.ndarray) -> NumberOrArray: """Integrate data over MPI parallelized grid. @@ -2145,7 +2139,7 @@ def integrate_global(arr: np.ndarray) -> NumberOrArray: arr (:class:`~numpy.ndarray`): discretized data on grid """ integral = integrate_local(arr) - return mpi_allreduce(integral) # type: ignore + return mpi_allreduce(integral, operator="SUM") # type: ignore return integrate_global # type: ignore diff --git a/pde/solvers/base.py b/pde/solvers/base.py index 5731542a..095bc0a7 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -22,6 +22,7 @@ from ..pdes.base import PDEBase from ..tools.math import OnlineStatistics from ..tools.misc import classproperty +from ..tools.mpi import mpi_allreduce from ..tools.numba import is_jitted, jit from ..tools.typing import BackendType, StepperHook @@ -410,14 +411,19 @@ def __init__( self.adaptive = adaptive self.tolerance = tolerance - def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return function that synchronizes errors between multiple processes.""" + def _make_error_synchronizer(self): + # Deprecated on 2024-08-09 + warnings.warn( + "`_make_error_synchronizer` has been replaced by " + "`pde.tools.mpi.mpi_allreduce`", + DeprecationWarning, + ) @register_jitable - def synchronize_errors(error: float) -> float: - return error + def error_synchronizer(value): + return mpi_allreduce(value, operator="MAX") - return synchronize_errors # type: ignore + return error_synchronizer def _make_dt_adjuster(self) -> Callable[[float, float], float]: """Return a function that can be used to adjust time steps.""" @@ -545,7 +551,6 @@ def _make_adaptive_stepper(self, state: FieldBase) -> Callable[ # obtain functions determining how the PDE is evolved single_step_error = self._make_single_step_error_estimate(state) post_step_hook = self._make_post_step_hook(state) - sync_errors = self._make_error_synchronizer() # obtain auxiliary functions adjust_dt = self._make_dt_adjuster() @@ -578,7 +583,7 @@ def adaptive_stepper( error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (necessary for MPI) - error_rel = sync_errors(error_rel) + error_rel = mpi_allreduce(error_rel, operator="MAX") # do the step if the error is sufficiently small if error_rel <= 1: diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py index 2ef4ca6a..6f9c7f7c 100644 --- a/pde/solvers/controller.py +++ b/pde/solvers/controller.py @@ -327,14 +327,32 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None: state = initial_state.copy() # decide whether to call the main routine or whether this is an MPI client - if mpi.is_main: - # this node is the primary one - self._run_single(state, dt) + if mpi.size > 1: + self.info["parallel_run"] = True self.info["process_count"] = mpi.size - else: - # multiple processes are used and this is one of the secondaries - self._run_mpi_client(state, dt) self.info["process_rank"] = mpi.rank - return None # do not return anything in client processes + from mpi4py import MPI + + if mpi.is_main: # this node is the primary one + try: + self._run_single(state, dt) + except Exception: + # found exception on the main node + MPI.COMM_WORLD.Abort() # abort all other nodes + raise + else: # this node is a secondary (client) node + try: + self._run_mpi_client(state, dt) + except Exception as exception: + print(exception) # simply print the exception to show some info + MPI.COMM_WORLD.Abort() # abort all other (and main) nodes + raise + else: + return None # do not return anything in client processes + + else: + # serial run without MPI + self.info["parallel_run"] = False + self._run_single(state, dt) return state diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py index a2471e0c..aef78d90 100644 --- a/pde/solvers/explicit.py +++ b/pde/solvers/explicit.py @@ -13,6 +13,7 @@ from ..fields.base import FieldBase from ..pdes.base import PDEBase from ..tools.math import OnlineStatistics +from ..tools.mpi import mpi_allreduce from ..tools.numba import jit from ..tools.typing import BackendType from .base import AdaptiveSolverBase @@ -182,7 +183,6 @@ def _make_adaptive_euler_stepper(self, state: FieldBase) -> Callable[ post_step_hook = self._make_post_step_hook(state) # obtain auxiliary functions - sync_errors = self._make_error_synchronizer() adjust_dt = self._make_dt_adjuster() tolerance = self.tolerance dt_min = self.dt_min @@ -226,7 +226,7 @@ def adaptive_stepper( error_rel = error / tolerance # normalize error to given tolerance # synchronize the error between all processes (necessary for MPI) - error_rel = sync_errors(error_rel) + error_rel = mpi_allreduce(error_rel, operator="MAX") if error_rel <= 1: # error is sufficiently small try: diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index ac792be9..3861f3e1 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -8,7 +8,6 @@ from typing import Callable, Literal import numpy as np -from numba.extending import register_jitable from ..fields.base import FieldBase from ..grids._mesh import GridMesh @@ -116,23 +115,6 @@ def __init__( ) self.decomposition = decomposition - def _make_error_synchronizer(self) -> Callable[[float], float]: - """Return function that synchronizes errors between multiple processes.""" - # if mpi.parallel_run: - # in a parallel run, we need to return the maximal error - from ..tools.mpi import Operator, mpi_allreduce - - operator_max_id = Operator.MAX - - @register_jitable - def synchronize_errors(error: float) -> float: - """Return maximal error accross all cores.""" - return mpi_allreduce(error, operator_max_id) # type: ignore - - return synchronize_errors # type: ignore - # else: - # return super()._make_error_synchronizer() - def make_stepper( self, state: FieldBase, dt=None ) -> Callable[[FieldBase, float, float], float]: @@ -178,12 +160,16 @@ def make_stepper( # decompose the state into multiple cells self.mesh = GridMesh.from_grid(state.grid, self.decomposition) sub_state = self.mesh.extract_subfield(state) + print("GOT sub_state for node", mpi.rank) self.info["grid_decomposition"] = self.mesh.shape if self.adaptive: # create stepper with adaptive steps self.info["dt_statistics"] = OnlineStatistics() + + print("MAKE adaptive stepper for rank", mpi.rank) adaptive_stepper = self._make_adaptive_stepper(sub_state) + print("GOT adaptive stepper for rank", mpi.rank) self.info["post_step_data"] = self._post_step_data_init def wrapped_stepper( @@ -196,6 +182,7 @@ def wrapped_stepper( post_step_data = self.info["post_step_data"] # distribute the end time and the field to all nodes + print("BROADCAST DATA") t_end = self.mesh.broadcast(t_end) substate_data = self.mesh.split_field_data_mpi(state.data) @@ -203,6 +190,7 @@ def wrapped_stepper( # field data via special boundary conditions and they synchronize the # maximal error via the error synchronizer. Apart from that, all nodes # work independently. + print("START ADAPTIVE STEPPER") t_last, dt, steps = adaptive_stepper( substate_data, t_start, @@ -211,16 +199,20 @@ def wrapped_stepper( self.info["dt_statistics"], post_step_data, ) + print("t_last", t_last) # check whether dt is the same for all processes dt_list = self.mesh.allgather(dt) if not np.isclose(min(dt_list), max(dt_list)): # abort simulations in all nodes when they went out of sync raise RuntimeError(f"Processes went out of sync: dt={dt_list}") + print("dt_list", dt_list) # collect the data from all nodes post_step_data_list = self.mesh.gather(post_step_data) + print("post_step_data_list", post_step_data_list) self.mesh.combine_field_data_mpi(substate_data, out=state.data) + print("substate_data", state.data) if mpi.is_main: self.info["steps"] += steps diff --git a/pde/tools/expressions.py b/pde/tools/expressions.py index 411c9242..14471b0b 100644 --- a/pde/tools/expressions.py +++ b/pde/tools/expressions.py @@ -131,7 +131,10 @@ def np_heaviside(x1, x2): # special functions that we want to support in expressions but that are not defined by # sympy version 1.6 or have a different signature than expected by numba/numpy -SPECIAL_FUNCTIONS = {"Heaviside": _heaviside_implemention, "hypot": np.hypot} +SPECIAL_FUNCTIONS: dict[str, Callable] = { + "Heaviside": _heaviside_implemention, + "hypot": np.hypot, +} class ListArrayPrinter(PythonCodePrinter): diff --git a/pde/tools/mpi.py b/pde/tools/mpi.py index dc9bb509..25fff010 100644 --- a/pde/tools/mpi.py +++ b/pde/tools/mpi.py @@ -19,12 +19,12 @@ import os import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import numba as nb import numpy as np from numba import types -from numba.extending import overload, register_jitable +from numba.extending import SentryLiteralArgs, overload, register_jitable try: from numba.types import Literal @@ -36,7 +36,7 @@ # Initialize assuming that we run serial code if `numba_mpi` is not available initialized: bool = False -"""bool: Flag determining whether mpi was initialized (and is available)""" +"""bool: Flag determining whether MPI was initialized (and is available)""" size: int = 1 """int: Total process count""" @@ -158,10 +158,11 @@ def impl(data, source: int, tag: int) -> None: return impl -def mpi_allreduce(data, operator: int | str | None = None): +def mpi_allreduce(data, operator): """Combines data from all MPI nodes. - Note that complex datatypes and user-defined functions are not properly supported. + Note that complex datatypes and user-defined reduction operators are not properly + supported in numba-compiled cases. Args: data: @@ -173,21 +174,39 @@ def mpi_allreduce(data, operator: int | str | None = None): Returns: The accumulated data """ - if operator: - return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator)) + if not parallel_run: + # in a serial run, we can always return the value as is + return data + + if isinstance(data, np.ndarray): + # synchronize an array + out = np.empty_like(data) + MPI.COMM_WORLD.Allreduce(data, out, op=Operator.operator(operator)) + return out + else: - return MPI.COMM_WORLD.allreduce(data) + # synchronize a single value + return MPI.COMM_WORLD.allreduce(data, op=Operator.operator(operator)) @overload(mpi_allreduce) -def ol_mpi_allreduce(data, operator: int | str | None = None): +def ol_mpi_allreduce(data, operator): """Overload the `mpi_allreduce` function.""" import numba_mpi - if operator is None or isinstance(operator, nb.types.NoneType): - op_id = -1 # value will not be used - elif isinstance(operator, Literal): - # an operator is specified (using a literal value + if not parallel_run: + # in a serial run, we can always return the value as is + + def impl(data, operator): + return data + + return impl + + # Conversely, in a parallel run, we need to use the correct reduction. Let's first + # determine the operator, which must be given as a literal type + SentryLiteralArgs(["operator"]).for_function(ol_mpi_allreduce).bind(data, operator) + if isinstance(operator, Literal): + # an operator is specified (using a literal value) if isinstance(operator.literal_value, str): # an operator is specified by it's name op_id = Operator.id(operator.literal_value) @@ -200,35 +219,67 @@ def ol_mpi_allreduce(data, operator: int | str | None = None): raise RuntimeError(f"`operator` must be a literal type, not {operator}") @register_jitable - def _allreduce(sendobj, recvobj, operator: int | str | None = None) -> int: + def _allreduce(sendobj, recvobj, operator) -> int: """Helper function that calls `numba_mpi.allreduce`""" - if operator is None: - return numba_mpi.allreduce(sendobj, recvobj) # type: ignore - elif op_id is None: + if op_id is None: return numba_mpi.allreduce(sendobj, recvobj, operator) # type: ignore else: return numba_mpi.allreduce(sendobj, recvobj, op_id) # type: ignore if isinstance(data, types.Number): + # implementation of the reduction for a single number - def impl(data, operator: int | str | None = None): + def impl(data, operator): """Reduce a single number across all cores.""" sendobj = np.array([data]) recvobj = np.empty((1,), sendobj.dtype) status = _allreduce(sendobj, recvobj, operator) - assert status == 0 + if status != 0: + raise RuntimeError return recvobj[0] elif isinstance(data, types.Array): + # implementation of the reduction for a numpy array - def impl(data, operator: int | str | None = None): + def impl(data, operator): """Reduce an array across all cores.""" recvobj = np.empty(data.shape, data.dtype) status = _allreduce(data, recvobj, operator) - assert status == 0 + if status != 0: + raise RuntimeError return recvobj else: raise TypeError(f"Unsupported type {data.__class__.__name__}") return impl + + +def make_float_synchronizer(operator: int | str = "MAX") -> Callable[[float], float]: + """Return function that synchronizes number between multiple processes. + + Args: + operator (int or str): + MPI operator to synchronize the numbers. Typical values are "MAX", "MIN", + or "SUM" + + Returns: + Function that can be used to synchronize float values across cores + """ + if parallel_run: + # in a parallel run, we need to synchronize the float values + operator_id = Operator.id(operator) + + @register_jitable + def synchronize_floats(value: float) -> float: + """Synchronize values accross all cores.""" + return mpi_allreduce(value, operator_id) # type: ignore + + else: + # we don't need to do anything in a serial run + + @register_jitable + def synchronize_floats(value: float) -> float: + return value + + return synchronize_floats # type: ignore diff --git a/pde/tools/output.py b/pde/tools/output.py index 110c4c12..af58b618 100644 --- a/pde/tools/output.py +++ b/pde/tools/output.py @@ -41,7 +41,7 @@ def get_progress_bar_class(fancy: bool = True): progress_bar_class = tqdm.tqdm else: # use the fancier version of the progress bar in jupyter - from tqdm.auto import tqdm as progress_bar_class + from tqdm.auto import tqdm as progress_bar_class # type: ignore else: # only import text progress bar progress_bar_class = tqdm.tqdm diff --git a/tests/grids/test_grid_mesh.py b/tests/grids/test_grid_mesh.py index a3c028c7..9a893f3f 100644 --- a/tests/grids/test_grid_mesh.py +++ b/tests/grids/test_grid_mesh.py @@ -254,7 +254,9 @@ def test_integration_parallel(grid, decomposition, rank): subfield = mesh.extract_subfield(field) # numpy version + # integrate full field on each node np.testing.assert_allclose(field.integral, expected) + # integrate subfields per node and accumulate result np.testing.assert_allclose(subfield.integral, expected) # numba version diff --git a/tests/solvers/test_explicit_mpi_solvers.py b/tests/solvers/test_explicit_mpi_solvers.py index 24bd6126..54f326d6 100644 --- a/tests/solvers/test_explicit_mpi_solvers.py +++ b/tests/solvers/test_explicit_mpi_solvers.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from pde import DiffusionPDE, ScalarField, UnitGrid, PDE, FieldCollection +from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid from pde.solvers import Controller, ExplicitMPISolver from pde.tools import mpi @@ -17,7 +17,7 @@ [ ("euler", False, "auto"), ("euler", True, [1, -1]), - ("runge-kutta", False, [-1, 1]), + ("runge-kutta", True, [-1, 1]), ], ) def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng): diff --git a/tests/tools/test_mpi.py b/tests/tools/test_mpi.py index 402bc6aa..0006174a 100644 --- a/tests/tools/test_mpi.py +++ b/tests/tools/test_mpi.py @@ -5,7 +5,14 @@ import numpy as np import pytest -from pde.tools.mpi import mpi_allreduce, mpi_recv, mpi_send, rank, size +from pde.tools.mpi import ( + make_float_synchronizer, + mpi_allreduce, + mpi_recv, + mpi_send, + rank, + size, +) @pytest.mark.multiprocessing @@ -24,13 +31,51 @@ def test_send_recv(): @pytest.mark.multiprocessing -def test_allreduce(): - """Test basic send and receive.""" - from numba_mpi import Operator +@pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) +def test_allreduce(operator, rng): + """Test MPI allreduce function.""" + data = rng.uniform(size=size) + result = mpi_allreduce(data[rank], operator=operator) + + if operator == "MAX": + assert result == data.max() + elif operator == "MIN": + assert result == data.min() + elif operator == "SUM": + assert result == data.sum() + else: + raise NotImplementedError + + +@pytest.mark.multiprocessing +@pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) +def test_allreduce_array(operator, rng): + """Test MPI allreduce function.""" + data = rng.uniform(size=(size, 3)) + result = mpi_allreduce(data[rank], operator=operator) + + if operator == "MAX": + np.testing.assert_allclose(result, data.max(axis=0)) + elif operator == "MIN": + np.testing.assert_allclose(result, data.min(axis=0)) + elif operator == "SUM": + np.testing.assert_allclose(result, data.sum(axis=0)) + else: + raise NotImplementedError - data = np.arange(size) - total = mpi_allreduce(data[rank]) - assert total == data.sum() - total = mpi_allreduce(data[rank], int(Operator.MAX)) - assert total == data.max() +@pytest.mark.multiprocessing +@pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) +def test_float_synchronizers(operator): + """Test make_float_synchronizer function.""" + sync = make_float_synchronizer(operator=operator) + result = sync(rank) + + if operator == "MAX": + assert result == size - 1 + elif operator == "MIN": + assert result == 0 + elif operator == "SUM": + assert result == sum(range(size)) + else: + raise NotImplementedError