diff --git a/numpy_ml/neural_nets/layers/layers.py b/numpy_ml/neural_nets/layers/layers.py index 198fe13..a6de9f7 100644 --- a/numpy_ml/neural_nets/layers/layers.py +++ b/numpy_ml/neural_nets/layers/layers.py @@ -79,6 +79,9 @@ def update(self, cur_loss=None): optimizer. Flush all gradients once the update is complete. """ assert self.trainable, "Layer is frozen" + self.optimizer = ( + self.optimizer.copy() if self.optimizer.cur_step == 0 else self.optimizer + ) self.optimizer.step() for k, v in self.gradients.items(): if k in self.parameters: