Skip to content

Commit

Permalink
Renamed intervals to interrupts (#507)
Browse files Browse the repository at this point in the history
This addresses an old inconsistency described in #459. For now the old option should still work but raise a DeprecationWarning. We will probably remove the `interval` argument in about 6 months or so.

Closes #459
  • Loading branch information
david-zwicker authored Dec 23, 2023
1 parent df40854 commit 58c2625
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 140 deletions.
2 changes: 1 addition & 1 deletion examples/make_movie_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition

storage = MemoryStorage() # create storage
tracker = storage.tracker(interval=1) # create associated tracker
tracker = storage.tracker(interrupts=1) # create associated tracker

eq = DiffusionPDE() # define the physics
eq.solve(state, t_range=2, dt=0.005, tracker=tracker)
Expand Down
2 changes: 1 addition & 1 deletion examples/pde_brusselator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ def pde_rhs(state_data, t):
state = eq.get_initial_state(grid)

# simulate the pde
tracker = PlotTracker(interval=1, plot_args={"kind": "rgb", "vmin": 0, "vmax": 5})
tracker = PlotTracker(interrupts=1, plot_args={"kind": "rgb", "vmin": 0, "vmax": 5})
sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker)
2 changes: 1 addition & 1 deletion examples/pde_brusselator_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@
state = FieldCollection([u, v])

# simulate the pde
tracker = PlotTracker(interval=1, plot_args={"vmin": 0, "vmax": 5})
tracker = PlotTracker(interrupts=1, plot_args={"vmin": 0, "vmax": 5})
sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker)
2 changes: 1 addition & 1 deletion examples/pde_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ def evolution_rate(self, state, t=0):
state = eq.get_state(s, i)

# simulate the pde
tracker = PlotTracker(interval=10, plot_args={"vmin": 0, "vmax": 1})
tracker = PlotTracker(interrupts=10, plot_args={"vmin": 0, "vmax": 1})
sol = eq.solve(state, t_range=50, dt=1e-2, tracker=["progress", tracker])
4 changes: 2 additions & 2 deletions examples/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
trackers = [
"progress", # show progress bar during simulation
"steady_state", # abort when steady state is reached
storage.tracker(interval=1), # store data every simulation time unit
storage.tracker(interrupts=1), # store data every simulation time unit
pde.PlotTracker(show=True), # show images during simulation
# print some output every 5 real seconds:
pde.PrintTracker(interval=pde.RealtimeInterrupts(duration=5)),
pde.PrintTracker(interrupts=pde.RealtimeInterrupts(duration=5)),
]

