-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More precise typing of arguments by using
Literal
- Loading branch information
1 parent
f71e6c3
commit 09ee3e9
Showing
22 changed files
with
190 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
.. codeauthor:: David Zwicker <[email protected]> | ||
""" | ||
|
||
from typing import Callable, Tuple | ||
from typing import Callable, Literal, Tuple | ||
|
||
import numba as nb | ||
import numpy as np | ||
|
@@ -165,7 +165,9 @@ def _get_laplace_matrix(bcs: Boundaries) -> Tuple[np.ndarray, np.ndarray]: | |
|
||
|
||
def _make_derivative( | ||
grid: CartesianGrid, axis: int = 0, method: str = "central" | ||
grid: CartesianGrid, | ||
axis: int = 0, | ||
method: Literal["central", "forward", "backward"] = "central", | ||
) -> OperatorType: | ||
"""make a derivative operator along a single axis using numba compilation | ||
|
@@ -520,7 +522,10 @@ def laplace(arr: np.ndarray, out: np.ndarray) -> None: | |
|
||
|
||
@CartesianGrid.register_operator("laplace", rank_in=0, rank_out=0) | ||
def make_laplace(grid: CartesianGrid, backend: str = "auto") -> OperatorType: | ||
def make_laplace( | ||
grid: CartesianGrid, | ||
backend: Literal["auto", "numba", "numba-spectral", "scipy"] = "auto", | ||
) -> OperatorType: | ||
"""make a Laplace operator on a Cartesian grid | ||
Args: | ||
|
@@ -691,7 +696,9 @@ def gradient(arr: np.ndarray, out: np.ndarray) -> None: | |
|
||
|
||
@CartesianGrid.register_operator("gradient", rank_in=0, rank_out=1) | ||
def make_gradient(grid: CartesianGrid, backend: str = "auto") -> OperatorType: | ||
def make_gradient( | ||
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "auto" | ||
) -> OperatorType: | ||
"""make a gradient operator on a Cartesian grid | ||
Args: | ||
|
@@ -1046,7 +1053,9 @@ def divergence(arr: np.ndarray, out: np.ndarray) -> None: | |
|
||
|
||
@CartesianGrid.register_operator("divergence", rank_in=1, rank_out=0) | ||
def make_divergence(grid: CartesianGrid, backend: str = "auto") -> OperatorType: | ||
def make_divergence( | ||
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "auto" | ||
) -> OperatorType: | ||
"""make a divergence operator on a Cartesian grid | ||
Args: | ||
|
@@ -1090,7 +1099,10 @@ def make_divergence(grid: CartesianGrid, backend: str = "auto") -> OperatorType: | |
|
||
|
||
def _vectorize_operator( | ||
make_operator: Callable, grid: CartesianGrid, *, backend: str = "numba" | ||
make_operator: Callable, | ||
grid: CartesianGrid, | ||
*, | ||
backend: Literal["auto", "numba", "scipy"] = "numba", | ||
) -> OperatorType: | ||
"""apply an operator to on all dimensions of a vector | ||
|
@@ -1120,7 +1132,9 @@ def vectorized_operator(arr: np.ndarray, out: np.ndarray) -> None: | |
|
||
|
||
@CartesianGrid.register_operator("vector_gradient", rank_in=1, rank_out=2) | ||
def make_vector_gradient(grid: CartesianGrid, backend: str = "numba") -> OperatorType: | ||
def make_vector_gradient( | ||
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba" | ||
) -> OperatorType: | ||
"""make a vector gradient operator on a Cartesian grid | ||
Args: | ||
|
@@ -1136,7 +1150,9 @@ def make_vector_gradient(grid: CartesianGrid, backend: str = "numba") -> Operato | |
|
||
|
||
@CartesianGrid.register_operator("vector_laplace", rank_in=1, rank_out=1) | ||
def make_vector_laplace(grid: CartesianGrid, backend: str = "numba") -> OperatorType: | ||
def make_vector_laplace( | ||
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba" | ||
) -> OperatorType: | ||
"""make a vector Laplacian on a Cartesian grid | ||
Args: | ||
|
@@ -1152,7 +1168,9 @@ def make_vector_laplace(grid: CartesianGrid, backend: str = "numba") -> Operator | |
|
||
|
||
@CartesianGrid.register_operator("tensor_divergence", rank_in=2, rank_out=1) | ||
def make_tensor_divergence(grid: CartesianGrid, backend: str = "numba") -> OperatorType: | ||
def make_tensor_divergence( | ||
grid: CartesianGrid, backend: Literal["auto", "numba", "scipy"] = "numba" | ||
) -> OperatorType: | ||
"""make a tensor divergence operator on a Cartesian grid | ||
Args: | ||
|
@@ -1168,7 +1186,9 @@ def make_tensor_divergence(grid: CartesianGrid, backend: str = "numba") -> Opera | |
|
||
|
||
@CartesianGrid.register_operator("poisson_solver", rank_in=0, rank_out=0) | ||
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType: | ||
def make_poisson_solver( | ||
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto" | ||
) -> OperatorType: | ||
"""make a operator that solves Poisson's equation | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
.. codeauthor:: David Zwicker <[email protected]> | ||
""" | ||
|
||
from typing import Tuple | ||
from typing import Literal, Tuple | ||
|
||
import numba as nb | ||
import numpy as np | ||
|
@@ -430,7 +430,9 @@ def tensor_divergence(arr: np.ndarray, out: np.ndarray) -> None: | |
|
||
@CylindricalSymGrid.register_operator("poisson_solver", rank_in=0, rank_out=0) | ||
@fill_in_docstring | ||
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType: | ||
def make_poisson_solver( | ||
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto" | ||
) -> OperatorType: | ||
"""make a operator that solves Poisson's equation | ||
{DESCR_CYLINDRICAL_GRID} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
.. codeauthor:: David Zwicker <[email protected]> | ||
""" | ||
|
||
from typing import Tuple | ||
from typing import Literal, Tuple | ||
|
||
import numpy as np | ||
|
||
|
@@ -310,7 +310,9 @@ def _get_laplace_matrix(bcs: Boundaries) -> Tuple[np.ndarray, np.ndarray]: | |
|
||
@PolarSymGrid.register_operator("poisson_solver", rank_in=0, rank_out=0) | ||
@fill_in_docstring | ||
def make_poisson_solver(bcs: Boundaries, method: str = "auto") -> OperatorType: | ||
def make_poisson_solver( | ||
bcs: Boundaries, method: Literal["auto", "scipy"] = "auto" | ||
) -> OperatorType: | ||
"""make a operator that solves Poisson's equation | ||
{DESCR_POLAR_GRID} | ||
|
Oops, something went wrong.