Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error:tImplementedError: Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'CPU' backend. #3

Open
daixiangzi opened this issue Dec 5, 2023 · 24 comments

Comments

@daixiangzi
Copy link

when I use:
from torch.cuda.amp import GradScaler
amp = GradScaler(init_scale=512, growth_interval=100)
for img, label in train_iter:
with torch.cuda.amp.autocast(True):
label = backbone(img)
amp.scale(loss).backward()
amp.unscale_(opt)

@lucidrains
Copy link
Owner

lucidrains commented Dec 5, 2023

@daixiangzi that looks like an unrelated error that is specific to cpu with the amp.scale? could you retry on cuda?

@daixiangzi
Copy link
Author

daixiangzi commented Dec 6, 2023

@daixiangzi that looks like an unrelated error that is specific to cpu with the amp.scale? could you retry on cuda?

When I use a single machine with multiple gpu, I can run your code, but when I use multiple machines with multiple gpus,error will appear.

@daixiangzi
Copy link
Author

@daixiangzi that looks like an unrelated error that is specific to cpu with the amp.scale? could you retry on cuda?
I have always used it on cuda

@lucidrains
Copy link
Owner

are you using fsdp?

@daixiangzi
Copy link
Author

daixiangzi commented Dec 7, 2023

fsdp

no ,I use ddp

@lucidrains
Copy link
Owner

ah, hard for me to debug. i only have a single machine

@lucidrains
Copy link
Owner

@daixiangzi could you try disabling mixed precision just for the MoE block?

with torch.cuda.amp.autocast(enabled = False):
    # ... your moe forward

@lucidrains
Copy link
Owner

finally seeing the limits of pytorch

in jax, properly working moe is just a few lines of code

@daixiangzi
Copy link
Author

@daixiangzi could you try disabling mixed precision just for the MoE block?

with torch.cuda.amp.autocast(enabled = False):
    # ... your moe forward

ok

@daixiangzi
Copy link
Author

@daixiangzi could you try disabling mixed precision just for the MoE block?

with torch.cuda.amp.autocast(enabled = False):
    # ... your moe forward

I try it just now, but not work

@daixiangzi
Copy link
Author

daixiangzi commented Jan 4, 2024

define moe layer:

self.softmoe = SoftMoE(dim = 512,num_experts =64,is_distributed=True,offload_unused_experts_to_cpu=True)

forward:

with torch.cuda.amp.autocast(enabled =False):
                x = x + self.softmoe(self.ln_3(x).permute(1,0,2)).permute(1,0,2)

@daixiangzi
Copy link
Author

NotImplementedError: Could not run 'aten::amp_foreach_non_finite_check_and_unscale' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build proces

@daixiangzi
Copy link
Author

in fact I train dp version on single machine with multiple gpu, but result is still bad, I have a little doubt about


I see origin rep:https://github.com/google-research/vmoe/blob/dfb9ee01ce6dfc5a8228b406e768ee325dd18fcd/vmoe/nn/vit_moe.py#L109C27-L109C45
your code:

x = self.norm(x)
logits = einsum('b n d, e s d -> b n e s', x, slot_embeds)
...
slots = einsum('b n d, b n e s -> b e s d', x, dispatch_weights)

origin rep code seem is:

norm_x = self.norm(x)
logits = einsum('b n d, e s d -> b n e s', norm_x, slot_embeds)
....
slots = einsum('b n d, b n e s -> b e s d', x, dispatch_weights)

@lucidrains
Copy link
Owner

@daixiangzi hmm, have you tried it their way? conventionally we always work with pre-normalized values, so that would be weird they dispatch based on unnormalized input into the block

@daixiangzi
Copy link
Author

@daixiangzi hmm, have you tried it their way? conventionally we always work with pre-normalized values, so that would be weird they dispatch based on unnormalized input into the block

no ,I will try it.but I always have some doubts about the effectiveness of softmoe, mainly because I have been adjusting it for a long time