eq = pde.DiffusionPDE(0.1) # define the PDE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"outputs": [],
"source": [
"# Show the evolution while computing it\n",
"eq.solve(state, t_range=1e3, dt=0.01, tracker=pde.PlotTracker(interval=100));"
"eq.solve(state, t_range=1e3, dt=0.01, tracker=pde.PlotTracker(interrupts=100));"
]
},
{
Expand All @@ -170,7 +170,7 @@
"# reduced output\n",
"trackers = [\n",
" 'progress',\n",
" pde.PrintTracker(interval='0:01') # print output roughly every real second\n",
" pde.PrintTracker(interrupts='0:01') # print output roughly every real second\n",
"]\n",
"\n",
"eq.solve(state, t_range=1e3, dt=0.01, tracker=trackers);"
Expand All @@ -190,7 +190,7 @@
"outputs": [],
"source": [
"storage = pde.MemoryStorage()\n",
"eq.solve(state, 100, dt=0.01, tracker=storage.tracker(interval=10))\n",
"eq.solve(state, 100, dt=0.01, tracker=storage.tracker(interrupts=10))\n",
"\n",
"for field in storage:\n",
" print(f\"{field.integral:.3g}, {field.fluctuations:.3g}\")"
Expand All @@ -203,7 +203,7 @@
"outputs": [],
"source": [
"storage_write = pde.FileStorage('simulation.hdf')\n",
"eq.solve(state, 100, dt=0.01, tracker=storage_write.tracker(interval=10));"
"eq.solve(state, 100, dt=0.01, tracker=storage_write.tracker(interrupts=10));"
]
},
{
Expand Down
18 changes: 9 additions & 9 deletions pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,17 @@ def solve(
(supported by :class:`~pde.solvers.ScipySolver` and
:class:`~pde.solvers.ExplicitSolver`), `dt` sets the initial time step.
tracker:
Defines a tracker that processes the state of the simulation at
specified times. A tracker is either an instance of
:class:`~pde.trackers.base.TrackerBase` or a string, which identifies a
tracker. All possible identifiers can be obtained by calling
:func:`~pde.trackers.base.get_named_trackers`. Multiple trackers can be
Defines trackers that process the state of the simulation at specified
times. A tracker is either an instance of
:class:`~pde.trackers.base.TrackerBase` or a string identifying a
tracker (possible identifiers can be obtained by calling
:func:`~pde.trackers.base.get_named_trackers`). Multiple trackers can be
specified as a list. The default value `auto` checks the state for
consistency (tracker 'consistency') and displays a progress bar (tracker
'progress'). More general trackers are defined in :mod:`~pde.trackers`,
where all options are explained in detail. In particular, the interval
at which the tracker is evaluated can be chosen when creating a tracker
object explicitly.
'progress') when :mod:`tqdm` is installed. More general trackers are
defined in :mod:`~pde.trackers`, where all options are explained in
detail. In particular, the time points where the tracker analyzes data
can be chosen when creating a tracker object explicitly.
solver (:class:`~pde.solvers.base.SolverBase` or str):
Specifies the method for solving the differential equation. This can
either be an instance of :class:`~pde.solvers.base.SolverBase` or a
Expand Down
4 changes: 1 addition & 3 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def registered_solvers(cls) -> List[str]: # @NoSelf
@property
def _compiled(self) -> bool:
"""bool: indicates whether functions need to be compiled"""
return (
self.backend == "numba" and not nb.config.DISABLE_JIT
) # @UndefinedVariable
return self.backend == "numba" and not nb.config.DISABLE_JIT

def _make_modify_after_step(
self, state: FieldBase
Expand Down
10 changes: 5 additions & 5 deletions pde/solvers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ def __init__(
Sets the time range for which the simulation is run. If only a single
value `t_end` is given, the time range is assumed to be `[0, t_end]`.
tracker:
Defines a tracker that process the state of the simulation at specified
Defines trackers that process the state of the simulation at specified
times. A tracker is either an instance of
:class:`~pde.trackers.base.TrackerBase` or a string, which identifies a
tracker. All possible identifiers can be obtained by calling
:func:`~pde.trackers.base.get_named_trackers`. Multiple trackers can be
:class:`~pde.trackers.base.TrackerBase` or a string identifying a
tracker (possible identifiers can be obtained by calling
:func:`~pde.trackers.base.get_named_trackers`). Multiple trackers can be
specified as a list. The default value `auto` checks the state for
consistency (tracker 'consistency') and displays a progress bar (tracker
'progress') when :mod:`tqdm` is installed. More general trackers are
defined in :mod:`~pde.trackers`, where all options are explained in
detail. In particular, the interval at which the tracker is evaluated
detail. In particular, the time points where the tracker analyzes data
can be chosen when creating a tracker object explicitly.
"""
self.solver = solver
Expand Down
23 changes: 14 additions & 9 deletions pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..tools.docstrings import fill_in_docstring
from ..tools.output import display_progress
from ..trackers.base import InfoDict, TrackerBase
from ..trackers.interrupts import InterruptsBase, IntervalData
from ..trackers.interrupts import InterruptData, InterruptsBase

if TYPE_CHECKING:
from .memory import MemoryStorage
Expand Down Expand Up @@ -281,15 +281,16 @@ def items(self) -> Iterator[Tuple[float, FieldBase]]:
@fill_in_docstring
def tracker(
self,
interval: int | float | InterruptsBase = 1,
interrupts: InterruptData = 1,
*,
transformation: Optional[Callable[[FieldBase, float], FieldBase]] = None,
interval=None,
) -> "StorageTracker":
"""create object that can be used as a tracker to fill this storage
Args:
interval:
{ARG_TRACKER_INTERVAL}
interrupts:
{ARG_TRACKER_INTERRUPT}
transformation (callable, optional):
A function that transforms the current state into a new field or field
collection, which is then stored. This allows to store derived
Expand Down Expand Up @@ -319,7 +320,10 @@ def add_to_state(state):
possible by defining appropriate :func:`add_to_state`
"""
return StorageTracker(
storage=self, interval=interval, transformation=transformation
storage=self,
interrupts=interrupts,
transformation=transformation,
interval=interval,
)

def start_writing(self, field: FieldBase, info: Optional[InfoDict] = None) -> None:
Expand Down Expand Up @@ -539,16 +543,17 @@ class StorageTracker(TrackerBase):
def __init__(
self,
storage,
interval: IntervalData = 1,
interrupts: InterruptData = 1,
*,
transformation: Optional[Callable[[FieldBase, float], FieldBase]] = None,
interval=None,
):
"""
Args:
storage (:class:`~pde.storage.base.StorageBase`):
Storage instance to which the data is written
interval:
{ARG_TRACKER_INTERVAL}
interrupts:
{ARG_TRACKER_INTERRUPT}
transformation (callable, optional):
A function that transforms the current state into a new field or field
collection, which is then stored. This allows to store derived
Expand All @@ -557,7 +562,7 @@ def __init__(
the current field, while the optional second argument is the associated
time.
"""
super().__init__(interval=interval)
super().__init__(interrupts=interrupts, interval=interval)
self.storage = storage
if transformation is not None and not callable(transformation):
raise TypeError("`transformation` must be callable")
Expand Down
13 changes: 7 additions & 6 deletions pde/tools/docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
More information can be found in the
:ref:`boundaries documentation <documentation-boundaries>`.
""",
"ARG_TRACKER_INTERVAL": """
Determines how often the tracker interrupts the simulation. Simple
numbers are interpreted as durations measured in the simulation time
variable. Alternatively, a string using the format 'hh:mm:ss' can be
used to give durations in real time. Finally, instances of the classes
defined in :mod:`~pde.trackers.interrupts` can be given for more control.
"ARG_TRACKER_INTERRUPT": """
Determines when the tracker interrupts the simulation. A single numbers
determines an interval (measured in the simulation time unit) of regular
interruption. A string is interpreted as a duration in real time assuming the
format 'hh:mm:ss'. A list of numbers is taken as explicit simulation time points.
More fine-grained contol is possible by passing an instance of classes defined
in :mod:`~pde.trackers.interrupts`.
""",
"ARG_PLOT_QUANTITIES": """
A 2d list of quantities that are shown in a rectangular arrangement.
Expand Down
22 changes: 11 additions & 11 deletions pde/trackers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@
~trackers.RuntimeTracker
~trackers.ConsistencyTracker
~interactive.InteractivePlotTracker
Some trackers can also be referenced by name for convenience when using them in
simulations. The lit of supported names is returned by
:func:`~pde.trackers.base.get_named_trackers`.
Multiple trackers can be collected in a :class:`~base.TrackerCollection`, which provides
methods for handling them efficiently. Moreover, custom trackers can be implemented by
deriving from :class:`~.trackers.base.TrackerBase`. Note that trackers generally receive
a view into the current state, implying that they can adjust the state by modifying it
in-place. Moreover, trackers can interrupt the simulation by raising the special
exception :class:`StopIteration`.
in-place. Moreover, trackers can abort the simulation by raising the special exception
:class:`StopIteration`.
For each tracker, the time intervals at which it is called can be decided using one
of the following classes, which determine when the simulation will be interrupted:
For each tracker, the time at which the simulation is interrupted can be decided using
one of the following classes:
.. autosummary::
:nosignatures:
~interrupts.FixedInterrupts
~interrupts.ConstantInterrupts
~interrupts.LogarithmicInterrupts
~interrupts.RealtimeInterrupts
In particular, interrupts can be specified conveniently using
:func:`~interrupts.interval_to_interrupts`.
:func:`~interrupts.parse_interrupt`.
.. codeauthor:: David Zwicker <[email protected]>
"""

Expand All @@ -55,6 +55,6 @@
FixedInterrupts,
LogarithmicInterrupts,
RealtimeInterrupts,
interval_to_interrupts,
parse_interrupt,
)
from .trackers import *
18 changes: 13 additions & 5 deletions pde/trackers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import math
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Type, Union

Expand All @@ -16,7 +17,7 @@
from ..fields.base import FieldBase
from ..tools.docstrings import fill_in_docstring
from ..tools.misc import module_available
from .interrupts import IntervalData, interval_to_interrupts
from .interrupts import InterruptData, parse_interrupt

InfoDict = Optional[Dict[str, Any]]
TrackerDataType = Union["TrackerBase", str]
Expand All @@ -32,13 +33,20 @@ class TrackerBase(metaclass=ABCMeta):
_subclasses: Dict[str, Type[TrackerBase]] = {} # all inheriting classes

@fill_in_docstring
def __init__(self, interval: IntervalData = 1):
def __init__(self, interrupts: InterruptData = 1, *, interval=None):
"""
Args:
interval:
{ARG_TRACKER_INTERVAL}
interrupts:
{ARG_TRACKER_INTERRUPT}
"""
self.interrupt = interval_to_interrupts(interval)
if interval is not None:
# deprecated on 2023-12-23
warnings.warn(
"Argument `interval` has been renamed to `interrupts`",
DeprecationWarning,
)
interrupts = interval
self.interrupt = parse_interrupt(interrupts)
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
Expand Down
13 changes: 7 additions & 6 deletions pde/trackers/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..tools.docstrings import fill_in_docstring
from ..tools.plotting import napari_add_layers
from .base import InfoDict, TrackerBase
from .interrupts import IntervalData
from .interrupts import InterruptData


def napari_process(
Expand Down Expand Up @@ -248,23 +248,24 @@ def main():
@fill_in_docstring
def __init__(
self,
interval: IntervalData = "0:01",
interrupts: InterruptData = "0:01",
*,
close: bool = True,
show_time: bool = False,
interval=None,
):
"""
Args:
interval:
{ARG_TRACKER_INTERVAL}
interrupts:
{ARG_TRACKER_INTERRUPT}
close (bool):
Flag indicating whether the napari window is closed automatically at the
end of the simulation. If `False`, the tracker blocks when `finalize` is
called until the user closes napari manually.
show_time (bool):
Whether to indicate the time
"""
# initialize the tracker
super().__init__(interval=interval)
super().__init__(interrupts=interrupts, interval=interval)
self.close = close
self.show_time = show_time

Expand Down
8 changes: 4 additions & 4 deletions pde/trackers/interrupts.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ def next(self, t: float) -> float:
return super().next(t)


IntervalData = Union[InterruptsBase, float, str, Sequence[float], np.ndarray]
InterruptData = Union[InterruptsBase, float, str, Sequence[float], np.ndarray]


def interval_to_interrupts(data: IntervalData) -> InterruptsBase:
"""create interrupt class from various data formats specifying time intervals
def parse_interrupt(data: InterruptData) -> InterruptsBase:
"""create interrupt class from various data formats
Args:
data (str or number or :class:`InterruptsBase`):
Expand All @@ -256,7 +256,7 @@ def interval_to_interrupts(data: IntervalData) -> InterruptsBase:
interpreted as :class:`FixedInterrupts`.
Returns:
:class:`InterruptsBase`: An instance that represents the time intervals
:class:`InterruptsBase`: An instance that represents the interrupt
"""
if isinstance(data, InterruptsBase):
return data
Expand Down
Loading

0 comments on commit 58c2625

Please sign in to comment.