Skip to content

Commit

Permalink
Fixed problem with updating RGB images (#486)
Browse files Browse the repository at this point in the history
* Fixed problem with updating RGB images
* Improved coverage
  • Loading branch information
david-zwicker authored Nov 18, 2023
1 parent f975750 commit cafb881
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
18 changes: 14 additions & 4 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,9 @@ def _update_rgb_image_plot(self, reference: PlotReference) -> None:
The reference to the plot that is updated
"""
# obtain image data
rgb_arr, _ = self._get_rgb_data(**reference.parameters)
data_args = reference.parameters.copy()
data_args.pop("kind")
rgb_arr, _ = self._get_rgb_data(**data_args)
# update the axes image
reference.element.set_data(rgb_arr)

Expand Down Expand Up @@ -833,7 +835,12 @@ def _plot_rgb_image(
ax.set_title(self.label)

# store parameters of the plot that are necessary for updating
parameters = {"transpose": transpose, "vmin": vmin, "vmax": vmax}
parameters = {
"kind": "rgb_image",
"transpose": transpose,
"vmin": vmin,
"vmax": vmax,
}
return PlotReference(ax, axes_image, parameters)

def _update_plot(self, reference: List[PlotReference]) -> None:
Expand All @@ -843,8 +850,11 @@ def _update_plot(self, reference: List[PlotReference]) -> None:
reference (list of :class:`PlotReference`):
All references of the plot to update
"""
for field, ref in zip(self.fields, reference):
field._update_plot(ref)
if reference[0].parameters.get("kind", None) == "rgb_image":
self._update_rgb_image_plot(reference[0])
else:
for field, ref in zip(self.fields, reference):
field._update_plot(ref)

@plot_on_figure(update_method="_update_plot")
def plot(
Expand Down
3 changes: 2 additions & 1 deletion tests/fields/test_field_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_collections(rng):
with pytest.raises(KeyError):
fields["42"] = 0

fields.plot(subplot_args=[{}, {"scale": 1}, {"colorbar": False}])
refs = fields.plot(subplot_args=[{}, {"scale": 1}, {"colorbar": False}])
fields._update_plot(refs)


def test_collections_copy():
Expand Down

0 comments on commit cafb881

Please sign in to comment.