diff --git a/examples/make_movie_storage.py b/examples/make_movie_storage.py index a612e6b4..f0fbd268 100644 --- a/examples/make_movie_storage.py +++ b/examples/make_movie_storage.py @@ -12,7 +12,7 @@ state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition storage = MemoryStorage() # create storage -tracker = storage.tracker(interval=1) # create associated tracker +tracker = storage.tracker(interrupts=1) # create associated tracker eq = DiffusionPDE() # define the physics eq.solve(state, t_range=2, dt=0.005, tracker=tracker) diff --git a/examples/pde_brusselator_class.py b/examples/pde_brusselator_class.py index 1dd96b47..8df5d4d4 100644 --- a/examples/pde_brusselator_class.py +++ b/examples/pde_brusselator_class.py @@ -74,5 +74,5 @@ def pde_rhs(state_data, t): state = eq.get_initial_state(grid) # simulate the pde -tracker = PlotTracker(interval=1, plot_args={"kind": "rgb", "vmin": 0, "vmax": 5}) +tracker = PlotTracker(interrupts=1, plot_args={"kind": "rgb", "vmin": 0, "vmax": 5}) sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker) diff --git a/examples/pde_brusselator_expression.py b/examples/pde_brusselator_expression.py index d7b2a2c4..e3a3daed 100644 --- a/examples/pde_brusselator_expression.py +++ b/examples/pde_brusselator_expression.py @@ -38,5 +38,5 @@ state = FieldCollection([u, v]) # simulate the pde -tracker = PlotTracker(interval=1, plot_args={"vmin": 0, "vmax": 5}) +tracker = PlotTracker(interrupts=1, plot_args={"vmin": 0, "vmax": 5}) sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker) diff --git a/examples/pde_sir.py b/examples/pde_sir.py index 61b7d4a6..a851c7cf 100644 --- a/examples/pde_sir.py +++ b/examples/pde_sir.py @@ -64,5 +64,5 @@ def evolution_rate(self, state, t=0): state = eq.get_state(s, i) # simulate the pde -tracker = PlotTracker(interval=10, plot_args={"vmin": 0, "vmax": 1}) +tracker = PlotTracker(interrupts=10, plot_args={"vmin": 0, "vmax": 1}) sol = eq.solve(state, t_range=50, dt=1e-2, tracker=["progress", tracker]) diff --git a/examples/trackers.py b/examples/trackers.py index da7bb392..dcdfbae4 100644 --- a/examples/trackers.py +++ b/examples/trackers.py @@ -15,10 +15,10 @@ trackers = [ "progress", # show progress bar during simulation "steady_state", # abort when steady state is reached - storage.tracker(interval=1), # store data every simulation time unit + storage.tracker(interrupts=1), # store data every simulation time unit pde.PlotTracker(show=True), # show images during simulation # print some output every 5 real seconds: - pde.PrintTracker(interval=pde.RealtimeInterrupts(duration=5)), + pde.PrintTracker(interrupts=pde.RealtimeInterrupts(duration=5)), ] eq = pde.DiffusionPDE(0.1) # define the PDE diff --git a/examples/tutorial/Tutorial 2 - Solving pre-defined partial differential equations.ipynb b/examples/tutorial/Tutorial 2 - Solving pre-defined partial differential equations.ipynb index 73277c7b..a8b175ac 100644 --- a/examples/tutorial/Tutorial 2 - Solving pre-defined partial differential equations.ipynb +++ b/examples/tutorial/Tutorial 2 - Solving pre-defined partial differential equations.ipynb @@ -154,7 +154,7 @@ "outputs": [], "source": [ "# Show the evolution while computing it\n", - "eq.solve(state, t_range=1e3, dt=0.01, tracker=pde.PlotTracker(interval=100));" + "eq.solve(state, t_range=1e3, dt=0.01, tracker=pde.PlotTracker(interrupts=100));" ] }, { @@ -170,7 +170,7 @@ "# reduced output\n", "trackers = [\n", " 'progress',\n", - " pde.PrintTracker(interval='0:01') # print output roughly every real second\n", + " pde.PrintTracker(interrupts='0:01') # print output roughly every real second\n", "]\n", "\n", "eq.solve(state, t_range=1e3, dt=0.01, tracker=trackers);" @@ -190,7 +190,7 @@ "outputs": [], "source": [ "storage = pde.MemoryStorage()\n", - "eq.solve(state, 100, dt=0.01, tracker=storage.tracker(interval=10))\n", + "eq.solve(state, 100, dt=0.01, tracker=storage.tracker(interrupts=10))\n", "\n", "for field in storage:\n", " print(f\"{field.integral:.3g}, {field.fluctuations:.3g}\")" @@ -203,7 +203,7 @@ "outputs": [], "source": [ "storage_write = pde.FileStorage('simulation.hdf')\n", - "eq.solve(state, 100, dt=0.01, tracker=storage_write.tracker(interval=10));" + "eq.solve(state, 100, dt=0.01, tracker=storage_write.tracker(interrupts=10));" ] }, { diff --git a/pde/pdes/base.py b/pde/pdes/base.py index 84a15d3a..1006361b 100644 --- a/pde/pdes/base.py +++ b/pde/pdes/base.py @@ -548,17 +548,17 @@ def solve( (supported by :class:`~pde.solvers.ScipySolver` and :class:`~pde.solvers.ExplicitSolver`), `dt` sets the initial time step. tracker: - Defines a tracker that processes the state of the simulation at - specified times. A tracker is either an instance of - :class:`~pde.trackers.base.TrackerBase` or a string, which identifies a - tracker. All possible identifiers can be obtained by calling - :func:`~pde.trackers.base.get_named_trackers`. Multiple trackers can be + Defines trackers that process the state of the simulation at specified + times. A tracker is either an instance of + :class:`~pde.trackers.base.TrackerBase` or a string identifying a + tracker (possible identifiers can be obtained by calling + :func:`~pde.trackers.base.get_named_trackers`). Multiple trackers can be specified as a list. The default value `auto` checks the state for consistency (tracker 'consistency') and displays a progress bar (tracker - 'progress'). More general trackers are defined in :mod:`~pde.trackers`, - where all options are explained in detail. In particular, the interval - at which the tracker is evaluated can be chosen when creating a tracker - object explicitly. + 'progress') when :mod:`tqdm` is installed. More general trackers are + defined in :mod:`~pde.trackers`, where all options are explained in + detail. In particular, the time points where the tracker analyzes data + can be chosen when creating a tracker object explicitly. solver (:class:`~pde.solvers.base.SolverBase` or str): Specifies the method for solving the differential equation. This can either be an instance of :class:`~pde.solvers.base.SolverBase` or a diff --git a/pde/solvers/base.py b/pde/solvers/base.py index 523a9aea..00944417 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -112,9 +112,7 @@ def registered_solvers(cls) -> List[str]: # @NoSelf @property def _compiled(self) -> bool: """bool: indicates whether functions need to be compiled""" - return ( - self.backend == "numba" and not nb.config.DISABLE_JIT - ) # @UndefinedVariable + return self.backend == "numba" and not nb.config.DISABLE_JIT def _make_modify_after_step( self, state: FieldBase diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py index 840706a1..cf8bace7 100644 --- a/pde/solvers/controller.py +++ b/pde/solvers/controller.py @@ -57,16 +57,16 @@ def __init__( Sets the time range for which the simulation is run. If only a single value `t_end` is given, the time range is assumed to be `[0, t_end]`. tracker: - Defines a tracker that process the state of the simulation at specified + Defines trackers that process the state of the simulation at specified times. A tracker is either an instance of - :class:`~pde.trackers.base.TrackerBase` or a string, which identifies a - tracker. All possible identifiers can be obtained by calling - :func:`~pde.trackers.base.get_named_trackers`. Multiple trackers can be + :class:`~pde.trackers.base.TrackerBase` or a string identifying a + tracker (possible identifiers can be obtained by calling + :func:`~pde.trackers.base.get_named_trackers`). Multiple trackers can be specified as a list. The default value `auto` checks the state for consistency (tracker 'consistency') and displays a progress bar (tracker 'progress') when :mod:`tqdm` is installed. More general trackers are defined in :mod:`~pde.trackers`, where all options are explained in - detail. In particular, the interval at which the tracker is evaluated + detail. In particular, the time points where the tracker analyzes data can be chosen when creating a tracker object explicitly. """ self.solver = solver diff --git a/pde/storage/base.py b/pde/storage/base.py index 25f26c41..d36ca7e6 100644 --- a/pde/storage/base.py +++ b/pde/storage/base.py @@ -30,7 +30,7 @@ from ..tools.docstrings import fill_in_docstring from ..tools.output import display_progress from ..trackers.base import InfoDict, TrackerBase -from ..trackers.interrupts import InterruptsBase, IntervalData +from ..trackers.interrupts import InterruptData, InterruptsBase if TYPE_CHECKING: from .memory import MemoryStorage @@ -281,15 +281,16 @@ def items(self) -> Iterator[Tuple[float, FieldBase]]: @fill_in_docstring def tracker( self, - interval: int | float | InterruptsBase = 1, + interrupts: InterruptData = 1, *, transformation: Optional[Callable[[FieldBase, float], FieldBase]] = None, + interval=None, ) -> "StorageTracker": """create object that can be used as a tracker to fill this storage Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} transformation (callable, optional): A function that transforms the current state into a new field or field collection, which is then stored. This allows to store derived @@ -319,7 +320,10 @@ def add_to_state(state): possible by defining appropriate :func:`add_to_state` """ return StorageTracker( - storage=self, interval=interval, transformation=transformation + storage=self, + interrupts=interrupts, + transformation=transformation, + interval=interval, ) def start_writing(self, field: FieldBase, info: Optional[InfoDict] = None) -> None: @@ -539,16 +543,17 @@ class StorageTracker(TrackerBase): def __init__( self, storage, - interval: IntervalData = 1, + interrupts: InterruptData = 1, *, transformation: Optional[Callable[[FieldBase, float], FieldBase]] = None, + interval=None, ): """ Args: storage (:class:`~pde.storage.base.StorageBase`): Storage instance to which the data is written - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} transformation (callable, optional): A function that transforms the current state into a new field or field collection, which is then stored. This allows to store derived @@ -557,7 +562,7 @@ def __init__( the current field, while the optional second argument is the associated time. """ - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.storage = storage if transformation is not None and not callable(transformation): raise TypeError("`transformation` must be callable") diff --git a/pde/tools/docstrings.py b/pde/tools/docstrings.py index b489650e..46bcd5ab 100644 --- a/pde/tools/docstrings.py +++ b/pde/tools/docstrings.py @@ -36,12 +36,13 @@ More information can be found in the :ref:`boundaries documentation `. """, - "ARG_TRACKER_INTERVAL": """ - Determines how often the tracker interrupts the simulation. Simple - numbers are interpreted as durations measured in the simulation time - variable. Alternatively, a string using the format 'hh:mm:ss' can be - used to give durations in real time. Finally, instances of the classes - defined in :mod:`~pde.trackers.interrupts` can be given for more control. + "ARG_TRACKER_INTERRUPT": """ + Determines when the tracker interrupts the simulation. A single numbers + determines an interval (measured in the simulation time unit) of regular + interruption. A string is interpreted as a duration in real time assuming the + format 'hh:mm:ss'. A list of numbers is taken as explicit simulation time points. + More fine-grained contol is possible by passing an instance of classes defined + in :mod:`~pde.trackers.interrupts`. """, "ARG_PLOT_QUANTITIES": """ A 2d list of quantities that are shown in a rectangular arrangement. diff --git a/pde/trackers/__init__.py b/pde/trackers/__init__.py index 45d17c02..7d481a92 100644 --- a/pde/trackers/__init__.py +++ b/pde/trackers/__init__.py @@ -17,33 +17,33 @@ ~trackers.RuntimeTracker ~trackers.ConsistencyTracker ~interactive.InteractivePlotTracker - + Some trackers can also be referenced by name for convenience when using them in simulations. The lit of supported names is returned by :func:`~pde.trackers.base.get_named_trackers`. - + Multiple trackers can be collected in a :class:`~base.TrackerCollection`, which provides methods for handling them efficiently. Moreover, custom trackers can be implemented by deriving from :class:`~.trackers.base.TrackerBase`. Note that trackers generally receive a view into the current state, implying that they can adjust the state by modifying it -in-place. Moreover, trackers can interrupt the simulation by raising the special -exception :class:`StopIteration`. +in-place. Moreover, trackers can abort the simulation by raising the special exception +:class:`StopIteration`. -For each tracker, the time intervals at which it is called can be decided using one -of the following classes, which determine when the simulation will be interrupted: +For each tracker, the time at which the simulation is interrupted can be decided using +one of the following classes: .. autosummary:: :nosignatures: - + ~interrupts.FixedInterrupts ~interrupts.ConstantInterrupts ~interrupts.LogarithmicInterrupts ~interrupts.RealtimeInterrupts - + In particular, interrupts can be specified conveniently using -:func:`~interrupts.interval_to_interrupts`. - +:func:`~interrupts.parse_interrupt`. + .. codeauthor:: David Zwicker """ @@ -55,6 +55,6 @@ FixedInterrupts, LogarithmicInterrupts, RealtimeInterrupts, - interval_to_interrupts, + parse_interrupt, ) from .trackers import * diff --git a/pde/trackers/base.py b/pde/trackers/base.py index 7e434765..896ee5ce 100644 --- a/pde/trackers/base.py +++ b/pde/trackers/base.py @@ -8,6 +8,7 @@ import logging import math +import warnings from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Optional, Sequence, Type, Union @@ -16,7 +17,7 @@ from ..fields.base import FieldBase from ..tools.docstrings import fill_in_docstring from ..tools.misc import module_available -from .interrupts import IntervalData, interval_to_interrupts +from .interrupts import InterruptData, parse_interrupt InfoDict = Optional[Dict[str, Any]] TrackerDataType = Union["TrackerBase", str] @@ -32,13 +33,20 @@ class TrackerBase(metaclass=ABCMeta): _subclasses: Dict[str, Type[TrackerBase]] = {} # all inheriting classes @fill_in_docstring - def __init__(self, interval: IntervalData = 1): + def __init__(self, interrupts: InterruptData = 1, *, interval=None): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} """ - self.interrupt = interval_to_interrupts(interval) + if interval is not None: + # deprecated on 2023-12-23 + warnings.warn( + "Argument `interval` has been renamed to `interrupts`", + DeprecationWarning, + ) + interrupts = interval + self.interrupt = parse_interrupt(interrupts) self._logger = logging.getLogger(self.__class__.__name__) def __init_subclass__(cls, **kwargs): # @NoSelf diff --git a/pde/trackers/interactive.py b/pde/trackers/interactive.py index 18516b50..7c748c82 100644 --- a/pde/trackers/interactive.py +++ b/pde/trackers/interactive.py @@ -16,7 +16,7 @@ from ..tools.docstrings import fill_in_docstring from ..tools.plotting import napari_add_layers from .base import InfoDict, TrackerBase -from .interrupts import IntervalData +from .interrupts import InterruptData def napari_process( @@ -248,14 +248,16 @@ def main(): @fill_in_docstring def __init__( self, - interval: IntervalData = "0:01", + interrupts: InterruptData = "0:01", + *, close: bool = True, show_time: bool = False, + interval=None, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} close (bool): Flag indicating whether the napari window is closed automatically at the end of the simulation. If `False`, the tracker blocks when `finalize` is @@ -263,8 +265,7 @@ def __init__( show_time (bool): Whether to indicate the time """ - # initialize the tracker - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.close = close self.show_time = show_time diff --git a/pde/trackers/interrupts.py b/pde/trackers/interrupts.py index 884e00e5..8dcee4ae 100644 --- a/pde/trackers/interrupts.py +++ b/pde/trackers/interrupts.py @@ -242,11 +242,11 @@ def next(self, t: float) -> float: return super().next(t) -IntervalData = Union[InterruptsBase, float, str, Sequence[float], np.ndarray] +InterruptData = Union[InterruptsBase, float, str, Sequence[float], np.ndarray] -def interval_to_interrupts(data: IntervalData) -> InterruptsBase: - """create interrupt class from various data formats specifying time intervals +def parse_interrupt(data: InterruptData) -> InterruptsBase: + """create interrupt class from various data formats Args: data (str or number or :class:`InterruptsBase`): @@ -256,7 +256,7 @@ def interval_to_interrupts(data: IntervalData) -> InterruptsBase: interpreted as :class:`FixedInterrupts`. Returns: - :class:`InterruptsBase`: An instance that represents the time intervals + :class:`InterruptsBase`: An instance that represents the interrupt """ if isinstance(data, InterruptsBase): return data diff --git a/pde/trackers/trackers.py b/pde/trackers/trackers.py index f4e96621..05f9e336 100644 --- a/pde/trackers/trackers.py +++ b/pde/trackers/trackers.py @@ -41,7 +41,7 @@ from ..tools.parse_duration import parse_duration from ..tools.typing import Real from .base import FinishedSimulation, InfoDict, TrackerBase -from .interrupts import IntervalData, RealtimeInterrupts +from .interrupts import InterruptData, RealtimeInterrupts if TYPE_CHECKING: import pandas @@ -69,7 +69,13 @@ def check_simulation(state, time): """ @fill_in_docstring - def __init__(self, func: Callable, interval: IntervalData = 1): + def __init__( + self, + func: Callable, + interrupts: InterruptData = 1, + *, + interval=None, + ): """ Args: func: @@ -81,10 +87,10 @@ def __init__(self, func: Callable, interval: IntervalData = 1): should be stored. The function can thus adjust the state by modifying it in-place and it can even interrupt the simulation by raising the special exception :class:`StopIteration`. - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} """ - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self._callback = func self._num_args = len(inspect.signature(func).parameters) if not 0 < self._num_args < 3: @@ -116,16 +122,17 @@ class ProgressTracker(TrackerBase): @fill_in_docstring def __init__( self, - interval: Optional[IntervalData] = None, + interrupts: Optional[InterruptData] = None, *, fancy: bool = True, ndigits: int = 5, leave: bool = True, + interval=None, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} The default value `None` updates the progress bar approximately every (real) second. fancy (bool): @@ -137,10 +144,10 @@ def __init__( Whether to leave the progress bar after the simulation has finished (default: True) """ - if interval is None: - interval = RealtimeInterrupts(duration=1) # print every second by default + if interrupts is None: + interrupts = RealtimeInterrupts(duration=1) # print every second by default - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.fancy = fancy self.ndigits = ndigits self.leave = leave @@ -237,16 +244,22 @@ class PrintTracker(TrackerBase): name = "print" @fill_in_docstring - def __init__(self, interval: IntervalData = 1, stream: IO[str] = sys.stdout): + def __init__( + self, + interrupts: InterruptData = 1, + stream: IO[str] = sys.stdout, + *, + interval=None, + ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} stream: The stream used for printing """ - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.stream = stream def handle(self, field: FieldBase, t: float) -> None: @@ -286,7 +299,7 @@ class PlotTracker(TrackerBase): @fill_in_docstring def __init__( self, - interval: IntervalData = 1, + interrupts: InterruptData = 1, *, title: str | Callable = "Time: {time:g}", output_file: Optional[str] = None, @@ -295,11 +308,12 @@ def __init__( tight_layout: bool = False, max_fps: float = math.inf, plot_args: Optional[Dict[str, Any]] = None, + interval=None, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} title (str or callable): Title text of the figure. If this is a string, it is shown with a potential placeholder named `time` being replaced by the current @@ -353,7 +367,7 @@ def __init__( from ..visualization.movies import Movie # @Reimport # initialize the tracker - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.title = title self.output_file = output_file self.tight_layout = tight_layout @@ -530,16 +544,17 @@ class LivePlotTracker(PlotTracker): @fill_in_docstring def __init__( self, - interval: IntervalData = "0:03", + interrupts: InterruptData = "0:03", *, show: bool = True, max_fps: float = 2, + interval=None, **kwargs, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} title (str): Text to show in the title. The current time point will be appended to this text, so include a space for optimal results. @@ -567,7 +582,13 @@ def __init__( instance, the value `{'ax_style': {'ylim': (0, 1)}}` enforces the y-axis to lie between 0 and 1. """ - super().__init__(interval=interval, show=show, max_fps=max_fps, **kwargs) + super().__init__( + interrupts=interrupts, + interval=interval, + show=show, + max_fps=max_fps, + **kwargs, + ) class DataTracker(CallbackTracker): @@ -598,7 +619,12 @@ def get_statistics(state, time): @fill_in_docstring def __init__( - self, func: Callable, interval: IntervalData = 1, filename: Optional[str] = None + self, + func: Callable, + interrupts: InterruptData = 1, + *, + filename: Optional[str] = None, + interval=None, ): """ Args: @@ -613,8 +639,8 @@ def __init__( Typical return values of the function are either a single number, a numpy array, a list of number, or a dictionary to return multiple numbers with assigned labels. - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} filename (str): A path to a file to which the data is written at the end of the tracking. The data format will be determined by the extension @@ -622,7 +648,7 @@ def __init__( storing a tuple `(self.times, self.data)`, whereas any other data format requires :mod:`pandas`. """ - super().__init__(func=func, interval=interval) + super().__init__(func=func, interrupts=interrupts, interval=interval) self.filename = filename self.times: List[float] = [] self.data: List[Any] = [] @@ -722,17 +748,18 @@ class SteadyStateTracker(TrackerBase): @fill_in_docstring def __init__( self, - interval: Optional[IntervalData] = None, + interrupts: Optional[InterruptData] = None, atol: float = 1e-8, rtol: float = 1e-5, *, progress: bool = False, evolution_rate: Optional[Callable[[np.ndarray, float], np.ndarray]] = None, + interval=None, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} The default value `None` checks for the steady state approximately every (real) second. atol (float): @@ -748,9 +775,9 @@ def __init__( can be less accurate. A suitable form of the function is returned by `eq.make_pde_rhs(state)` when `eq` is the PDE class. """ - if interval is None: - interval = RealtimeInterrupts(duration=1) - super().__init__(interval=interval) + if interrupts is None: + interrupts = RealtimeInterrupts(duration=1) + super().__init__(interrupts=interrupts, interval=interval) self.atol = atol self.rtol = rtol self.evolution_rate = evolution_rate @@ -832,7 +859,9 @@ class RuntimeTracker(TrackerBase): """Tracker interrupting the simulation once a duration has passed""" @fill_in_docstring - def __init__(self, max_runtime: Real | str, interval: IntervalData = 1): + def __init__( + self, max_runtime: Real | str, interrupts: InterruptData = 1, *, interval=None + ): """ Args: max_runtime (float or str): @@ -840,10 +869,10 @@ def __init__(self, max_runtime: Real | str, interval: IntervalData = 1): simulation is interrupted. Values can be either given as a number (interpreted as seconds) or as a string, which is then parsed using the function :func:`~pde.tools.parse_duration.parse_duration`. - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} """ - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) try: self.max_runtime = float(max_runtime) @@ -885,17 +914,17 @@ class ConsistencyTracker(TrackerBase): name = "consistency" @fill_in_docstring - def __init__(self, interval: Optional[IntervalData] = None): + def __init__(self, interrupts: Optional[InterruptData] = None, *, interval=None): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} The default value `None` checks for consistency approximately every (real) second. """ - if interval is None: - interval = RealtimeInterrupts(duration=1) - super().__init__(interval=interval) + if interrupts is None: + interrupts = RealtimeInterrupts(duration=1) + super().__init__(interrupts=interrupts, interval=interval) def handle(self, field: FieldBase, t: float) -> None: """handle data supplied to this tracker @@ -917,18 +946,23 @@ class MaterialConservationTracker(TrackerBase): @fill_in_docstring def __init__( - self, interval: IntervalData = 1, atol: float = 1e-4, rtol: float = 1e-4 + self, + interrupts: InterruptData = 1, + atol: float = 1e-4, + rtol: float = 1e-4, + *, + interval=None, ): """ Args: - interval: - {ARG_TRACKER_INTERVAL} + interrupts: + {ARG_TRACKER_INTERRUPT} atol (float): Absolute tolerance for amount deviations rtol (float): Relative tolerance for amount deviations """ - super().__init__(interval=interval) + super().__init__(interrupts=interrupts, interval=interval) self.atol = atol self.rtol = rtol diff --git a/tests/storage/test_file_storages.py b/tests/storage/test_file_storages.py index 759641a9..a44ae5e6 100644 --- a/tests/storage/test_file_storages.py +++ b/tests/storage/test_file_storages.py @@ -87,7 +87,7 @@ def test_simulation_persistence(compression, tmp_path, rng): pde = DiffusionPDE() grid = UnitGrid([16, 16]) # generate grid state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) - pde.solve(state, t_range=0.11, dt=0.001, tracker=storage.tracker(interval=0.05)) + pde.solve(state, t_range=0.11, dt=0.001, tracker=storage.tracker(interrupts=0.05)) storage.close() # read the data diff --git a/tests/storage/test_generic_storages.py b/tests/storage/test_generic_storages.py index 745fb9b9..b985f3c7 100644 --- a/tests/storage/test_generic_storages.py +++ b/tests/storage/test_generic_storages.py @@ -72,7 +72,7 @@ def test_storage_truncation(tmp_path, rng): storages = [MemoryStorage()] if module_available("h5py"): storages.append(FileStorage(file)) - tracker_list = [s.tracker(interval=0.01) for s in storages] + tracker_list = [s.tracker(interrupts=0.01) for s in storages] grid = UnitGrid([8, 8]) state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) diff --git a/tests/trackers/test_interrupts.py b/tests/trackers/test_interrupts.py index 21a6843b..50706688 100644 --- a/tests/trackers/test_interrupts.py +++ b/tests/trackers/test_interrupts.py @@ -10,7 +10,7 @@ FixedInterrupts, LogarithmicInterrupts, RealtimeInterrupts, - interval_to_interrupts, + parse_interrupt, ) @@ -35,7 +35,7 @@ def test_interrupt_constant(): assert ival3.next(3) == pytest.approx(6) assert ival3.dt == 2 - ival = interval_to_interrupts(2) + ival = parse_interrupt(2) assert ival.initialize(1) == pytest.approx(1) assert ival.next(3) == pytest.approx(3) assert ival.next(3) == pytest.approx(5) @@ -67,7 +67,7 @@ def test_interrupt_logarithmic(): def test_interrupt_realtime(): """test the RealtimeInterrupts class""" - for ival in [RealtimeInterrupts("0:01"), interval_to_interrupts("0:01")]: + for ival in [RealtimeInterrupts("0:01"), parse_interrupt("0:01")]: assert ival.initialize(0) == pytest.approx(0) i1, i2, i3 = ival.next(1), ival.next(1), ival.next(1) assert i3 > i2 > i1 > 0 @@ -95,10 +95,10 @@ def test_interrupt_fixed(): assert ival.next(6) == pytest.approx(7) assert ival.dt == 6 - ival = interval_to_interrupts([1, 3]) + ival = parse_interrupt([1, 3]) assert np.isinf(ival.initialize(4)) - ival = interval_to_interrupts(np.arange(3)) + ival = parse_interrupt(np.arange(3)) assert ival.initialize(0) == pytest.approx(0) assert ival.dt == 0 assert ival.next(0) == pytest.approx(1) diff --git a/tests/trackers/test_trackers.py b/tests/trackers/test_trackers.py index d11d25e2..08438fa0 100644 --- a/tests/trackers/test_trackers.py +++ b/tests/trackers/test_trackers.py @@ -30,7 +30,7 @@ def get_title(state, t): tracker = trackers.PlotTracker( output_file=output_file, title=get_title, - interval=0.1, + interrupts=0.1, show=False, tight_layout=True, ) @@ -49,7 +49,7 @@ def test_plot_movie_tracker(tmp_path, rng): state = ScalarField.random_uniform(grid, rng=rng) eq = DiffusionPDE() tracker = trackers.PlotTracker( - movie=output_file, interval=0.1, show=False, tight_layout=True + movie=output_file, interrupts=0.1, show=False, tight_layout=True ) eq.solve(state, t_range=0.5, dt=0.005, tracker=tracker, backend="numpy") @@ -59,7 +59,7 @@ def test_plot_movie_tracker(tmp_path, rng): def test_simple_progress(): """simple test for basic progress bar""" - pbar = trackers.ProgressTracker(interval=1) + pbar = trackers.ProgressTracker(interrupts=1) field = ScalarField(UnitGrid([3])) pbar.initialize(field) pbar.handle(field, 2) @@ -77,15 +77,15 @@ def get_data(state): return {"integral": state.integral} devnull = open(os.devnull, "w") - data = trackers.DataTracker(get_data, interval=0.1) + data = trackers.DataTracker(get_data, interrupts=0.1) tracker_list = [ - trackers.PrintTracker(interval=0.1, stream=devnull), - trackers.CallbackTracker(store_time, interval=0.1), + trackers.PrintTracker(interrupts=0.1, stream=devnull), + trackers.CallbackTracker(store_time, interrupts=0.1), None, # should be ignored data, ] if module_available("matplotlib"): - tracker_list.append(trackers.PlotTracker(interval=0.1, show=False)) + tracker_list.append(trackers.PlotTracker(interrupts=0.1, show=False)) grid = UnitGrid([16, 16]) state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) @@ -114,8 +114,8 @@ def get_mean_data(state): grid = UnitGrid([4, 4]) state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) eq = DiffusionPDE() - data_tracker = trackers.DataTracker(get_mean_data, interval=0.1) - callback_tracker = trackers.CallbackTracker(store_mean_data, interval=0.1) + data_tracker = trackers.DataTracker(get_mean_data, interrupts=0.1) + callback_tracker = trackers.CallbackTracker(store_mean_data, interrupts=0.1) tracker_list = [data_tracker, callback_tracker] eq.solve(state, t_range=0.5, dt=0.005, tracker=tracker_list, backend="numpy") @@ -132,8 +132,8 @@ def get_time(state, t): grid = UnitGrid([4, 4]) state = ScalarField.random_uniform(grid, 0.2, 0.3, rng=rng) eq = DiffusionPDE() - data_tracker = trackers.DataTracker(get_time, interval=0.1) - tracker_list = [trackers.CallbackTracker(store_time, interval=0.1), data_tracker] + data_tracker = trackers.DataTracker(get_time, interrupts=0.1) + tracker_list = [trackers.CallbackTracker(store_time, interrupts=0.1), data_tracker] eq.solve(state, t_range=0.5, dt=0.005, tracker=tracker_list, backend="numpy") ts = np.arange(0, 0.55, 0.1) @@ -168,7 +168,7 @@ def test_steady_state_tracker(): # use basic form tracker = trackers.SteadyStateTracker(atol=0.05, rtol=0.05, progress=True) - eq.solve(c0, 1e4, dt=0.1, tracker=[tracker, storage.tracker(interval=1e2)]) + eq.solve(c0, 1e4, dt=0.1, tracker=[tracker, storage.tracker(interrupts=1e2)]) assert len(storage) < 20 # finished early # use form with the evolution rate supplied @@ -178,7 +178,7 @@ def test_steady_state_tracker(): progress=True, evolution_rate=eq.make_pde_rhs(c0, backend="numpy"), ) - eq.solve(c0, 1e4, dt=0.1, tracker=[tracker, storage.tracker(interval=1e2)]) + eq.solve(c0, 1e4, dt=0.1, tracker=[tracker, storage.tracker(interrupts=1e2)]) assert len(storage) < 20 # finished early @@ -189,7 +189,7 @@ def test_small_tracker_dt(rng): eq = DiffusionPDE() c0 = ScalarField.random_uniform(UnitGrid([4, 4]), 0.1, 0.2, rng=rng) eq.solve( - c0, 1e-2, dt=1e-3, solver="explicit", tracker=storage.tracker(interval=1e-4) + c0, 1e-2, dt=1e-3, solver="explicit", tracker=storage.tracker(interrupts=1e-4) ) assert len(storage) == 11 @@ -238,10 +238,10 @@ def test_get_named_trackers(): def test_double_tracker(rng): """simple test for using a custom tracker twice""" - interval = ConstantInterrupts(1) + interrupts = ConstantInterrupts(1) times1, times2 = [], [] - t1 = trackers.CallbackTracker(lambda s, t: times1.append(t), interval=interval) - t2 = trackers.CallbackTracker(lambda s, t: times2.append(t), interval=interval) + t1 = trackers.CallbackTracker(lambda s, t: times1.append(t), interrupts=interrupts) + t2 = trackers.CallbackTracker(lambda s, t: times2.append(t), interrupts=interrupts) field = ScalarField.random_uniform(UnitGrid([3]), rng=rng) DiffusionPDE().solve(field, t_range=4, dt=0.1, tracker=[t1, t2]) diff --git a/tests/visualization/test_movies.py b/tests/visualization/test_movies.py index 7ca7b41f..6df5d79e 100644 --- a/tests/visualization/test_movies.py +++ b/tests/visualization/test_movies.py @@ -41,7 +41,7 @@ def test_movie_scalar(movie_func, tmp_path, rng): state = ScalarField.random_uniform(UnitGrid([4, 4]), rng=rng) eq = DiffusionPDE() storage = MemoryStorage() - tracker = storage.tracker(interval=1) + tracker = storage.tracker(interrupts=1) eq.solve(state, t_range=2, dt=1e-2, backend="numpy", tracker=tracker) # check creating the movie