diff --git a/pde/fields/base.py b/pde/fields/base.py index 47dd0aff..2a8c4472 100644 --- a/pde/fields/base.py +++ b/pde/fields/base.py @@ -1684,7 +1684,7 @@ def _apply_operator( ) def make_dot_operator( - self, backend: str = "numba", *, conjugate: bool = True + self, backend: Literal["numpy", "numba"] = "numba", *, conjugate: bool = True ) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]: """return operator calculating the dot product between two fields diff --git a/pde/fields/vectorial.py b/pde/fields/vectorial.py index b61b6af7..4f4f1982 100644 --- a/pde/fields/vectorial.py +++ b/pde/fields/vectorial.py @@ -6,7 +6,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) import numba as nb import numpy as np @@ -247,7 +257,7 @@ def outer_product( return out def make_outer_prod_operator( - self, backend: str = "numba" + self, backend: Literal["numpy", "numba"] = "numba" ) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]: """return operator calculating the outer product of two vector fields diff --git a/pde/solvers/base.py b/pde/solvers/base.py index beedffb0..75f38c7e 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -24,6 +24,7 @@ from ..tools.math import OnlineStatistics from ..tools.misc import classproperty from ..tools.numba import is_jitted, jit +from ..tools.typing import BackendType class SolverBase(metaclass=ABCMeta): @@ -38,7 +39,7 @@ class SolverBase(metaclass=ABCMeta): _subclasses: Dict[str, Type[SolverBase]] = {} """dict: dictionary of all inheriting classes""" - def __init__(self, pde: PDEBase, *, backend: str = "auto"): + def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"): """ Args: pde (:class:`~pde.pdes.base.PDEBase`): @@ -140,7 +141,7 @@ def modify_after_step(state_data: np.ndarray) -> float: return modify_after_step # type: ignore def _make_pde_rhs( - self, state: FieldBase, backend: str = "auto" + self, state: FieldBase, backend: BackendType = "auto" ) -> Callable[[np.ndarray, float], np.ndarray]: """obtain a function for evaluating the right hand side @@ -328,7 +329,7 @@ def __init__( self, pde: PDEBase, *, - backend: str = "auto", + backend: BackendType = "auto", adaptive: bool = True, tolerance: float = 1e-4, ): diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py index fd4b6ca8..7d963de0 100644 --- a/pde/solvers/explicit.py +++ b/pde/solvers/explicit.py @@ -13,6 +13,7 @@ from ..pdes.base import PDEBase from ..tools.math import OnlineStatistics from ..tools.numba import jit +from ..tools.typing import BackendType from .base import AdaptiveSolverBase @@ -26,7 +27,7 @@ def __init__( pde: PDEBase, scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler", *, - backend: str = "auto", + backend: BackendType = "auto", adaptive: bool = False, tolerance: float = 1e-4, ): diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index da6eba75..25e56865 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -14,6 +14,7 @@ from ..pdes.base import PDEBase from ..tools import mpi from ..tools.math import OnlineStatistics +from ..tools.typing import BackendType from .explicit import ExplicitSolver @@ -80,7 +81,7 @@ def __init__( scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler", decomposition: Union[int, List[int]] = -1, *, - backend: str = "auto", + backend: BackendType = "auto", adaptive: bool = False, tolerance: float = 1e-4, ): diff --git a/pde/solvers/implicit.py b/pde/solvers/implicit.py index d58f4558..37473c8f 100644 --- a/pde/solvers/implicit.py +++ b/pde/solvers/implicit.py @@ -11,6 +11,7 @@ from ..fields.base import FieldBase from ..pdes.base import PDEBase +from ..tools.typing import BackendType from .base import SolverBase @@ -28,7 +29,7 @@ def __init__( pde: PDEBase, maxiter: int = 100, maxerror: float = 1e-4, - backend: str = "auto", + backend: BackendType = "auto", ): """ Args: diff --git a/pde/solvers/scipy.py b/pde/solvers/scipy.py index 8d56b639..5fc20c4a 100644 --- a/pde/solvers/scipy.py +++ b/pde/solvers/scipy.py @@ -10,6 +10,7 @@ from ..fields.base import FieldBase from ..pdes.base import PDEBase +from ..tools.typing import BackendType from .base import SolverBase @@ -22,7 +23,7 @@ class ScipySolver(SolverBase): name = "scipy" - def __init__(self, pde: PDEBase, backend: str = "auto", **kwargs): + def __init__(self, pde: PDEBase, backend: BackendType = "auto", **kwargs): r""" Args: pde (:class:`~pde.pdes.base.PDEBase`): diff --git a/pde/storage/base.py b/pde/storage/base.py index 5a14dde8..6483fa50 100644 --- a/pde/storage/base.py +++ b/pde/storage/base.py @@ -54,7 +54,7 @@ class StorageBase(metaclass=ABCMeta): times: Sequence[float] # :class:`~numpy.ndarray`): stored time points data: Any # actual data for all the stored times - write_mode: str # mode determining how the storage behaves + write_mode: WriteModeType # mode determining how the storage behaves def __init__( self, diff --git a/pde/tools/spectral.py b/pde/tools/spectral.py index edf89be8..6d1c7dd5 100644 --- a/pde/tools/spectral.py +++ b/pde/tools/spectral.py @@ -89,8 +89,6 @@ def noise_normal(): scaling = 2 * np.pi * scale * k2s ** (exponent / 4) scaling.flat[0] = 0 - # TODO: accelerate the FFT using the pyfftw package - def noise_colored() -> np.ndarray: """return array of colored noise""" # random field diff --git a/pde/tools/typing.py b/pde/tools/typing.py index 9b45b48d..1ebe044d 100644 --- a/pde/tools/typing.py +++ b/pde/tools/typing.py @@ -4,7 +4,7 @@ .. codeauthor:: David Zwicker """ -from typing import TYPE_CHECKING, Protocol, Tuple, Union +from typing import TYPE_CHECKING, Literal, Protocol, Tuple, Union import numpy as np from numpy.typing import ArrayLike # @UnusedImport @@ -16,6 +16,7 @@ Number = Union[Real, complex] NumberOrArray = Union[Number, np.ndarray] FloatNumerical = Union[float, np.ndarray] +BackendType = Literal["auto", "numpy", "numba"] class OperatorType(Protocol):