Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LBFGS optimizer for Fusion #81

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@
"code"
]
},
{
"login": "e-eight",
"name": "Soham Pal",
"avatar_url": "https://avatars.githubusercontent.com/u/3883241?v=4",
"profile": "https://soham.dev",
"contributions": [
"code"
]
},
{
"login": "by256",
"name": "Batuhan Yildirim",
Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<!-- markdownlint-disable -->
<table>
<tr>
<td align="center"><a href="https://github.com/mttgdd"><img src="https://avatars.githubusercontent.com/u/3154919?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Matt Gadd</b></sub></a><br /><a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=mttgdd" title="Code">💻</a></td>
<td align="center"><a href="https://soham.dev"><img src="https://avatars.githubusercontent.com/u/3883241?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Soham Pal</b></sub></a><br /><a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=e-eight" title="Code">💻</a></td>
<td align="center"><a href="https://by256.github.io/"><img src="https://avatars.githubusercontent.com/u/44163664?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Batuhan Yildirim</b></sub></a><br /><a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=by256" title="Code">💻</a></td>
<td align="center"><a href="https://github.com/mttgdd"><img src="https://avatars.githubusercontent.com/u/3154919?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Matt Gadd</b></sub></a><br /><a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=mttgdd" title="Code">💻</a></td>
<td align="center"><a href="https://github.com/zzzzwj"><img src="https://avatars.githubusercontent.com/u/23235538?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Wenjie Zhang</b></sub></a><br /><a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=zzzzwj" title="Code">💻</a> <a href="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/commits?author=zzzzwj" title="Tests">⚠️</a></td>
Expand Down
2 changes: 1 addition & 1 deletion torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
optimizer_name : string
The name of the optimizer, should be one of {``Adadelta``, ``Adagrad``,
``Adam``, ``AdamW``, ``Adamax``, ``ASGD``, ``RMSprop``, ``Rprop``,
``SGD``}.
``SGD``, ``LBFGS``}.
**kwargs : keyword arguments
Keyword arguments on setting the optimizer, should be in the form:
``lr=1e-3, weight_decay=5e-4, ...``. These keyword arguments
Expand Down
36 changes: 24 additions & 12 deletions torchensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
"""


import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
from joblib import Parallel, delayed

from ._base import BaseClassifier, BaseRegressor
from ._base import torchensemble_model_doc
from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
from .utils import set_module
from .utils import operator as op

from .utils import set_module

__all__ = ["BaggingClassifier", "BaggingRegressor"]

Expand Down Expand Up @@ -59,11 +57,20 @@ def _parallel_fit_per_epoch(
sampling_data = [tensor[sampling_mask] for tensor in data]
sampling_target = target[sampling_mask]

optimizer.zero_grad()
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
sampling_output = estimator(*sampling_data)
loss = criterion(sampling_output, sampling_target)
if loss.requires_grad:
loss.backward()
return loss

optimizer.step(closure)

# Calculate loss for logging
sampling_output = estimator(*sampling_data)
loss = criterion(sampling_output, sampling_target)
loss.backward()
optimizer.step()
loss = closure()

# Print training status
if batch_idx % log_interval == 0:
Expand All @@ -79,15 +86,20 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
idx, epoch, batch_idx, loss, correct, subsample_size
idx,
epoch,
batch_idx,
loss.item(),
correct,
subsample_size,
)
)
else:
msg = (
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
print(msg.format(idx, epoch, batch_idx, loss))
print(msg.format(idx, epoch, batch_idx, loss.item()))

return estimator, optimizer

Expand Down
54 changes: 38 additions & 16 deletions torchensemble/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
import torch.nn as nn
import torch.nn.functional as F

from ._base import BaseClassifier, BaseRegressor
from ._base import torchensemble_model_doc
from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
from .utils import set_module
from .utils import operator as op

from .utils import set_module

__all__ = ["FusionClassifier", "FusionRegressor"]

Expand Down Expand Up @@ -109,11 +107,20 @@ def fit(
data, target = io.split_data_target(elem, self.device)
batch_size = data[0].size(0)

optimizer.zero_grad()
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
output = self._forward(*data)
loss = self._criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss

optimizer.step(closure)

# Calculate loss for logging
output = self._forward(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
loss = closure()

# Print training status
if batch_idx % log_interval == 0:
Expand All @@ -127,12 +134,16 @@ def fit(
)
self.logger.info(
msg.format(
epoch, batch_idx, loss, correct, batch_size
epoch,
batch_idx,
loss.item(),
correct,
batch_size,
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Train_Loss", loss, total_iters
"fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1

Expand Down Expand Up @@ -257,20 +268,31 @@ def fit(

data, target = io.split_data_target(elem, self.device)

optimizer.zero_grad()
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
output = self.forward(*data)
loss = self._criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss

optimizer.step(closure)

# Calculate loss for logging
output = self.forward(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
loss = closure()

# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}"
self.logger.info(msg.format(epoch, batch_idx, loss))
self.logger.info(
msg.format(epoch, batch_idx, loss.item())
)
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Train_Loss", loss, total_iters
"fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1

Expand Down
15 changes: 10 additions & 5 deletions torchensemble/tests/test_set_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torchensemble
import torch.nn as nn

import torchensemble

optimizer_list = [
"Adadelta",
Expand All @@ -13,6 +12,7 @@
"RMSprop",
"Rprop",
"SGD",
"LBFGS",
]


Expand All @@ -33,9 +33,14 @@ def forward(self, X):
@pytest.mark.parametrize("optimizer_name", optimizer_list)
def test_set_optimizer_normal(optimizer_name):
model = MLP()
torchensemble.utils.set_module.set_optimizer(
model, optimizer_name, lr=1e-3
)
if optimizer_name != "LBFGS":
torchensemble.utils.set_module.set_optimizer(
model, optimizer_name, lr=1e-3
)
else:
torchensemble.utils.set_module.set_optimizer(
model, optimizer_name, history_size=7, max_iter=10
)


def test_set_optimizer_Unknown():
Expand Down
1 change: 1 addition & 0 deletions torchensemble/utils/set_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def set_optimizer(model, optimizer_name, **kwargs):
"RMSprop",
"Rprop",
"SGD",
"LBFGS",
]
if optimizer_name not in torch_optim_optimizers:
msg = "Unrecognized optimizer: {}, should be one of {}."
Expand Down
31 changes: 19 additions & 12 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
"""


import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
from joblib import Parallel, delayed

from ._base import BaseClassifier, BaseRegressor
from ._base import torchensemble_model_doc
from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
from .utils import set_module
from .utils import operator as op

from .utils import set_module

__all__ = ["VotingClassifier", "VotingRegressor"]

Expand Down Expand Up @@ -49,11 +47,20 @@ def _parallel_fit_per_epoch(
data, target = io.split_data_target(elem, device)
batch_size = data[0].size(0)

optimizer.zero_grad()
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
output = estimator(*data)
loss = criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss

optimizer.step(closure)

# Calculate loss for logging
output = estimator(*data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
loss = closure()

# Print training status
if batch_idx % log_interval == 0:
Expand All @@ -69,7 +76,7 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
idx, epoch, batch_idx, loss, correct, batch_size
idx, epoch, batch_idx, loss.item(), correct, batch_size
)
)
# Regression
Expand All @@ -78,7 +85,7 @@ def _parallel_fit_per_epoch(
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
print(msg.format(idx, epoch, batch_idx, loss))
print(msg.format(idx, epoch, batch_idx, loss.item()))

return estimator, optimizer

Expand Down