Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved some types #470

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def _apply_operator(
)

def make_dot_operator(
self, backend: str = "numba", *, conjugate: bool = True
self, backend: Literal["numpy", "numba"] = "numba", *, conjugate: bool = True
) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]:
"""return operator calculating the dot product between two fields

Expand Down
14 changes: 12 additions & 2 deletions pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)

import numba as nb
import numpy as np
Expand Down Expand Up @@ -247,7 +257,7 @@ def outer_product(
return out

def make_outer_prod_operator(
self, backend: str = "numba"
self, backend: Literal["numpy", "numba"] = "numba"
) -> Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], np.ndarray]:
"""return operator calculating the outer product of two vector fields

Expand Down
7 changes: 4 additions & 3 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..tools.math import OnlineStatistics
from ..tools.misc import classproperty
from ..tools.numba import is_jitted, jit
from ..tools.typing import BackendType


class SolverBase(metaclass=ABCMeta):
Expand All @@ -38,7 +39,7 @@ class SolverBase(metaclass=ABCMeta):
_subclasses: Dict[str, Type[SolverBase]] = {}
"""dict: dictionary of all inheriting classes"""

def __init__(self, pde: PDEBase, *, backend: str = "auto"):
def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"):
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
Expand Down Expand Up @@ -140,7 +141,7 @@ def modify_after_step(state_data: np.ndarray) -> float:
return modify_after_step # type: ignore

def _make_pde_rhs(
self, state: FieldBase, backend: str = "auto"
self, state: FieldBase, backend: BackendType = "auto"
) -> Callable[[np.ndarray, float], np.ndarray]:
"""obtain a function for evaluating the right hand side

Expand Down Expand Up @@ -328,7 +329,7 @@ def __init__(
self,
pde: PDEBase,
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = True,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..pdes.base import PDEBase
from ..tools.math import OnlineStatistics
from ..tools.numba import jit
from ..tools.typing import BackendType
from .base import AdaptiveSolverBase


Expand All @@ -26,7 +27,7 @@ def __init__(
pde: PDEBase,
scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler",
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = False,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..pdes.base import PDEBase
from ..tools import mpi
from ..tools.math import OnlineStatistics
from ..tools.typing import BackendType
from .explicit import ExplicitSolver


Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
scheme: Literal["euler", "runge-kutta", "rk", "rk45"] = "euler",
decomposition: Union[int, List[int]] = -1,
*,
backend: str = "auto",
backend: BackendType = "auto",
adaptive: bool = False,
tolerance: float = 1e-4,
):
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import SolverBase


Expand All @@ -28,7 +29,7 @@ def __init__(
pde: PDEBase,
maxiter: int = 100,
maxerror: float = 1e-4,
backend: str = "auto",
backend: BackendType = "auto",
):
"""
Args:
Expand Down
3 changes: 2 additions & 1 deletion pde/solvers/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import SolverBase


Expand All @@ -22,7 +23,7 @@ class ScipySolver(SolverBase):

name = "scipy"

def __init__(self, pde: PDEBase, backend: str = "auto", **kwargs):
def __init__(self, pde: PDEBase, backend: BackendType = "auto", **kwargs):
r"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
Expand Down
2 changes: 1 addition & 1 deletion pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class StorageBase(metaclass=ABCMeta):

times: Sequence[float] # :class:`~numpy.ndarray`): stored time points
data: Any # actual data for all the stored times
write_mode: str # mode determining how the storage behaves
write_mode: WriteModeType # mode determining how the storage behaves

def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions pde/tools/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def noise_normal():
scaling = 2 * np.pi * scale * k2s ** (exponent / 4)
scaling.flat[0] = 0

# TODO: accelerate the FFT using the pyfftw package

def noise_colored() -> np.ndarray:
"""return array of colored noise"""
# random field
Expand Down
3 changes: 2 additions & 1 deletion pde/tools/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

from typing import TYPE_CHECKING, Protocol, Tuple, Union
from typing import TYPE_CHECKING, Literal, Protocol, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike # @UnusedImport
Expand All @@ -16,6 +16,7 @@
Number = Union[Real, complex]
NumberOrArray = Union[Number, np.ndarray]
FloatNumerical = Union[float, np.ndarray]
BackendType = Literal["auto", "numpy", "numba"]


class OperatorType(Protocol):
Expand Down