Skip to content

Commit

Permalink
More precise typing of arguments by using Literal (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker authored Sep 23, 2023
1 parent f71e6c3 commit 45730dc
Show file tree
Hide file tree
Showing 22 changed files with 190 additions and 72 deletions.
7 changes: 4 additions & 3 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -924,7 +925,7 @@ def random_normal(
mean: float = 0,
std: float = 1,
*,
scaling: str = "none",
scaling: Literal["none", "physical"] = "none",
label: Optional[str] = None,
dtype: Optional[DTypeLike] = None,
rng: Optional[np.random.Generator] = None,
Expand Down Expand Up @@ -2135,7 +2136,7 @@ def _update_image_plot(self, reference: PlotReference) -> None:
def _plot_vector(
self,
ax,
method: str = "quiver",
method: Literal["quiver", "streamplot"] = "quiver",
transpose: bool = False,
max_points: int = 16,
**kwargs,
Expand Down Expand Up @@ -2163,7 +2164,7 @@ def _plot_vector(
the plot with new data later.
"""
# store the parameters of this plot for later updating
parameters = {
parameters: Dict[str, Any] = {
"method": method,
"transpose": transpose,
"kwargs": kwargs,
Expand Down
16 changes: 13 additions & 3 deletions pde/fields/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

import numbers
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
from numpy.typing import DTypeLike
Expand Down Expand Up @@ -267,7 +277,7 @@ def integral(self) -> Number:
def project(
self,
axes: Union[str, Sequence[str]],
method: str = "integral",
method: Literal["integral", "average", "mean"] = "integral",
label: Optional[str] = None,
) -> ScalarField:
"""project scalar field along given axes
Expand Down Expand Up @@ -323,7 +333,7 @@ def slice(
self,
position: Dict[str, float],
*,
method: str = "nearest",
method: Literal["nearest"] = "nearest",
label: Optional[str] = None,
) -> ScalarField:
"""slice data at a given position
Expand Down
3 changes: 2 additions & 1 deletion pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

PI_4 = 4 * np.pi
PI_43 = 4 / 3 * np.pi
CoordsType = Literal["cartesian", "grid", "cells"]


class OperatorInfo(NamedTuple):
Expand Down Expand Up @@ -835,7 +836,7 @@ def get_image_data(self, data: np.ndarray) -> Dict[str, Any]:

@abstractmethod
def get_random_point(
self, *, boundary_distance: float = 0, coords: str = "cartesian"
self, *, boundary_distance: float = 0, coords: CoordsType = "cartesian"
) -> np.ndarray:
...

Expand Down
12 changes: 8 additions & 4 deletions pde/grids/boundaries/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -1097,6 +1098,9 @@ def ghost_cell_setter(data_full: np.ndarray, args=None) -> None:
return ghost_cell_setter # type: ignore


ExpressionBCTargetType = Literal["value", "derivative", "mixed", "virtual_point"]


class ExpressionBC(BCBase):
"""represents a boundary whose virtual point is calculated from an expression
Expand All @@ -1118,7 +1122,7 @@ def __init__(
rank: int = 0,
value: Union[float, str, Callable] = 0,
const: Union[float, str, Callable] = 0,
target: str = "virtual_point",
target: ExpressionBCTargetType = "virtual_point",
):
r"""
Warning:
Expand Down Expand Up @@ -1485,7 +1489,7 @@ def __init__(
*,
rank: int = 0,
value: Union[float, str, Callable] = 0,
target: str = "value",
target: ExpressionBCTargetType = "value",
):
super().__init__(grid, axis, upper, rank=rank, value=value, target=target)

Expand All @@ -1511,7 +1515,7 @@ def __init__(
*,
rank: int = 0,
value: Union[float, str, Callable] = 0,
target: str = "derivative",
target: ExpressionBCTargetType = "derivative",
):
super().__init__(grid, axis, upper, rank=rank, value=value, target=target)

Expand All @@ -1538,7 +1542,7 @@ def __init__(
rank: int = 0,
value: Union[float, str, Callable] = 0,
const: Union[float, str, Callable] = 0,
target: str = "mixed",
target: ExpressionBCTargetType = "mixed",
):
super().__init__(
grid, axis, upper, rank=rank, value=value, const=const, target=target
Expand Down
4 changes: 2 additions & 2 deletions pde/grids/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..tools.cuboid import Cuboid
from ..tools.plotting import plot_on_axes
from .base import DimensionError, GridBase, _check_shape
from .base import CoordsType, DimensionError, GridBase, _check_shape

if TYPE_CHECKING:
from .boundaries.axes import Boundaries, BoundariesData # @UnusedImport
Expand Down Expand Up @@ -254,7 +254,7 @@ def get_random_point(
self,
*,
boundary_distance: float = 0,
coords: str = "cartesian",
coords: CoordsType = "cartesian",
rng: Optional[np.random.Generator] = None,
) -> np.ndarray:
"""return a random point within the grid
Expand Down
26 changes: 22 additions & 4 deletions pde/grids/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,28 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Literal,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np

from ..tools.cache import cached_property
from .base import DimensionError, GridBase, _check_shape, discretize_interval
from .base import (
CoordsType,
DimensionError,
GridBase,
_check_shape,
discretize_interval,
)
from .cartesian import CartesianGrid

if TYPE_CHECKING:
Expand Down Expand Up @@ -213,7 +229,7 @@ def get_random_point(
*,
boundary_distance: float = 0,
avoid_center: bool = False,
coords: str = "cartesian",
coords: CoordsType = "cartesian",
rng: Optional[np.random.Generator] = None,
) -> np.ndarray:
"""return a random point within the grid
Expand Down Expand Up @@ -461,7 +477,9 @@ def polar_coordinates_real(
else:
return dist

def get_cartesian_grid(self, mode: str = "valid") -> CartesianGrid:
def get_cartesian_grid(
self, mode: Literal["valid", "full"] = "valid"
) -> CartesianGrid:
"""return a Cartesian grid for this Cylindrical one
Args:
Expand Down
40 changes: 30 additions & 10 deletions pde/grids/operators/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

from typing import Callable, Tuple
from typing import Callable, Literal, Tuple

import numba as nb
import numpy as np
Expand Down Expand Up @@ -165,7 +165,9 @@ def _get_laplace_matrix(bcs: Boundaries) -> Tuple[np.ndarray, np.ndarray]:


def _make_derivative(
grid: CartesianGrid, axis: int = 0, method: str = "central"
grid: CartesianGrid,
axis: int = 0,
method: Literal["central", "forward", "backward"] = "central",
) -> OperatorType:
"""make a derivative operator along a single axis using numba compilation
Expand Down Expand Up @@ -520,7 +522,10 @@ def laplace(arr: np.ndarray, out: np.ndarray) -> None:


@CartesianGrid.register_operator("laplace", rank_in=0, rank_out=0)
def make_laplace(grid: CartesianGrid, backend: str = "auto") -> OperatorType:
def make_laplace(
grid: CartesianGrid,
backend: Literal["auto", "numba", "numba-spectral", "scipy"] = "auto",
) -> OperatorType:
"""make a Laplace operator on a Cartesian grid
Args:
Expand Down Expand Up @@ -691,7 +696,9 @@ def gradient(arr: np.ndarray, out: np.ndarray) -> None:


@CartesianGrid.register_operator("gradient", rank_in=0, rank_out=1)
def make_gradient(grid: CartesianGrid, backend: str = "auto") -> OperatorType:
def make_gradient(
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "auto"
) -> OperatorType:
"""make a gradient operator on a Cartesian grid
Args:
Expand Down Expand Up @@ -1046,7 +1053,9 @@ def divergence(arr: np.ndarray, out: np.ndarray) -> None:


@CartesianGrid.register_operator("divergence", rank_in=1, rank_out=0)
def make_divergence(grid: CartesianGrid, backend: str = "auto") -> OperatorType:
def make_divergence(
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "auto"
) -> OperatorType:
"""make a divergence operator on a Cartesian grid
Args:
Expand Down Expand Up @@ -1090,7 +1099,10 @@ def make_divergence(grid: CartesianGrid, backend: str = "auto") -> OperatorType:


def _vectorize_operator(
make_operator: Callable, grid: CartesianGrid, *, backend: str = "numba"
make_operator: Callable,
grid: CartesianGrid,
*,
backend: Literal["auto", "numba", "scipy"] = "numba",
) -> OperatorType:
"""apply an operator to on all dimensions of a vector
Expand Down Expand Up @@ -1120,7 +1132,9 @@ def vectorized_operator(arr: np.ndarray, out: np.ndarray) -> None:


@CartesianGrid.register_operator("vector_gradient", rank_in=1, rank_out=2)
def make_vector_gradient(grid: CartesianGrid, backend: str = "numba") -> OperatorType:
def make_vector_gradient(
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba"
) -> OperatorType:
"""make a vector gradient operator on a Cartesian grid
Args:
Expand All @@ -1136,7 +1150,9 @@ def make_vector_gradient(grid: CartesianGrid, backend: str = "numba") -> Operato


@CartesianGrid.register_operator("vector_laplace", rank_in=1, rank_out=1)
def make_vector_laplace(grid: CartesianGrid, backend: str = "numba") -> OperatorType:
def make_vector_laplace(
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba"
) -> OperatorType:
"""make a vector Laplacian on a Cartesian grid
Args:
Expand All @@ -1152,7 +1168,9 @@ def make_vector_laplace(grid: CartesianGrid, backend: str = "numba") -> Operator


@CartesianGrid.register_operator("tensor_divergence", rank_in=2, rank_out=1)
def make_tensor_divergence(grid: CartesianGrid, backend: str = "numba") -> OperatorType:
def make_tensor_divergence(
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba"
) -> OperatorType:
"""make a tensor divergence operator on a Cartesian grid
Args:
Expand All @@ -1168,7 +1186,9 @@ def make_tensor_divergence(grid: CartesianGrid, backend: str = "numba") -> Opera


@CartesianGrid.register_operator("poisson_solver", rank_in=0, rank_out=0)
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType:
def make_poisson_solver(
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto"
) -> OperatorType:
"""make a operator that solves Poisson's equation
Args:
Expand Down
6 changes: 4 additions & 2 deletions pde/grids/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
import warnings
from typing import Callable, Optional
from typing import Callable, Literal, Optional

import numpy as np

Expand Down Expand Up @@ -56,7 +56,9 @@ def laplace(arr: np.ndarray, out: Optional[np.ndarray] = None) -> np.ndarray:
return laplace


def make_general_poisson_solver(matrix, vector, method: str = "auto") -> OperatorType:
def make_general_poisson_solver(
matrix, vector, method: Literal["auto", "scipy"] = "auto"
) -> OperatorType:
"""make an operator that solves Poisson's problem
Args:
Expand Down
6 changes: 4 additions & 2 deletions pde/grids/operators/cylindrical_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

from typing import Tuple
from typing import Literal, Tuple

import numba as nb
import numpy as np
Expand Down Expand Up @@ -430,7 +430,9 @@ def tensor_divergence(arr: np.ndarray, out: np.ndarray) -> None:

@CylindricalSymGrid.register_operator("poisson_solver", rank_in=0, rank_out=0)
@fill_in_docstring
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType:
def make_poisson_solver(
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto"
) -> OperatorType:
"""make a operator that solves Poisson's equation
{DESCR_CYLINDRICAL_GRID}
Expand Down
6 changes: 4 additions & 2 deletions pde/grids/operators/polar_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

from typing import Tuple
from typing import Literal, Tuple

import numpy as np

Expand Down Expand Up @@ -310,7 +310,9 @@ def _get_laplace_matrix(bcs: Boundaries) -> Tuple[np.ndarray, np.ndarray]:

@PolarSymGrid.register_operator("poisson_solver", rank_in=0, rank_out=0)
@fill_in_docstring
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType:
def make_poisson_solver(
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto"
) -> OperatorType:
"""make a operator that solves Poisson's equation
{DESCR_POLAR_GRID}
Expand Down
Loading

0 comments on commit 45730dc

Please sign in to comment.