Skip to content

Commit

Permalink
Added option to plot RGB images of field collections (#485)
Browse files Browse the repository at this point in the history
Closes issue #484 

* Added option to plot RGB images of field collections
* Several other smaller improvements
  • Loading branch information
david-zwicker authored Nov 18, 2023
1 parent 4f2632b commit f975750
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/pde_brusselator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ def pde_rhs(state_data, t):
state = eq.get_initial_state(grid)

# simulate the pde
tracker = PlotTracker(interval=1, plot_args={"vmin": 0, "vmax": 5})
tracker = PlotTracker(interval=1, plot_args={"kind": "rgb", "vmin": 0, "vmax": 5})
sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker)
6 changes: 0 additions & 6 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,12 +2079,6 @@ def _plot_image(
data_kws[arg] = kwargs.pop(arg)
data = self.get_image_data(scalar, transpose, **data_kws)

if ax is None:
import matplotlib.pyplot as plt

# create new figure
ax = plt.subplots()[1]

# plot the image
kwargs.setdefault("origin", "lower")
kwargs.setdefault("interpolation", "none")
Expand Down
136 changes: 118 additions & 18 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
Mapping,
Optional,
Sequence,
Tuple,
overload,
)

import numpy as np
from matplotlib.colors import Normalize
from numpy.typing import DTypeLike

from ..grids.base import GridBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import Number, number_array
from ..tools.plotting import PlotReference, plot_on_figure
from ..tools.plotting import PlotReference, plot_on_axes, plot_on_figure
from ..tools.typing import NumberOrArray
from .base import DataFieldBase, FieldBase
from .scalar import ScalarField
Expand Down Expand Up @@ -750,6 +752,90 @@ def get_image_data(self, index: int = 0, **kwargs) -> Dict[str, Any]:
"""
return self[index].get_image_data(**kwargs)

def _get_rgb_data(
self,
transpose: bool = False,
vmin: float | List[float | None] | None = None,
vmax: float | List[float | None] | None = None,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""obtain data required for RGB plot"""
num_fields = len(self)
if num_fields > 3:
raise ValueError("Can only plot RGB image for three or fewer fields")
if not hasattr(vmin, "__iter__"):
vmin = [vmin] * num_fields
if not hasattr(vmax, "__iter__"):
vmax = [vmax] * num_fields

# obtain image data with appropriate parameters
data = [f.get_image_data(transpose=transpose) for f in self]
# turn data into array of RGB values (shape nxmx3)
data_list = []
for i, d in enumerate(data):
norm = Normalize(vmin=vmin[i], vmax=vmax[i], clip=True) # type: ignore
data_list.append(norm(d["data"].T))
while len(data_list) < 3:
data_list.append(np.zeros_like(data_list[0]))
rgb_arr = np.dstack(data_list)
return rgb_arr, data[0]

def _update_rgb_image_plot(self, reference: PlotReference) -> None:
"""update an RGB image plot with the current field values
Args:
reference (:class:`PlotReference`):
The reference to the plot that is updated
"""
# obtain image data
rgb_arr, _ = self._get_rgb_data(**reference.parameters)
# update the axes image
reference.element.set_data(rgb_arr)

@plot_on_axes(update_method="_update_rgb_image_plot")
def _plot_rgb_image(
self,
ax,
transpose: bool = False,
vmin: float | List[float | None] | None = None,
vmax: float | List[float | None] | None = None,
**kwargs,
) -> PlotReference:
r"""visualize fields by mapping to different color chanels in a 2d density plot
Args:
ax (:class:`matplotlib.axes.Axes`):
Figure axes to be used for plotting.
transpose (bool):
Determines whether the transpose of the data is plotted
vmin, vmax (float, list of float):
Define the data range that the color chanels cover. By default, they
cover the complete value range of the supplied data.
\**kwargs:
Additional keyword arguments that affect the image. Non-Cartesian grids
might support `performance_goal` to influence how an image is created
from raw data. Finally, remaining arguments are passed to
:func:`matplotlib.pyplot.imshow` to affect the appearance.
Returns:
:class:`PlotReference`: Instance that contains information to update the
plot with new data later.
"""
rgb_arr, data = self._get_rgb_data(transpose, vmin, vmax)

# plot the image
kwargs.setdefault("origin", "lower")
kwargs.setdefault("interpolation", "none")
axes_image = ax.imshow(rgb_arr, extent=data["extent"], **kwargs)

# set some default properties
ax.set_xlabel(data["label_x"])
ax.set_ylabel(data["label_y"])
ax.set_title(self.label)

# store parameters of the plot that are necessary for updating
parameters = {"transpose": transpose, "vmin": vmin, "vmax": vmax}
return PlotReference(ax, axes_image, parameters)

