Skip to content

Commit

Permalink
Allow changing default time step (#413)
Browse files Browse the repository at this point in the history
* Additional minor changes in solvers to provide more flexibility when
subclassing
  • Loading branch information
david-zwicker authored May 26, 2023
1 parent 03268b1 commit 91ae49e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
39 changes: 23 additions & 16 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
class SolverBase(metaclass=ABCMeta):
"""base class for solvers"""

dt_default: float = 1e-3
"""float: default time step used if no time step was specified"""

_modify_state_after_step: bool = True
"""bool: flag choosing whether the `modify_after_step` hook of the PDE is called"""

Expand All @@ -47,10 +50,9 @@ def __init__(self, pde: PDEBase, *, backend: str = "auto"):
"""
self.pde = pde
self.backend = backend
self.info: Dict[str, Any] = {
"class": self.__class__.__name__,
"pde_class": self.pde.__class__.__name__,
}
self.info: Dict[str, Any] = {"class": self.__class__.__name__}
if self.pde:
self.info["pde_class"] = self.pde.__class__.__name__
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
Expand Down Expand Up @@ -160,7 +162,7 @@ def _make_pde_rhs(
time. The function returns the deterministic evolution rate and (if
applicable) a realization of the associated noise.
"""
if self.pde.is_sde:
if getattr(self.pde, "is_sde"):
raise RuntimeError(
f"Cannot create a deterministic stepper for a stochastic equation"
)
Expand Down Expand Up @@ -237,6 +239,7 @@ def _make_fixed_stepper(
Time step of the explicit stepping.
"""
single_step = self._make_single_step_fixed_dt(state, dt)
modify_state_after_step = self._modify_state_after_step
modify_after_step = self._make_modify_after_step(state)

if self._compiled:
Expand All @@ -252,7 +255,8 @@ def fixed_stepper(
# calculate the right hand side
t = t_start + i * dt
single_step(state_data, t)
modifications += modify_after_step(state_data)
if modify_state_after_step:
modifications += modify_after_step(state_data)

return t + dt, modifications

Expand All @@ -272,8 +276,7 @@ def make_stepper(
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the explicit stepping. If `None`, this solver specifies
1e-3 as a default value.
Time step used (Uses :attr:`SolverBase.dt_default` if `None`)
Returns:
Function that can be called to advance the `state` from time `t_start` to
Expand All @@ -283,7 +286,7 @@ def make_stepper(
# support `None` as a default value, so the controller can signal that
# the solver should use a default time step
if dt is None:
dt = 1e-3
dt = self.dt_default
self._logger.warning(
"Explicit stepper with a fixed time step did not receive any "
f"initial value for `dt`. Using dt={dt}, but specifying a value or "
Expand All @@ -293,8 +296,10 @@ def make_stepper(

self.info["dt"] = dt_float
self.info["steps"] = 0
self.info["stochastic"] = self.pde.is_sde
self.info["state_modifications"] = 0.0
self.info["stochastic"] = getattr(self.pde, "is_sde", False)
# we don't access self.pde directly since we might want to reuse the solver
# infrastructure for more general cases where a PDE is not defined

# create stepper with fixed steps
fixed_stepper = self._make_fixed_stepper(state, dt_float)
Expand Down Expand Up @@ -430,7 +435,7 @@ def _make_single_step_error_estimate(
An example for the state from which the grid and other information can
be extracted
"""
if self.pde.is_sde:
if getattr(self.pde, "is_sde"):
raise RuntimeError("Cannot use adaptive stepper with stochastic equation")

single_step = self._make_single_step_variable_dt(state)
Expand Down Expand Up @@ -485,6 +490,7 @@ def _make_adaptive_stepper(
# obtain functions determining how the PDE is evolved
single_step_error = self._make_single_step_error_estimate(state)
modify_after_step = self._make_modify_after_step(state)
modify_state_after_step = self._modify_state_after_step
sync_errors = self._make_error_synchronizer()

# obtain auxiliary functions
Expand Down Expand Up @@ -525,7 +531,8 @@ def adaptive_stepper(
steps += 1
t += dt_step
state_data[...] = new_state
modifications += modify_after_step(state_data)
if modify_state_after_step:
modifications += modify_after_step(state_data)
if dt_stats is not None:
dt_stats.add(dt_step)

Expand Down Expand Up @@ -561,8 +568,8 @@ def make_stepper(
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the explicit stepping. If `None`, this solver specifies
1e-3 as a default value.
Time step used (Uses :attr:`SolverBase.dt_default` if `None`). This sets
the initial time step for adaptive solvers.
Returns:
Function that can be called to advance the `state` from time `t_start` to
Expand All @@ -576,14 +583,14 @@ def make_stepper(
# support `None` as a default value, so the controller can signal that
# the solver should use a default time step
if dt is None:
dt_float = 1e-3
dt_float = self.dt_default
else:
dt_float = float(dt) # explicit casting to help type checking

self.info["dt"] = dt_float
self.info["dt_adaptive"] = True
self.info["steps"] = 0
self.info["stochastic"] = self.pde.is_sde
self.info["stochastic"] = getattr(self.pde, "is_sde", False)
self.info["state_modifications"] = 0.0

# create stepper with adaptive steps
Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def run(
node will return the state. All other nodes return None.
"""
# copy the initial state to not modify the supplied one
if hasattr(self.solver, "pde") and self.solver.pde.complex_valued:
if getattr(self.solver, "pde") and self.solver.pde.complex_valued:
self._logger.info("Convert state to complex numbers")
state: TState = initial_state.copy(dtype=complex)
else:
Expand Down

0 comments on commit 91ae49e

Please sign in to comment.