Skip to content

Commit

Permalink
chore: update pyplot typing and allowed styles
Browse files Browse the repository at this point in the history
  • Loading branch information
elmomoilanen committed Feb 24, 2024
1 parent 7659274 commit ad37c7c
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions sampdist/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Dict, Any

import numpy as np
import matplotlib.pyplot as plt # type: ignore[import]
import matplotlib.pyplot as plt
from matplotlib.axes import Axes


class Plotting:
Expand All @@ -14,21 +15,17 @@ class Plotting:
Pyplot style sheet to be used in histogram plots.
"""

allowed_styles = (
"default",
"seaborn",
"Solarize_Light2",
"ggplot",
allowed_styles = ["default"] + sorted(
style for style in plt.style.available if not style.startswith("_")
)

def __init__(self, plot_style_sheet: str = "default") -> None:
self.style_sheet = plot_style_sheet
self._check_style_validity()
if plot_style_sheet not in self.allowed_styles:
raise ValueError(
f"Style {plot_style_sheet} not in allowed styles list: {', '.join(self.allowed_styles)}"
)

def _check_style_validity(self) -> None:
if self.style_sheet not in self.allowed_styles:
allowed = ", ".join(self.allowed_styles)
raise ValueError(f"Style {self.style_sheet} not in allowed styles list: {allowed}")
self.style_sheet = plot_style_sheet

@staticmethod
def _compute_percentile_ci(data: np.ndarray, alpha: float) -> Any:
Expand Down Expand Up @@ -65,7 +62,7 @@ def _generate_font_family() -> Dict[str, Dict[str, Any]]:
}

@staticmethod
def _set_text_field(ax: plt.Axes, config: Dict[str, Any]) -> None:
def _set_text_field(ax: Axes, config: Dict[str, Any]) -> None:
ax.text(
config["x"],
config["y"],
Expand All @@ -78,7 +75,7 @@ def _set_text_field(ax: plt.Axes, config: Dict[str, Any]) -> None:
)

@staticmethod
def _set_annotation(ax: plt.Axes, config: Dict[str, Any]) -> None:
def _set_annotation(ax: Axes, config: Dict[str, Any]) -> None:
arrow_points = 1 if config["arrow_direction"] == "down" else -1

ax.annotate(
Expand Down

0 comments on commit ad37c7c

Please sign in to comment.