Skip to content

Commit

Permalink
Improved some types
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Sep 23, 2023
1 parent 45730dc commit e2e4aad
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def _apply_operator(
)

def make_dot_operator(
self, backend: str = "numba", *, conjugate: bool = True
self, backend: Literal["numpy", "numba"] = "numba", *, conjugate: bool = True
) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]:
"""return operator calculating the dot product between two fields
Expand Down
14 changes: 12 additions & 2 deletions pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)

import numba as nb
import numpy as np
Expand Down Expand Up @@ -247,7 +257,7 @@ def outer_product(
return out

def make_outer_prod_operator(
self, backend: str = "numba"
self, backend: Literal["numpy", "numba"] = "numba"
) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]:
"""return operator calculating the outer product of two vector fields
Expand Down
7 changes: 4 additions & 3 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..tools.math import OnlineStatistics
from ..tools.misc import classproperty
from ..tools.numba import is_jitted, jit
from ..tools.typing import BackendType


class SolverBase(metaclass=ABCMeta):
Expand All @@ -38,7 +39,7 @@ class SolverBase(metaclass=ABCMeta):
_subclasses: Dict[str, Type[SolverBase]] = {}
"""dict: dictionary of all inheriting classes"""

def __init__(self, pde: PDEBase, *, backend: str = "auto"):
def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"):
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
Expand Down Expand Up @@ -140,7 +141,7 @@ def modify_after_step(state_data: np.ndarray) -> float:
return modify_after_step # type: ignore

def _make_pde_rhs(
self, state: FieldBase, backend: str = "auto"
self, state: FieldBase, backend: BackendType = "auto"
) -> Callable[[np.ndarray, float], np.ndarray]:
"""obtain a function for evaluating the right hand side
Expand Down Expand Up @@ -328,7 +329,7 @@ def __init__(
self,
pde: PDEBase,
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = True,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.numba import jit
from ..tools.typing import BackendType
from .base import AdaptiveSolverBase


Expand All @@ -26,7 +27,7 @@ def __init__(
pde: PDEBase,
scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler",
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = False,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..pdes.base import PDEBase
from ..tools import mpi
from ..tools.math import OnlineStatistics
from ..tools.typing import BackendType
from .explicit import ExplicitSolver


Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler",
decomposition: Union[int, List[int]] = -1,
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = False,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import SolverBase


Expand All @@ -28,7 +29,7 @@ def __init__(
pde: PDEBase,
maxiter: int = 100,
maxerror: float = 1e-4,
backend: str = "auto",
backend: BackendType = "auto",
):
"""
Args:
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import SolverBase


Expand All @@ -22,7 +23,7 @@ class ScipySolver(SolverBase):

name = "scipy"

def __init__(self, pde: PDEBase, backend: str = "auto", **kwargs):
def __init__(self, pde: PDEBase, backend: BackendType = "auto", **kwargs):
r"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
Expand Down
2 changes: 1 addition & 1 deletion pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class StorageBase(metaclass=ABCMeta):

times: Sequence[float] # :class:`~numpy.ndarray`): stored time points
data: Any # actual data for all the stored times
write_mode: str # mode determining how the storage behaves
write_mode: WriteModeType # mode determining how the storage behaves

def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions pde/tools/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def noise_normal():
scaling = 2 * np.pi * scale * k2s ** (exponent / 4)
scaling.flat[0] = 0

# TODO: accelerate the FFT using the pyfftw package

def noise_colored() -> np.ndarray:
"""return array of colored noise"""
# random field
Expand Down
3 changes: 2 additions & 1 deletion pde/tools/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

from typing import TYPE_CHECKING, Protocol, Tuple, Union
from typing import TYPE_CHECKING, Literal, Protocol, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike # @UnusedImport
Expand All @@ -16,6 +16,7 @@
Number = Union[Real, complex]
NumberOrArray = Union[Number, np.ndarray]
FloatNumerical = Union[float, np.ndarray]
BackendType = Literal["auto", "numpy", "numba"]


class OperatorType(Protocol):
Expand Down

0 comments on commit e2e4aad

Please sign in to comment.