Skip to content

Commit

Permalink
[FIX] Fix missing base estimators when calling load (#45)
Browse files Browse the repository at this point in the history
* fix save and load

* remove randomness in unit tests
  • Loading branch information
xuyxu authored Feb 25, 2021
1 parent 1ede2a9 commit 4f49c0c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
[Ver 0.1.*]
-----------

* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__
* |MajorFeature| Add methods on model deserialization :meth:`load()` for all ensembles | `@mttgdd <https://github.com/mttgdd>`__

[Beta]
Expand Down
41 changes: 30 additions & 11 deletions torchensemble/tests/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.data import TensorDataset, DataLoader

import torchensemble
from torchensemble.utils import io
from torchensemble.utils.logging import set_logger


Expand All @@ -24,6 +25,10 @@
torchensemble.AdversarialTrainingRegressor]


# Remove randomness
np.random.seed(0)
torch.manual_seed(0)

set_logger("pytest_all_models")


Expand Down Expand Up @@ -66,12 +71,10 @@ def forward(self, X):

# Testing data
X_test = torch.Tensor(np.array(([0.5, 0.5],
[0.6, 0.6],
[0.7, 0.7],
[0.8, 0.8])))
[0.6, 0.6])))

y_test_clf = torch.LongTensor(np.array(([1, 1, 0, 0])))
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6, 0.7, 0.8])))
y_test_clf = torch.LongTensor(np.array(([1, 0])))
y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6])))
y_test_reg = y_test_reg.view(-1, 1)


Expand All @@ -94,9 +97,9 @@ def test_clf(clf):

# Prepare data
train = TensorDataset(X_train, y_train_clf)
train_loader = DataLoader(train, batch_size=2)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_clf)
test_loader = DataLoader(test, batch_size=2)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Snapshot ensemble needs more epochs
if isinstance(model, torchensemble.SnapshotEnsembleClassifier):
Expand All @@ -109,7 +112,15 @@ def test_clf(clf):
save_model=True)

# Test
model.predict(test_loader)
prev_acc = model.predict(test_loader)

# Reload
new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False)
io.load(new_model)

post_acc = new_model.predict(test_loader)

assert prev_acc == post_acc # ensure the same performance


@pytest.mark.parametrize("reg", all_reg)
Expand All @@ -131,9 +142,9 @@ def test_reg(reg):

# Prepare data
train = TensorDataset(X_train, y_train_reg)
train_loader = DataLoader(train, batch_size=2)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_reg)
test_loader = DataLoader(test, batch_size=2)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Snapshot ensemble needs more epochs
if isinstance(model, torchensemble.SnapshotEnsembleRegressor):
Expand All @@ -146,7 +157,15 @@ def test_reg(reg):
save_model=True)

# Test
model.predict(test_loader)
prev_mse = model.predict(test_loader)

# Reload
new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False)
io.load(new_model)

post_mse = new_model.predict(test_loader)

assert prev_mse == post_mse # ensure the same performance


@pytest.mark.parametrize("method", all_clf + all_reg)
Expand Down
15 changes: 13 additions & 2 deletions torchensemble/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ def save(model, save_dir, logger):
filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,
model.base_estimator_.__name__,
model.n_estimators)
state = {"model": model.state_dict()}

# The real number of base estimators in some ensembles is not same as
# `n_estimators`.
state = {"n_estimators": len(model.estimators_),
"model": model.state_dict()}
save_dir = os.path.join(save_dir, filename)

logger.info("Saving the model to `{}`".format(save_dir))
Expand All @@ -39,4 +43,11 @@ def load(model, save_dir="./", logger=None):
if logger:
logger.info("Loading the model from `{}`".format(save_dir))

model.load_state_dict(torch.load(save_dir)["model"])
state = torch.load(save_dir)
n_estimators = state["n_estimators"]
model_params = state["model"]

# Pre-allocate and load all base estimators
for _ in range(n_estimators):
model.estimators_.append(model._make_estimator())
model.load_state_dict(model_params)

0 comments on commit 4f49c0c

Please sign in to comment.