Skip to content

Commit

Permalink
Improve argument handling of vector plots
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Jan 5, 2024
1 parent 517b8f4 commit 22823a4
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,8 +2064,8 @@ def _update_image_plot(self, reference: PlotReference) -> None:
def _plot_vector(
self,
ax,
*,
method: Literal["quiver", "streamplot"] = "quiver",
transpose: bool = False,
max_points: int = 16,
**kwargs,
) -> PlotReference:
Expand All @@ -2077,13 +2077,12 @@ def _plot_vector(
method (str):
Plot type that is used. This can be either `quiver` or
`streamplot`.
transpose (bool):
Determines whether the transpose of the data should be plotted.
max_points (int):
The maximal number of points that is used along each axis. This
argument is only used for quiver plots.
\**kwargs:
Additional keyword arguments are passed to
:meth:`~pde.field.base.DataFieldBase.get_vector_data` and
:func:`matplotlib.pyplot.quiver` or
:func:`matplotlib.pyplot.streamplot`.
Expand All @@ -2092,29 +2091,32 @@ def _plot_vector(
the plot with new data later.
"""
# store the parameters of this plot for later updating
parameters: dict[str, Any] = {
"method": method,
"transpose": transpose,
"kwargs": kwargs,
}
parameters: dict[str, Any] = {"method": method, "kwargs": kwargs}

# obtain parameter used to extract vector data
data_kws = {}
for arg in ["performance_goal", "transpose"]:
if arg in kwargs:
data_kws[arg] = kwargs.pop(arg)

if method == "quiver":
# plot vector field using a quiver plot
data = self.get_vector_data(transpose=transpose, max_points=max_points)
parameters["max_points"] = max_points # only save for quiver plot
data_kws["max_points"] = max_points
data = self.get_vector_data(**data_kws)
element = ax.quiver(
data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs
)

elif method == "streamplot":
# plot vector field using a streamplot
data = self.get_vector_data(transpose=transpose)
data = self.get_vector_data(**data_kws)
element = ax.streamplot(
data["x"], data["y"], data["data_x"].T, data["data_y"].T, **kwargs
)

else:
raise ValueError(f"Vector plot `{method}` is not supported.")
parameters["data_kws"] = data_kws # save data parameters

# set some default properties of the plot
ax.set_aspect("equal")
Expand All @@ -2133,19 +2135,18 @@ def _update_vector_plot(self, reference: PlotReference) -> None:
"""
# extract general parameters
method = reference.parameters.get("method", "quiver")
transpose = reference.parameters.get("transpose", False)
data_kws = reference.parameters.get("data_kws", {})

if method == "quiver":
# update the data of a quiver plot
max_points = reference.parameters.get("max_points")
data = self.get_vector_data(transpose=transpose, max_points=max_points)
data = self.get_vector_data(**data_kws)
reference.element.set_UVC(data["data_x"], data["data_y"])

elif method == "streamplot":
# update a streamplot by redrawing it completely
ax = reference.ax
kwargs = reference.parameters.get("kwargs", {})
data = self.get_vector_data(transpose=transpose)
data = self.get_vector_data(**data_kws)
# remove old streamplot
ax.cla()
# update with new streamplot
Expand Down

0 comments on commit 22823a4

Please sign in to comment.