@daixiangzi
Copy link
Author

In terms of visual tasks, there are currently many open source versions of MOE, but in fact, there is not a sufficient experiment to demonstrate the effectiveness of this approach

@lucidrains
Copy link
Owner

@daixiangzi yea it happens

thanks for sharing your results

just to make sure we don't miss anything, would you like to also do a quick run where i substituted the rmsnorm with layernorm? (set https://github.com/lucidrains/soft-moe-pytorch/blob/main/soft_moe_pytorch/soft_moe.py#L287 to True)

@daixiangzi
Copy link
Author

@daixiangzi yea it happens

thanks for sharing your results

just to make sure we don't miss anything, would you like to also do a quick run where i substituted the rmsnorm with layernorm? (set https://github.com/lucidrains/soft-moe-pytorch/blob/main/soft_moe_pytorch/soft_moe.py#L287 to True)

Actually, before this, I tried my own version where the norm part used LN, but the effect was still not as good as the baseline

@lucidrains
Copy link
Owner

@daixiangzi ah ok, good to know

i'll take your experience as a datapoint

@daixiangzi
Copy link
Author

myself softmoe version:

        #x: [b,50,768]
        #phi [768, 128, slots]
        phi = self.phi
        phi = self.scale*self.normalize(phi,axis=0)
        #logits:[b, 50, 128, 1]
        
        logits = einsum(self.normalize(x,axis=-1), phi, "b m d, d n p -> b m n p")
        #add noise
        if self.training:
            normal = torch.randn(logits.shape, dtype=logits.dtype,device=logits.device)
            logits = logits +(1.0 * normal)
        #dispatch_weights:[b, token_num, 128, 1])
        dispatch_weights = logits.softmax(dim=1)  # denoted 'D' in the paper
        # NOTE: The 'torch.softmax' function does not support multiple values for the
        # 'dim' argument (unlike jax), so we are forced to flatten the last two dimensions.
        # Then, we rearrange the Tensor into its original shape.
        #对experts softmax:combine_weights:[b, token_num, 128, 1]
        combine_weights = rearrange(logits.flatten(start_dim=2).softmax(dim=-1),"b m (n p) -> b m n p",n=self.num_experts)

        # NOTE: To save memory, I don't rename the intermediate tensors Y, Ys, Xs.
        # Instead, I just overwrite the 'x' variable.  The names from the paper are
        # included in a comment for each line below.
        
        x = einsum(x, dispatch_weights, "b m d, b m n p -> b n p d")  # Xs
        #x:[b,128, 1, out_d]
        x = self.experts(x)  # Ys
        #combine_weights:[b, 50, 128, 1]
        x = einsum(x, combine_weights, "b n p d, b m n p -> b m d")  # Y
        #[b, 128, 1, 384]
        return x

@daixiangzi
Copy link
Author

in softmoe paper:
self.normalize should l2_norm, for example:

 def normalize(self,x:torch.Tensor,axis:int):
        return x*torch.rsqrt(torch.square(x).sum(dim=axis, keepdim=True) + 1e-6)

@lucidrains
Copy link
Owner

in softmoe paper: self.normalize should l2_norm, for example:

 def normalize(self,x:torch.Tensor,axis:int):
        return x*torch.rsqrt(torch.square(x).sum(dim=axis, keepdim=True) + 1e-6)

ohh i did it correct then the first time

@daixiangzi
Copy link
Author

in softmoe paper: self.normalize should l2_norm, for example:

 def normalize(self,x:torch.Tensor,axis:int):
        return x*torch.rsqrt(torch.square(x).sum(dim=axis, keepdim=True) + 1e-6)

ohh i did it correct then the first time

ok

@xiaoshuomin
Copy link

NotImplementedError: Could not run 'aten::amp_foreach_non_finite_check_and_unscale' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build proces

Hi, I met the same problem as you. Have you solved this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants