Skip to content

Commit

Permalink
Remove analyzer.fit
Browse files Browse the repository at this point in the history
Remove dead code left from PatternNet and PatternAttribution.
  • Loading branch information
adrhill committed Sep 14, 2022
1 parent ddc0890 commit 384dc61
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 100 deletions.
21 changes: 6 additions & 15 deletions examples/introduction_development.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,6 @@
" def __init__(self, model, disable_model_checks=False):\n",
" pass\n",
"\n",
" def fit(self, *args, **kwargs):\n",
" \"\"\"\n",
" Train analyzer like a Keras model.\n",
" Does not need to be implemented.\n",
" \"\"\"\n",
" pass\n",
"\n",
" def fit_generator(self, *args, **kwargs):\n",
" \"\"\"\n",
" Train analyzer like a Keras model.\n",
" Does not need to be implemented.\n",
" \"\"\"\n",
" pass\n",
"\n",
" def analyze(self, X):\n",
" \"\"\"\n",
" Analyze the behavior of model on input `X`.\n",
Expand Down Expand Up @@ -977,7 +963,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.7 ('innvestigate-pdNhrmV2-py3.9')",
"language": "python",
"name": "python3"
},
Expand All @@ -992,6 +978,11 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "6d5a31c854ee050015bee53401f829bd4e104bd774bbc6a8c859288cc9ab765e"
}
}
},
"nbformat": 4,
Expand Down
83 changes: 0 additions & 83 deletions src/innvestigate/analyzer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class AnalyzerBase(metaclass=ABCMeta):
>>> model = create_keras_model()
>>> a = Analyzer(model)
>>> a.fit(X_train) # If analyzer needs training.
>>> analysis = a.analyze(X_test)
>>>
>>> state = a.save()
Expand Down Expand Up @@ -127,39 +126,6 @@ def _do_model_checks(self) -> None:

self._model_check_done = True

def fit(self, *_args, disable_no_training_warning: bool = False, **_kwargs):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
if not disable_no_training_warning:
# issue warning if no training is foreseen, but fit() is still called.
warnings.warn(
"This analyzer does not need to be trained." " Still fit() is called.",
RuntimeWarning,
)

def fit_generator(
self, *_args, disable_no_training_warning: bool = False, **_kwargs
):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
if not disable_no_training_warning:
# issue warning if no training is foreseen, but fit() is still called.
warnings.warn(
"This analyzer does not need to be trained."
" Still fit_generator() is called.",
RuntimeWarning,
)

@abstractmethod
def analyze(
self, X: OptionalList[np.ndarray], *args: Any, **kwargs: Any
Expand Down Expand Up @@ -248,52 +214,3 @@ def load_npz(fname):
class_name = npz_file["class_name"].item()
state = npz_file["state"].item()
return AnalyzerBase.load(class_name, state)


###############################################################################


class TrainerMixin:
"""Mixin for analyzer that adapt to data.
This convenience interface exposes a Keras like training routing
to the user.
"""

# TODO: extend with Y
def fit(self, X: np.ndarray | None = None, batch_size: int = 32, **kwargs) -> None:
"""
Takes the same parameters as Keras's :func:`model.fit` function.
"""
generator = isequence.BatchSequence(X, batch_size)
return self._fit_generator(generator, **kwargs) # type: ignore

def fit_generator(self, *args, **kwargs):
"""
Takes the same parameters as Keras's :func:`model.fit_generator`
function.
"""
return self._fit_generator(*args, **kwargs)

def _fit_generator(self, *_args, **_kwargs):
raise NotImplementedError()


class OneEpochTrainerMixin(TrainerMixin):
"""Exposes the same interface and functionality as :class:`TrainerMixin`
except that the training is limited to one epoch.
"""

def fit(self, *args, **kwargs) -> None:
"""
Same interface as :func:`fit` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
return super().fit(*args, epochs=1, **kwargs)

def fit_generator(self, *args, steps: int = None, **kwargs):
"""
Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
return super().fit_generator(*args, steps_per_epoch=steps, epochs=1, **kwargs)
2 changes: 0 additions & 2 deletions tests/dryrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,10 @@ def _apply_test(self, model: Model) -> None:
# Generate random training input
input_shape = model.input_shape[1:]
x = np.random.rand(1, *input_shape).astype(np.float32)
x_fit = np.random.rand(16, *input_shape).astype(np.float32)
# Call model with test input
model.predict(x)
# Get analyzer.
analyzer = self._method(model)
analyzer.fit(x_fit)
# Generate random test input
analysis = analyzer.analyze(x)
assert tuple(analysis.shape) == (1,) + input_shape
Expand Down

0 comments on commit 384dc61

Please sign in to comment.