Skip to content

Commit

Permalink
Added boundaries iterator to Boundaries class (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker authored Feb 16, 2024
1 parent b0807a0 commit c427579
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pde/grids/_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def gather(self, data: TData) -> list[TData] | None:
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport

return COMM_WORLD.gather(data, root=0) # type: ignore
return COMM_WORLD.gather(data, root=0)

def allgather(self, data: TData) -> list[TData]:
"""gather a value from reach node and sends them to all nodes
Expand All @@ -720,7 +720,7 @@ def allgather(self, data: TData) -> list[TData]:
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport

return COMM_WORLD.allgather(data) # type: ignore
return COMM_WORLD.allgather(data)

@plot_on_axes()
def plot(self, ax, **kwargs) -> None:
Expand Down
11 changes: 9 additions & 2 deletions pde/grids/boundaries/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from typing import Union

import numpy as np
Expand All @@ -17,7 +17,7 @@
from ...tools.typing import GhostCellSetter
from ..base import GridBase, PeriodicityError
from .axis import BoundaryPair, BoundaryPairData, get_boundary_axis
from .local import BCDataError
from .local import BCBase, BCDataError

BoundariesData = Union[BoundaryPairData, Sequence[BoundaryPairData]]

Expand Down Expand Up @@ -170,6 +170,13 @@ def __eq__(self, other):
return NotImplemented
return super().__eq__(other) and self.grid == other.grid

@property
def boundaries(self) -> Iterator[BCBase]:
"""iterator over all non-periodic boundaries"""
for boundary_axis in self: # iterate all axes
if not boundary_axis.periodic: # skip periodic axes
yield from boundary_axis

def check_value_rank(self, rank: int) -> None:
"""check whether the values at the boundaries have the correct rank
Expand Down
16 changes: 15 additions & 1 deletion tests/grids/boundaries/test_axes_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pde.grids.base import PeriodicityError
from pde.grids.boundaries.axes import BCDataError, Boundaries
from pde.grids.boundaries.axis import BoundaryPair, BoundaryPeriodic, get_boundary_axis
from pde.grids.boundaries.local import NeumannBC


def test_boundaries():
Expand Down Expand Up @@ -181,4 +182,17 @@ def test_setting_specific_bcs():
with pytest.raises(KeyError):
bcs["nonsense"] = None

# test different ranks

def test_boundaries_property():
"""test boundaries property"""
g = UnitGrid([2, 2])
bc = Boundaries.from_data(g, ["neumann", "dirichlet"])
assert len(list(bc.boundaries)) == 4

bc = Boundaries.from_data(g, "neumann")
for b in bc.boundaries:
assert isinstance(b, NeumannBC)

g = UnitGrid([2, 2], periodic=[True, False])
bc = Boundaries.from_data(g, "auto_periodic_neumann")
assert len(list(bc.boundaries)) == 2

0 comments on commit c427579

Please sign in to comment.