def _update_plot(self, reference: List[PlotReference]) -> None:
"""update a plot collection with the current field values
Expand All @@ -775,9 +861,10 @@ def plot(
Args:
kind (str or list of str):
Determines the kind of the visualizations. Supported values are `image`,
`line`, `vector`, or `interactive`. Alternatively, `auto` determines the
best visualization based on each field itself. Instead of a single value
for all fields, a list with individual values can be given.
`line`, `vector`, `interactive`, or `rgb`. Alternatively, `auto`
determines the best visualization based on each field itself. Instead of
a single value for all fields, a list with individual values can be
given, unless `rgb` is chosen.
figsize (str or tuple of numbers):
Determines the figure size. The figure size is unchanged if the string
`default` is passed. Conversely, the size is adjusted automatically when
Expand All @@ -800,43 +887,56 @@ def plot(
List of :class:`PlotReference`: Instances that contain information
to update all the plots with new data later.
"""
if kind in {"rgb", "rgb_image", "rgb-image"}:
num_panels = 1
else:
num_panels = len(self)

# set the size of the figure
if figsize == "default":
pass # just leave the figure size at its default value

elif figsize == "auto":
# adjust the size of the figure
if arrangement == "horizontal":
fig.set_size_inches((4 * len(self), 3), forward=True)
fig.set_size_inches((4 * num_panels, 3), forward=True)
elif arrangement == "vertical":
fig.set_size_inches((4, 3 * len(self)), forward=True)
fig.set_size_inches((4, 3 * num_panels), forward=True)

else:
# assume that an actual tuple is given
fig.set_size_inches(figsize, forward=True)

# create all the subpanels
if arrangement == "horizontal":
(axs,) = fig.subplots(1, len(self), squeeze=False)
(axs,) = fig.subplots(1, num_panels, squeeze=False)
elif arrangement == "vertical":
axs = fig.subplots(len(self), 1, squeeze=False)
axs = fig.subplots(num_panels, 1, squeeze=False)
axs = [a[0] for a in axs] # transpose
else:
raise ValueError(f"Unknown arrangement `{arrangement}`")

if subplot_args is None:
subplot_args = [{}] * len(self)
subplot_args = [{}] * num_panels

if isinstance(kind, str):
kind = [kind] * len(self.fields)
if kind in {"rgb", "rgb_image", "rgb-image"}:
# plot a single RGB representation
reference = [
self._plot_rgb_image(
ax=axs[0], action="none", **kwargs, **subplot_args[0]
)
]

# plot all the elements onto the respective axes
reference = [
field.plot(kind=knd, ax=ax, action="none", **kwargs, **sp_args)
for field, knd, ax, sp_args in zip( # @UnusedVariable
self.fields, kind, axs, subplot_args
)
]
else:
# plot all the elements onto the respective axes
if isinstance(kind, str):
kind = [kind] * num_panels
reference = [
field.plot(kind=knd, ax=ax, action="none", **kwargs, **sp_args)
for field, knd, ax, sp_args in zip( # @UnusedVariable
self.fields, kind, axs, subplot_args
)
]

# return the references for all subplots
return reference
Expand Down
2 changes: 1 addition & 1 deletion pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _check_shape(shape: int | Sequence[int]) -> Tuple[int, ...]:
def discretize_interval(
x_min: float, x_max: float, num: int
) -> Tuple[np.ndarray, float]:
r""" construct a list of equidistantly placed intervals
r"""construct a list of equidistantly placed intervals
The discretization is defined as
Expand Down
15 changes: 8 additions & 7 deletions pde/grids/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from ..tools.cuboid import Cuboid
from ..tools.plotting import plot_on_axes
from .base import CoordsType, DimensionError, GridBase, _check_shape
from .base import (
CoordsType,
DimensionError,
GridBase,
_check_shape,
discretize_interval,
)

if TYPE_CHECKING:
from .boundaries.axes import Boundaries, BoundariesData
Expand Down Expand Up @@ -130,12 +136,7 @@ def __init__(
p1, p2 = self.cuboid.corners
axes_coords, discretization = [], []
for d in range(self.dim):
num = self.shape[d]
c, dc = np.linspace(p1[d], p2[d], num, endpoint=False, retstep=True)
if self.shape[d] == 1:
# correct for singular dimension
dc = p2[d] - p1[d]
c += dc / 2
c, dc = discretize_interval(p1[d], p2[d], self.shape[d])
axes_coords.append(c)
discretization.append(dc)
self._discretization = np.array(discretization)
Expand Down
13 changes: 13 additions & 0 deletions tests/fields/test_field_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,16 @@ def test_collection_apply(rng):
f1 = FieldCollection([s, v])

np.testing.assert_allclose(f1.apply("s1 * v2").data, v.data * 2)


@pytest.mark.parametrize("num", [1, 2, 3])
def test_rgb_image_plotting(num):
"""test plotting of collections as rgb fields"""
grid = UnitGrid([16, 8])
fc = FieldCollection([ScalarField.random_uniform(grid) for _ in range(num)])

ref = fc._plot_rgb_image()
fc._update_rgb_image_plot(ref)

refs = fc.plot("rgb_image")
fc._update_plot(refs)

0 comments on commit f975750

Please sign in to comment.