Skip to content

Commit

Permalink
Improved handling of initial data (#587)
Browse files Browse the repository at this point in the history
* Improved handling of initial data
* Added logging for post_step_hook
  • Loading branch information
david-zwicker authored Aug 5, 2024
1 parent 0bc569e commit e9cbc35
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _make_post_step_hook(self, state: FieldBase) -> StepperHook:
try:
# look for the definition of a hook function
if hasattr(self.pde, "make_modify_after_step"):

# Deprecated on 2024-08-02
warnings.warn(
"`make_modify_after_step` has been replaced by `make_post_step_hook`",
Expand All @@ -156,12 +155,16 @@ def post_step_hook(

# create zero of correct type
self._post_step_data_init = np.dtype(state.dtype).type(0)
self._logger.info(
"Created post-step hook from `make_modify_after_step`"
)

else:
# get hook function and initial data from PDE
post_step_hook, self._post_step_data_init = (
self.pde.make_post_step_hook(state)
)
self._logger.info("Created post-step hook from PDE")

except NotImplementedError:
pass # no hook function defined on the PDE
Expand All @@ -175,20 +178,17 @@ def post_step_hook(
"""Default hook function does nothing."""

self._post_step_data_init = None
self._logger.debug("No post-step hook defined")

else:
if np.isscalar(self._post_step_data_init):
self._post_step_data_init = np.array(self._post_step_data_init)
if not isinstance(self._post_step_data_init, np.ndarray):
raise TypeError(
"The intial data provided for the post-step hook must be a number "
"or a numpy array."
)
# ensure that the initial values is a mutable array
self._post_step_data_init = np.array(self._post_step_data_init, copy=True)

self._post_step_data_type = nb.typeof(self._post_step_data_init)
if self._compiled:
sig_hook = (nb.typeof(state.data), nb.float64, self._post_step_data_type)
post_step_hook = jit(sig_hook)(post_step_hook)
self._logger.debug("Compiled post-step hook")

return post_step_hook # type: ignore

Expand Down

0 comments on commit e9cbc35

Please sign in to comment.