Skip to content

Commit

Permalink
Prepare the package for more difficult curvilinear grids (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker authored Nov 24, 2023
1 parent cafb881 commit aae2299
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class GridBase(metaclass=ABCMeta):

boundary_names: Dict[str, Tuple[int, bool]] = {}
"""dict: Names of boundaries to select them conveniently"""
cell_volume_data: Sequence[FloatNumerical]
cell_volume_data: Optional[Sequence[FloatNumerical]]
"""list: Information about the size of discretization cells"""
coordinate_constraints: List[int] = []
"""list: axes that not described explicitly"""
Expand Down Expand Up @@ -445,13 +445,20 @@ def cell_coords(self) -> np.ndarray:
@cached_property()
def cell_volumes(self) -> np.ndarray:
""":class:`~numpy.ndarray`: volume of each cell"""
if self.cell_volume_data is None:
raise RuntimeError(
"`cell_volumes` needs to be implemented if `cell_volume_data` is `None`"
)
vols = functools.reduce(np.outer, self.cell_volume_data)
return np.broadcast_to(vols, self.shape)

@cached_property()
def uniform_cell_volumes(self) -> bool:
"""bool: returns True if all cell volumes are the same"""
return all(np.asarray(vols).ndim == 0 for vols in self.cell_volume_data)
if self.cell_volume_data is None:
return False
else:
return all(np.asarray(vols).ndim == 0 for vols in self.cell_volume_data)

def difference_vector_real(self, p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
"""return vector(s) pointing from p1 to p2
Expand Down Expand Up @@ -1254,19 +1261,25 @@ def integrate(
:class:`~numpy.ndarray`: The values integrated over the entire grid
"""
# determine the volumes of the individual cells
if axes is None:
volume_list = self.cell_volume_data
if self.cell_volume_data is None:
if axes is None:
cell_volumes = self.cell_volumes
else:
raise NotImplementedError
else:
# use stored value for the default case of integrating over all axes
if isinstance(axes, int):
axes = (axes,)
if axes is None:
volume_list = self.cell_volume_data
else:
axes = tuple(axes) # required for numpy.sum
volume_list = [
cell_vol if ax in axes else 1
for ax, cell_vol in enumerate(self.cell_volume_data)
]
cell_volumes = functools.reduce(np.outer, volume_list)
# use stored value for the default case of integrating over all axes
if isinstance(axes, int):
axes = (axes,)
else:
axes = tuple(axes) # required for numpy.sum
volume_list = [
cell_vol if ax in axes else 1
for ax, cell_vol in enumerate(self.cell_volume_data)
]
cell_volumes = functools.reduce(np.outer, volume_list)

# determine the axes over which we will integrate
if not isinstance(data, np.ndarray) or data.ndim < self.num_axes:
Expand Down Expand Up @@ -1355,7 +1368,9 @@ def make_cell_volume_compiled(self, flat_index: bool = False) -> CellVolume:
Returns:
function: returning the volume of the chosen cell
"""
if all(np.isscalar(d) for d in self.cell_volume_data):
if self.cell_volume_data is not None and all(
np.isscalar(d) for d in self.cell_volume_data
):
# all cells have the same volume
cell_volume = np.prod(self.cell_volume_data) # type: ignore

Expand Down

0 comments on commit aae2299

Please sign in to comment.