Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORAL loss is defined differently from the original paper #17

Open
DenisDsh opened this issue Jul 15, 2018 · 9 comments
Open

CORAL loss is defined differently from the original paper #17

DenisDsh opened this issue Jul 15, 2018 · 9 comments

Comments

@DenisDsh
Copy link

I noticed that both the covariance and the Frobenius norm are computed differently in your implementation.

You compute the Frobenius norm as below:
# frobenius norm between source and target
loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
However as stated here http://mathworld.wolfram.com/FrobeniusNorm.html , after squaring each element and summing them, should be computed the square root of the sum not the mean of the squared elements.

In the original paper the covariances are computed as below :
https://arxiv.org/abs/1607.01719

screen shot 2018-07-15 at 20 04 29

While in your implementation:

# source covariance 
xm = torch.mean(source, 0, keepdim=True) - source
xc = xm.t() @ xm

# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target
xct = xmt.t() @ xmt  
@yaox12
Copy link

yaox12 commented Jul 25, 2018

I agree with you. @SSARCandy

This is my implementation, any advice?

def coral_loss(source, target):
    d = source.size(1)
    ns, nt = source.size(0), target.size(0)

    # source covariance
    tmp_s = torch.ones((1, ns)) @ source
    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
    
    # target covariance
    tmp_t = torch.ones((1, nt)) @ target
    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

    # frobenius norm
    loss = (cs - ct).pow(2).sum().sqrt()
    loss = loss / (4 * d * d)

    return loss

@redhat12345
Copy link

@yaox12 I used to run your code but got the following error.

Traceback (most recent call last):
File "DeepCoral.py", line 117, in
train(epoch, model)
File "DeepCoral.py", line 77, in train
label_source_pred, loss_coral = model(data_source, data_target)
File "/home/user/pytorch_python3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in call
result = self.forward(*input, **kwargs)
File "/media/user/DATA/DA_pytorch/transferlearning/code/deep/DeepCoral/ResNet.py", line 161, in forward
loss += CORAL(source, target)
File "/media/user/DATA/DA_pytorch/transferlearning/code/deep/DeepCoral/Coral.py", line 43, in CORAL
tmp_s = torch.ones((1, ns)) @ source
TypeError: unsupported operand type(s) for @: 'torch.FloatTensor' and 'Variable'
[6]+ Killed python DeepCoral.py

@yaox12
Copy link

yaox12 commented Sep 16, 2018

@redhat12345 My code is based on PyTorch>=0.4, in which torch.tensor and Variable are merged together.

@redhat12345
Copy link

@yaox12 Even I use Pytorch=0.4 but got the error:

Traceback (most recent call last):
File "DeepCoral.py", line 147, in
train(epoch, model)
File "DeepCoral.py", line 85, in train
label_source_pred, loss_coral = model(data_source, data_target)
File "/home/user/pytorch4_python3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, **kwargs)
File "/media/user/DATA/DA_pytorch/transferlearning/code/deep/DeepCoral/ResNet.py", line 161, in forward
loss += CORAL(source, target)
File "/media/user/DATA/DA_pytorch/transferlearning/code/deep/DeepCoral/Coral.py", line 63, in CORAL
tmp_s = torch.ones((1, ns)) @ source
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #2 'mat2'

@yaox12
Copy link

yaox12 commented Sep 26, 2018

@redhat12345 if the source and target are cuda tensors, then torch.ones((1, ns)) should be torch.ones((1, ns)).cuda(), as well as that of nt.
I have tried with this loss and find it usually gets NaN. I have no idea why.

@mrsempress
Copy link

mrsempress commented Dec 7, 2019

@yaox12, I agree with you. But I think line 14 is:
loss = (cs - ct).pow(2).sum().
Because in paper is
$$l_{coral}=\frac{1}{4d^2}||C_s-C_T||^2_F$$
and Frobenius norm is
$$||A||_F=\sqrt{\sum^m\sum^n |a|^2}$$
then
$$||C_s-C_T||^2_F$$
should not have sqrt().
And I think writer's code is also right.

@typhoon1104
Copy link

Why you think writer's code is also right?

@yangguangan
Copy link

In my opinion, the main problem is the calculation of the covariance, in paper, the covariance is get by dividing by (n-1), but in the code , it is get by dividing by (n), that is " torch.mean(torch.mul((xc - xct), (xc - xct)))" . however, I'm actually not sure which one is the right one.

@ch-andrei
Copy link

ch-andrei commented Jun 2, 2022

tldr, no error, this code is "correct" but the magnitude of the loss is not scaled correctly.

  1. deep coral uses squared frobenius loss so sqrt is not necessary; original would use torch.sum and not torch.mean though so doing loss / (4 * d * d) should actually simply be loss / 4 (as computing the mean already divides by d * d)
  2. If you plot the values produced by this code vs the original method from the paper, you get the same trends but they are scaled differently, i.e., this code makes the magnitude of coral loss different by a ratio of D*D/(B-1)**2, for B batch size and D dimensionality of features.
def coral_loss(source, target):
    # source covariance
    xs = torch.mean(source, 0, keepdim=True) - source
    xs = xs.t() @ xs

    # target covariance
    xt = torch.mean(target, 0, keepdim=True) - target
    xt = xt.t() @ xt

    # frobenius norm
    loss = torch.mean(torch.mul(xs - xt, xs - xt))

    # note: b batch dim, d is feature dim
    # original deep coral implementation differs from the above by a ratio of (d * d / (b-1) / (b-1))
    # loss / (4 * d * d) * (d * d / (b-1) / (b-1)) simplifies to
    b = source.shape[0] - 1  # batch dim
    return loss / (4 * b * b)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants