diff --git a/pde/fields/base.py b/pde/fields/base.py index 076411e7..08b7d7f8 100644 --- a/pde/fields/base.py +++ b/pde/fields/base.py @@ -2064,8 +2064,8 @@ def _update_image_plot(self, reference: PlotReference) -> None: def _plot_vector( self, ax, + *, method: Literal["quiver", "streamplot"] = "quiver", - transpose: bool = False, max_points: int = 16, **kwargs, ) -> PlotReference: @@ -2077,13 +2077,12 @@ def _plot_vector( method (str): Plot type that is used. This can be either `quiver` or `streamplot`. - transpose (bool): - Determines whether the transpose of the data should be plotted. max_points (int): The maximal number of points that is used along each axis. This argument is only used for quiver plots. \**kwargs: Additional keyword arguments are passed to + :meth:`~pde.field.base.DataFieldBase.get_vector_data` and :func:`matplotlib.pyplot.quiver` or :func:`matplotlib.pyplot.streamplot`. @@ -2092,29 +2091,32 @@ def _plot_vector( the plot with new data later. """ # store the parameters of this plot for later updating - parameters: dict[str, Any] = { - "method": method, - "transpose": transpose, - "kwargs": kwargs, - } + parameters: dict[str, Any] = {"method": method, "kwargs": kwargs} + + # obtain parameter used to extract vector data + data_kws = {} + for arg in ["performance_goal", "transpose"]: + if arg in kwargs: + data_kws[arg] = kwargs.pop(arg) if method == "quiver": # plot vector field using a quiver plot - data = self.get_vector_data(transpose=transpose, max_points=max_points) - parameters["max_points"] = max_points # only save for quiver plot + data_kws["max_points"] = max_points + data = self.get_vector_data(**data_kws) element = ax.quiver( data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs ) elif method == "streamplot": # plot vector field using a streamplot - data = self.get_vector_data(transpose=transpose) + data = self.get_vector_data(**data_kws) element = ax.streamplot( data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs ) else: raise ValueError(f"Vector plot `{method}` is not supported.") + parameters["data_kws"] = data_kws # save data parameters # set some default properties of the plot ax.set_aspect("equal") @@ -2133,19 +2135,18 @@ def _update_vector_plot(self, reference: PlotReference) -> None: """ # extract general parameters method = reference.parameters.get("method", "quiver") - transpose = reference.parameters.get("transpose", False) + data_kws = reference.parameters.get("data_kws", {}) if method == "quiver": # update the data of a quiver plot - max_points = reference.parameters.get("max_points") - data = self.get_vector_data(transpose=transpose, max_points=max_points) + data = self.get_vector_data(**data_kws) reference.element.set_UVC(data["data_x"], data["data_y"]) elif method == "streamplot": # update a streamplot by redrawing it completely ax = reference.ax kwargs = reference.parameters.get("kwargs", {}) - data = self.get_vector_data(transpose=transpose) + data = self.get_vector_data(**data_kws) # remove old streamplot ax.cla() # update with new streamplot