Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dots + relative positon bias having issue with dimension mismatch in Transformer Block!!! #20

Open
Uljibuh opened this issue Mar 14, 2024 · 0 comments

Comments

@Uljibuh
Copy link

Uljibuh commented Mar 14, 2024

I try to run the CoAtNet wihtout modifying anything,

def coatnet_0(): num_blocks = [1, 1, 1, 1, 1] # L channels = [64, 96, 192, 384, 768] # D return CoAtNet((50, 50), 3, num_blocks, channels, num_classes=3)

img = torch.randn(1, 3, 50, 50)

net = coatnet_0() out = net(img) print(out.shape)

The error is that the dot is computed from downsampled image,
and the relative bias position is calculated from original image, i do not think that atteniton mechanism is miscalcuting the dots and relative position bias.
i think the issue is from how attention is iplement in Transformer Block, somehow the attention in transformer block is not utilizing downsampled image for relative position bias, instead it is calculated based on original image.
how can i solve this issue?


RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 out = net(img)
2 print(out.shape)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[10], line 25, in CoAtNet.forward(self, x)
23 x = self.s1(x)
24 x = self.s2(x)
---> 25 x = self.s3(x)
26 x = self.s4(x)
28 x = self.pool(x).view(-1, x.shape[1])

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[9], line 31, in Transformer.forward(self, x)
29 def forward(self, x):
30 if self.downsample:
---> 31 x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
32 else:
33 x = x + self.attn(x)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[4], line 8, in PreNorm.forward(self, x, **kwargs)
7 def forward(self, x, **kwargs):
----> 8 return self.fn(self.norm(x), **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[8], line 47, in Attention.forward(self, x)
43 relative_bias = self.relative_bias_table.gather(
44 0, self.relative_index.repeat(1, self.heads))
45 relative_bias = rearrange(
46 relative_bias, '(h w) c -> 1 c h w', h=self.ihself.iw, w=self.ihself.iw)
---> 47 dots = dots + relative_bias
49 attn = self.attend(dots)
50 out = torch.matmul(attn, v)

RuntimeError: The size of tensor a (16) must match the size of tensor b (9) at non-singleton dimension 3

@Uljibuh Uljibuh changed the title dot + relation posiiton bias having issue with dimension mismatch!!! dots + relation positon bias having issue with dimension mismatch in Transformer Block!!! Mar 14, 2024
@Uljibuh Uljibuh changed the title dots + relation positon bias having issue with dimension mismatch in Transformer Block!!! dots + relative positon bias having issue with dimension mismatch in Transformer Block!!! Mar 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant