-
-
Notifications
You must be signed in to change notification settings - Fork 605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric with multiple input runs in an unexpected way. #2940
Labels
Comments
lyhyl
changed the title
Metric with multiple input runs in an unexpected way.
May 9, 2023
Loss
metric with multiple input runs in an unexpected way.
@lyhyl thanks for reporting this issue! RIght now a workaround could be to replace the structure import torch
import torch.nn as nn
import torch.nn.functional as F
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss
class TargetsPair:
a: torch.Tensor
b: torch.Tensor
def __init__(self, a, b):
self.a = a
self.b = b
def __len__(self):
return len(self.a)
class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb
def forward(self, y_pred: TargetsPair, y_true: TargetsPair) -> torch.Tensor:
a_true, b_true = y_true.a, y_true.b
a_pred, b_pred = y_pred.a, y_pred.b
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)
def prepare_batch(batch, device, non_blocking):
return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))
class MyModel(nn.Module):
def forward(self, x):
return torch.rand(4, 1), torch.rand(4, 2)
model = MyModel()
def output_transform(output):
(a_pred, b_pred), (a_true, b_true) = output
return TargetsPair(a_pred, b_pred), TargetsPair(a_true, b_true)
device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss, output_transform=output_transform)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)
data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"] In future, we may introduce a flag into |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
❓ Questions/Help/Support
My customized loss requires two pairs of input:
When I try to log the loss with
Loss
metric:It will crash on line:
ignite/ignite/metrics/metric.py
Line 308 in 4825bb6
because it treats all inputs as independent pair of y_pred and y, which is not what
MyLoss
need.I dug into the source code I found #2055 introduces a new feature, which causes this issue.
So, what are the best practices for dealing with multiple input losses?
The text was updated successfully, but these errors were encountered: