From cafb881b2c97fd714c689d6b517a026f29f10cf5 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Sat, 18 Nov 2023 15:58:14 +0100 Subject: [PATCH] Fixed problem with updating RGB images (#486) * Fixed problem with updating RGB images * Improved coverage --- pde/fields/collection.py | 18 ++++++++++++++---- tests/fields/test_field_collections.py | 3 ++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pde/fields/collection.py b/pde/fields/collection.py index 6abe9017..cb2b0204 100644 --- a/pde/fields/collection.py +++ b/pde/fields/collection.py @@ -787,7 +787,9 @@ def _update_rgb_image_plot(self, reference: PlotReference) -> None: The reference to the plot that is updated """ # obtain image data - rgb_arr, _ = self._get_rgb_data(**reference.parameters) + data_args = reference.parameters.copy() + data_args.pop("kind") + rgb_arr, _ = self._get_rgb_data(**data_args) # update the axes image reference.element.set_data(rgb_arr) @@ -833,7 +835,12 @@ def _plot_rgb_image( ax.set_title(self.label) # store parameters of the plot that are necessary for updating - parameters = {"transpose": transpose, "vmin": vmin, "vmax": vmax} + parameters = { + "kind": "rgb_image", + "transpose": transpose, + "vmin": vmin, + "vmax": vmax, + } return PlotReference(ax, axes_image, parameters) def _update_plot(self, reference: List[PlotReference]) -> None: @@ -843,8 +850,11 @@ def _update_plot(self, reference: List[PlotReference]) -> None: reference (list of :class:`PlotReference`): All references of the plot to update """ - for field, ref in zip(self.fields, reference): - field._update_plot(ref) + if reference[0].parameters.get("kind", None) == "rgb_image": + self._update_rgb_image_plot(reference[0]) + else: + for field, ref in zip(self.fields, reference): + field._update_plot(ref) @plot_on_figure(update_method="_update_plot") def plot( diff --git a/tests/fields/test_field_collections.py b/tests/fields/test_field_collections.py index fe32e991..e941d259 100644 --- a/tests/fields/test_field_collections.py +++ b/tests/fields/test_field_collections.py @@ -93,7 +93,8 @@ def test_collections(rng): with pytest.raises(KeyError): fields["42"] = 0 - fields.plot(subplot_args=[{}, {"scale": 1}, {"colorbar": False}]) + refs = fields.plot(subplot_args=[{}, {"scale": 1}, {"colorbar": False}]) + fields._update_plot(refs) def test_collections_copy():