diff --git a/examples/introduction_development.ipynb b/examples/introduction_development.ipynb index 3b9ceb61..7e6fad03 100644 --- a/examples/introduction_development.ipynb +++ b/examples/introduction_development.ipynb @@ -240,20 +240,6 @@ " def __init__(self, model, disable_model_checks=False):\n", " pass\n", "\n", - " def fit(self, *args, **kwargs):\n", - " \"\"\"\n", - " Train analyzer like a Keras model.\n", - " Does not need to be implemented.\n", - " \"\"\"\n", - " pass\n", - "\n", - " def fit_generator(self, *args, **kwargs):\n", - " \"\"\"\n", - " Train analyzer like a Keras model.\n", - " Does not need to be implemented.\n", - " \"\"\"\n", - " pass\n", - "\n", " def analyze(self, X):\n", " \"\"\"\n", " Analyze the behavior of model on input `X`.\n", @@ -977,7 +963,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.7 ('innvestigate-pdNhrmV2-py3.9')", "language": "python", "name": "python3" }, @@ -992,6 +978,11 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "6d5a31c854ee050015bee53401f829bd4e104bd774bbc6a8c859288cc9ab765e" + } } }, "nbformat": 4, diff --git a/src/innvestigate/analyzer/base.py b/src/innvestigate/analyzer/base.py index 61cafeca..19b846ee 100644 --- a/src/innvestigate/analyzer/base.py +++ b/src/innvestigate/analyzer/base.py @@ -31,7 +31,6 @@ class AnalyzerBase(metaclass=ABCMeta): >>> model = create_keras_model() >>> a = Analyzer(model) - >>> a.fit(X_train) # If analyzer needs training. >>> analysis = a.analyze(X_test) >>> >>> state = a.save() @@ -127,39 +126,6 @@ def _do_model_checks(self) -> None: self._model_check_done = True - def fit(self, *_args, disable_no_training_warning: bool = False, **_kwargs): - """ - Stub that eats arguments. If an analyzer needs training - include :class:`TrainerMixin`. - - :param disable_no_training_warning: Do not warn if this function is - called despite no training is needed. - """ - if not disable_no_training_warning: - # issue warning if no training is foreseen, but fit() is still called. - warnings.warn( - "This analyzer does not need to be trained." " Still fit() is called.", - RuntimeWarning, - ) - - def fit_generator( - self, *_args, disable_no_training_warning: bool = False, **_kwargs - ): - """ - Stub that eats arguments. If an analyzer needs training - include :class:`TrainerMixin`. - - :param disable_no_training_warning: Do not warn if this function is - called despite no training is needed. - """ - if not disable_no_training_warning: - # issue warning if no training is foreseen, but fit() is still called. - warnings.warn( - "This analyzer does not need to be trained." - " Still fit_generator() is called.", - RuntimeWarning, - ) - @abstractmethod def analyze( self, X: OptionalList[np.ndarray], *args: Any, **kwargs: Any @@ -248,52 +214,3 @@ def load_npz(fname): class_name = npz_file["class_name"].item() state = npz_file["state"].item() return AnalyzerBase.load(class_name, state) - - -############################################################################### - - -class TrainerMixin: - """Mixin for analyzer that adapt to data. - - This convenience interface exposes a Keras like training routing - to the user. - """ - - # TODO: extend with Y - def fit(self, X: np.ndarray | None = None, batch_size: int = 32, **kwargs) -> None: - """ - Takes the same parameters as Keras's :func:`model.fit` function. - """ - generator = isequence.BatchSequence(X, batch_size) - return self._fit_generator(generator, **kwargs) # type: ignore - - def fit_generator(self, *args, **kwargs): - """ - Takes the same parameters as Keras's :func:`model.fit_generator` - function. - """ - return self._fit_generator(*args, **kwargs) - - def _fit_generator(self, *_args, **_kwargs): - raise NotImplementedError() - - -class OneEpochTrainerMixin(TrainerMixin): - """Exposes the same interface and functionality as :class:`TrainerMixin` - except that the training is limited to one epoch. - """ - - def fit(self, *args, **kwargs) -> None: - """ - Same interface as :func:`fit` of :class:`TrainerMixin` except that - the parameter epoch is fixed to 1. - """ - return super().fit(*args, epochs=1, **kwargs) - - def fit_generator(self, *args, steps: int = None, **kwargs): - """ - Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that - the parameter epoch is fixed to 1. - """ - return super().fit_generator(*args, steps_per_epoch=steps, epochs=1, **kwargs) diff --git a/tests/dryrun.py b/tests/dryrun.py index ed7a99d5..51ae6537 100644 --- a/tests/dryrun.py +++ b/tests/dryrun.py @@ -97,12 +97,10 @@ def _apply_test(self, model: Model) -> None: # Generate random training input input_shape = model.input_shape[1:] x = np.random.rand(1, *input_shape).astype(np.float32) - x_fit = np.random.rand(16, *input_shape).astype(np.float32) # Call model with test input model.predict(x) # Get analyzer. analyzer = self._method(model) - analyzer.fit(x_fit) # Generate random test input analysis = analyzer.analyze(x) assert tuple(analysis.shape) == (1,) + input_shape