Skip to content

Commit

Permalink
fix: bug in doubling estimators_ if model loaded from save_dir + feat…
Browse files Browse the repository at this point in the history
…ure for having callback after each epoch (#166)

* fix: dont instantiate if estimators loaded from save_dir

* feat: add on_epoch_end_cb to call at end of each epoch

---------

Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
h2soheili and [email protected] authored Jun 16, 2024
1 parent 6726a99 commit 4366765
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchensemble/soft_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def fit(
test_loader=None,
save_model=True,
save_dir=None,
on_epoch_end_cb=None
):

# Instantiate base estimators and set attributes
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
# dont instantiate if estimators loaded from save_dir
if len(self.estimators_) != self.n_estimators:
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
self._validate_parameters(epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)

Expand Down Expand Up @@ -295,6 +298,9 @@ def fit(
else:
scheduler.step()

# Call on epoch end
if on_epoch_end_cb:
on_epoch_end_cb(epoch)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

Expand Down Expand Up @@ -390,6 +396,7 @@ def fit(
test_loader=None,
save_model=True,
save_dir=None,
on_epoch_end_cb=None
):
super().fit(
train_loader=train_loader,
Expand All @@ -399,6 +406,7 @@ def fit(
test_loader=test_loader,
save_model=save_model,
save_dir=save_dir,
on_epoch_end_cb=on_epoch_end_cb,
)

@torchensemble_model_doc(
Expand Down

0 comments on commit 4366765

Please sign in to comment.