Skip to content

Commit

Permalink
Improved some types and docstrings (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker authored Sep 21, 2023
1 parent 09feb4b commit 2b83ea6
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 158 deletions.
97 changes: 63 additions & 34 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def from_state(
The attributes that describe the current instance
data (:class:`~numpy.ndarray`, optional):
Data values at the support points of the grid defining the field
Returns:
:class:`FieldBase`: The field created from the state
"""
# base class was chosen => select correct class from attributes
class_name = attributes.pop("class")
Expand Down Expand Up @@ -237,7 +240,8 @@ def from_file(cls, filename: str) -> FieldBase:
field_copy = pde.FieldBase.from_file("test.hdf5")
Args:
filename (str): Path to the file being read
filename (str):
Path to the file being read
Returns:
:class:`FieldBase`: The field with the appropriate sub-class
Expand Down Expand Up @@ -280,10 +284,10 @@ def _from_hdf_dataset(cls, dataset) -> FieldBase:

@property
def grid(self) -> GridBase:
"""GridBase: The grid on which the field is defined"""
""":class:`~pde.grids.base,GridBase`: The grid on which the field is defined"""
return self._grid

def to_file(self, filename: str, **kwargs):
def to_file(self, filename: str, **kwargs) -> None:
r"""store field in a file
The extension of the filename determines what format is being used. If it ends
Expand Down Expand Up @@ -324,7 +328,7 @@ def to_file(self, filename: str, **kwargs):
else:
raise ValueError(f"Do not know how to save data to `*{extension}`")

def _write_hdf_dataset(self, hdf_path, key: str = "data"):
def _write_hdf_dataset(self, hdf_path, key: str = "data") -> None:
"""write data to a given hdf5 path `hdf_path`"""
# write the data
dataset = hdf_path.create_dataset(key, data=self.data)
Expand All @@ -347,7 +351,9 @@ def copy(
) -> TField:
...

def assert_field_compatible(self, other: FieldBase, accept_scalar: bool = False):
def assert_field_compatible(
self, other: FieldBase, accept_scalar: bool = False
) -> None:
"""checks whether `other` is compatible with the current field
Args:
Expand Down Expand Up @@ -437,7 +443,7 @@ def _unary_operation(self: TField, op: Callable) -> TField:
A function calculating the result
Returns:
FieldBase: An field that contains the result of the operation.
:class:`FieldBase`: An field that contains the result of the operation.
"""
return self.__class__(grid=self.grid, data=op(self.data), label=self.label)

Expand All @@ -452,11 +458,18 @@ def imag(self: TField) -> TField:
return self._unary_operation(np.imag)

def conjugate(self: TField) -> TField:
"""returns complex conjugate of the field"""
"""returns complex conjugate of the field
Returns:
:class:`FieldBase`: the complex conjugated field
"""
return self._unary_operation(np.conjugate)

def __neg__(self):
"""return the negative of the current field"""
"""return the negative of the current field
:class:`FieldBase`: The negative of the current field
"""
return self._unary_operation(np.negative)

def _binary_operation(
Expand All @@ -473,7 +486,7 @@ def _binary_operation(
Flag determining whether the second operator must be a scalar
Returns:
FieldBase: An field that contains the result of the operation. If
:class:`FieldBase`: An field that contains the result of the operation. If
`scalar_second == True`, the type of FieldBase is the same as `self`
"""
# determine the dtype of the output
Expand Down Expand Up @@ -526,7 +539,7 @@ def _binary_operation_inplace(
Flag determining whether the second operator must be a scalar.
Returns:
FieldBase: The field `self` with updated data
:class:`FieldBase`: The field `self` with updated data
"""
if isinstance(other, FieldBase):
# right operator is a field
Expand Down Expand Up @@ -638,7 +651,7 @@ def apply(
Only used when `func` is a string.
Returns:
Field with new data. This is identical to `out` if it was given.
:class:`FieldBase`: Field with new data. Identical to `out` if given.
"""
if isinstance(func, str):
# function is given as an expression that will be evaluated
Expand Down Expand Up @@ -694,7 +707,9 @@ def plot(self, *args, **kwargs):
def _get_napari_data(self, **kwargs) -> Dict[str, Dict[str, Any]]:
...

def plot_interactive(self, viewer_args: Optional[Dict[str, Any]] = None, **kwargs):
def plot_interactive(
self, viewer_args: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
"""create an interactive plot of the field using :mod:`napari`
For a detailed description of the launched program, see the
Expand Down Expand Up @@ -878,7 +893,7 @@ def random_uniform(
vmax (float):
Largest random value
label (str, optional):
Name of the field
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double` if both
`vmin` and `vmax` are real, otherwise it is `complex`.
Expand Down Expand Up @@ -937,7 +952,7 @@ def random_normal(
scaled by the inverse volume of the grid cell; this is for instance
useful for concentration fields, which vary less in larger volumes).
label (str, optional):
Name of the field
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double` if both
`mean` and `std` are real, otherwise it is `complex`.
Expand Down Expand Up @@ -1016,7 +1031,7 @@ def random_harmonic(
resulting in products and sums of the values along axes,
respectively.
label (str, optional):
Name of the field
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double`.
rng (:class:`~numpy.random.Generator`):
Expand Down Expand Up @@ -1077,7 +1092,7 @@ def random_colored(
scale (float):
Scaling factor :math:`\Gamma` determining noise strength
label (str, optional):
Name of the field
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it defaults to `double`.
rng (:class:`~numpy.random.Generator`):
Expand Down Expand Up @@ -1105,6 +1120,9 @@ def get_class_by_rank(cls, rank: int) -> Type[DataFieldBase]:
Args:
rank (int): The rank of the tensor field
Returns:
The DataField class that corresponds to the rank
"""
for field_cls in cls._subclasses.values():
if (
Expand All @@ -1128,6 +1146,9 @@ def from_state(
The attributes that describe the current instance
data (:class:`~numpy.ndarray`, optional):
Data values at the support points of the grid defining the field
Returns:
:class:`DataFieldBase`: The instance created from the stored state
"""
if "class" in attributes:
class_name = attributes.pop("class")
Expand All @@ -1142,14 +1163,17 @@ def copy(
label: Optional[str] = None,
dtype: Optional[DTypeLike] = None,
) -> TDataField:
"""return a copy of the data, but not of the grid
"""return a new field with the data (but not the grid) copied
Args:
label (str, optional):
Name of the returned field
dtype (numpy dtype):
The data type of the field. If omitted, it will be determined from
`data` automatically or the dtype of the current field is used.
Returns:
:class:`DataFieldBase`: A copy of the current field
"""
if label is None:
label = self.label
Expand Down Expand Up @@ -1188,7 +1212,7 @@ def unserialize_attributes(cls, attributes: Dict[str, str]) -> Dict[str, Any]:
results[key] = json.loads(value)
return results

def _write_to_image(self, filename: str, **kwargs):
def _write_to_image(self, filename: str, **kwargs) -> None:
r"""write data to image
Args:
Expand Down Expand Up @@ -1519,7 +1543,7 @@ def to_scalar(

@property
def average(self) -> NumberOrArray:
"""determine the average of data
"""float or :class:`~numpy.ndarray`: the average of data
This is calculated by integrating each component of the field over space
and dividing by the grid volume
Expand All @@ -1528,7 +1552,7 @@ def average(self) -> NumberOrArray:

@property
def fluctuations(self) -> NumberOrArray:
""":class:`~numpy.ndarray`: fluctuations over the entire space.
"""float or :class:`~numpy.ndarray`: quantification of the average fluctuations
The fluctuations are defined as the standard deviation of the data scaled by the
cell volume. This definition makes the fluctuations independent of the
Expand All @@ -1547,7 +1571,7 @@ def fluctuations(self) -> NumberOrArray:

@property
def magnitude(self) -> float:
"""float: determine the magnitude of the field.
"""float: determine the (scalar) magnitude of the field
This is calculated by getting a scalar field using the default arguments of the
:func:`to_scalar` method, averaging the result over the whole grid, and taking
Expand Down Expand Up @@ -1584,12 +1608,14 @@ def apply_operator(
Optional field to which the result is written.
label (str, optional):
Name of the returned field
args (dict):
Additional arguments for the boundary conditions
**kwargs:
Additional arguments affecting how the operator behaves.
Returns:
Field data after applying the operator. This field is identical to `out` if
this argument was specified.
:class:`DataFieldBase`: Field data after applying the operator. This field
is identical to `out` if this argument was specified.
"""
# get information about the operator
operator_info = self.grid._get_operator_info(operator)
Expand Down Expand Up @@ -1637,6 +1663,8 @@ def _apply_operator(
Optional field to which the result is written.
label (str, optional):
Name of the returned field
args (dict):
Additional arguments for the boundary conditions
**kwargs:
Additional arguments affecting how the operator behaves.
Expand Down Expand Up @@ -1829,17 +1857,18 @@ def smooth(
) -> TDataField:
"""applies Gaussian smoothing with the given standard deviation
This function respects periodic boundary conditions of the underlying
grid, using reflection when no periodicity is specified.
sigma (float):
Gives the standard deviation of the smoothing in real length units
(default: 1)
out (FieldBase, optional):
Optional field into which the smoothed data is stored. Setting this
to the input field enables in-place smoothing.
label (str, optional):
Name of the returned field
This function respects periodic boundary conditions of the underlying grid,
using reflection when no periodicity is specified.
Args:
sigma (float):
Gives the standard deviation of the smoothing in real length units
(default: 1)
out (FieldBase, optional):
Optional field into which the smoothed data is stored. Setting this
to the input field enables in-place smoothing.
label (str, optional):
Name of the returned field
Returns:
Field with smoothed data. This is stored at `out` if given.
Expand Down
11 changes: 6 additions & 5 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

import numpy as np
from numpy.typing import DTypeLike

from ..grids.base import GridBase
from ..tools.docstrings import fill_in_docstring
Expand All @@ -49,7 +50,7 @@ def __init__(
copy_fields: bool = False,
label: Optional[str] = None,
labels: Union[List[Optional[str]], _FieldLabels, None] = None,
dtype=None,
dtype: DTypeLike = None,
):
"""
Args:
Expand Down Expand Up @@ -263,7 +264,7 @@ def from_dict(
*,
copy_fields: bool = False,
label: Optional[str] = None,
dtype=None,
dtype: DTypeLike = None,
) -> FieldCollection:
"""create a field collection from a dictionary of fields
Expand Down Expand Up @@ -326,7 +327,7 @@ def from_data(
with_ghost_cells: bool = True,
label: Optional[str] = None,
labels: Union[List[Optional[str]], _FieldLabels, None] = None,
dtype=None,
dtype: DTypeLike = None,
):
"""create a field collection from classes and data
Expand Down Expand Up @@ -423,7 +424,7 @@ def from_scalar_expressions(
consts: Optional[Dict[str, NumberOrArray]] = None,
label: Optional[str] = None,
labels: Optional[Sequence[str]] = None,
dtype=None,
dtype: DTypeLike = None,
) -> FieldCollection:
"""create a field collection on a grid from given expressions
Expand Down Expand Up @@ -570,7 +571,7 @@ def copy(
self: FieldCollection,
*,
label: Optional[str] = None,
dtype=None,
dtype: DTypeLike = None,
) -> FieldCollection:
"""return a copy of the data, but not of the grid
Expand Down
2 changes: 1 addition & 1 deletion pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def dot(
other (VectorField or Tensor2Field):
the second field
out (ScalarField or VectorField, optional):
Optional field to which the result is written.
Optional field to which the result is written.
conjugate (bool):
Whether to use the complex conjugate for the second operand
label (str, optional):
Expand Down
Loading

0 comments on commit 2b83ea6

Please sign in to comment.