diff --git a/README.md b/README.md index 648f25b..eef9203 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,18 @@ if __name__ == '__main__': - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage) + - [38. Frequency Channel Attention Usage](#38-Frequency-Channel-Attention-Usage) + + - [39. Attention Augmented Convolutional Networks Usage](#39-Attention-Augmented-Convolutional-Networks-Usage) + + - [40. Global Context Attention Usage](#40-Global-Context-Attention-Usage) + + - [41. Linear Context Transform Attention Usage](#41-Linear-Context-Transform-Attention-Usage) + + - [42. Gated Channel Transformation Usage](#42-Gated-Channel-Transformation-Usage) + + - [43. Gaussian Context Attention Usage](#43-Gaussian-Context-Attention-Usage) + - [Backbone Series](#Backbone-series) - [1. ResNet Usage](#1-ResNet-Usage) @@ -427,10 +439,10 @@ print(output.shape) ### 3. Simplified Self Attention Usage #### 3.1. Paper -[None]() +[SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021)](https://proceedings.mlr.press/v139/yang21o/yang21o.pdf) #### 3.2. Overview -![](./model/img/SSA.png) +![](./model/img/SimAttention.png) #### 3.3. Usage Code ```python @@ -1184,7 +1196,7 @@ if __name__ == '__main__': ``` -- +*** ### 31. ACmix Attention Usage @@ -1205,6 +1217,7 @@ if __name__ == '__main__': print(output.shape) ``` +*** ### 32. MobileViTv2 Attention Usage @@ -1232,6 +1245,7 @@ if __name__ == '__main__': print(output.shape) ``` +*** ### 33. DAT Attention Usage @@ -1276,6 +1290,7 @@ if __name__ == '__main__': print(output[0].shape) ``` +*** ### 34. CrossFormer Attention Usage @@ -1313,6 +1328,7 @@ if __name__ == '__main__': print(output.shape) ``` +*** ### 35. MOATransformer Attention Usage @@ -1350,6 +1366,7 @@ if __name__ == '__main__': print(output.shape) ``` +*** ### 36. CrissCrossAttention Attention Usage @@ -1370,6 +1387,7 @@ if __name__ == '__main__': print(outputs.shape) ``` +*** ### 37. Axial_attention Attention Usage @@ -1393,6 +1411,158 @@ if __name__ == '__main__': outputs = model(input) print(outputs.shape) +``` +*** + +### 38. Frequency Channel Attention Usage + +#### 38.1. Paper + +[FcaNet: Frequency Channel Attention Networks (ICCV 2021)](https://arxiv.org/abs/2012.11879) + +#### 38.2. Overview + +![](./model/img/FCANet.png) + +#### 38.3. Usage Code + +```python +from model.attention.FCA import MultiSpectralAttentionLayer +import torch + +if __name__ == "__main__": + input = torch.randn(32, 128, 64, 64) # (b, c, h, w) + fca_layer = MultiSpectralAttentionLayer(channel = 128, dct_h = 64, dct_w = 64, reduction = 16, freq_sel_method = 'top16') + output = fca_layer(input) + print(output.shape) + +``` +*** + +### 39. Attention Augmented Convolutional Networks Usage + +#### 39.1. Paper + +[Attention Augmented Convolutional Networks (ICCV 2019)](https://arxiv.org/abs/1904.09925) + +#### 39.2. Overview + +![](./model/img/AAAttention.png) + +#### 39.3. Usage Code + +```python +from model.attention.AAAttention import AugmentedConv +import torch + +if __name__ == "__main__": + input = torch.randn((16, 3, 32, 32)) + augmented_conv = AugmentedConv(in_channels=3, out_channels=64, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16) + output = augmented_conv(input) + print(output.shape) + +``` +*** + +### 40. Global Context Attention Usage + +#### 40.1. Paper + +[GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond (ICCVW 2019 Best Paper)](https://arxiv.org/abs/1904.11492) + +[Global Context Networks (TPAMI 2020)](https://arxiv.org/abs/2012.13375) + +#### 40.2. Overview + +![](./model/img/GCNet.png) + +#### 40.3. Usage Code + +```python +from model.attention.GCAttention import GCModule +import torch + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gc_layer = GCModule(64) + output = gc_layer(input) + print(output.shape) + +``` +*** + +### 41. Linear Context Transform Attention Usage + +#### 41.1. Paper + +[Linear Context Transform Block (AAAI 2020)](https://arxiv.org/pdf/1909.03834v2) + +#### 41.2. Overview + +![](./model/img/LCTAttention.png) + +#### 41.3. Usage Code + +```python +from model.attention.LCTAttention import LCT +import torch + +if __name__ == "__main__": + x = torch.randn(16, 64, 32, 32) + attn = LCT(64, 8) + y = attn(x) + print(y.shape) + +``` +*** + +### 42. Gated Channel Transformation Usage + +#### 42.1. Paper + +[Gated Channel Transformation for Visual Recognition (CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_Gated_Channel_Transformation_for_Visual_Recognition_CVPR_2020_paper.pdf) + +#### 42.2. Overview + +![](./model/img/GCT.png) + +#### 42.3. Usage Code + +```python +from model.attention.GCTAttention import GCT +import torch + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gct_layer = GCT(64) + output = gct_layer(input) + print(output.shape) + +``` +*** + +### 43. Gaussian Context Attention Usage + +#### 43.1. Paper + +[Gaussian Context Transformer (CVPR 2021)](https://openaccess.thecvf.com//content/CVPR2021/papers/Ruan_Gaussian_Context_Transformer_CVPR_2021_paper.pdf) + +#### 43.2. Overview + +![](./model/img/GaussianCA.png) + +#### 43.3. Usage Code + +```python +from model.attention.GaussianAttention import GCA +import torch + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gca_layer = GCA(64) + output = gca_layer(input) + print(output.shape) + ``` *** diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc index f85fe00..cfe1928 100644 Binary files a/model/__pycache__/__init__.cpython-38.pyc and b/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/attention/AAAttention.py b/model/attention/AAAttention.py new file mode 100644 index 0000000..d771044 --- /dev/null +++ b/model/attention/AAAttention.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AugmentedConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dk, dv, Nh, shape=0, relative=False, stride=1): + super(AugmentedConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.dk = dk + self.dv = dv + self.Nh = Nh + self.shape = shape + self.relative = relative + self.stride = stride + self.padding = (self.kernel_size - 1) // 2 + + assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1" + assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)" + assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)" + assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed." + + self.conv_out = nn.Conv2d(self.in_channels, self.out_channels - self.dv, self.kernel_size, stride=stride, padding=self.padding) + + self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size, stride=stride, padding=self.padding) + + self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1) + + if self.relative: + self.key_rel_w = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True)) + self.key_rel_h = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True)) + + def forward(self, x): + # Input x + # (batch_size, channels, height, width) + # batch, _, height, width = x.size() + + # conv_out + # (batch_size, out_channels, height, width) + conv_out = self.conv_out(x) + batch, _, height, width = conv_out.size() + + # flat_q, flat_k, flat_v + # (batch_size, Nh, height * width, dvh or dkh) + # dvh = dv / Nh, dkh = dk / Nh + # q, k, v + # (batch_size, Nh, height, width, dv or dk) + flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh) + logits = torch.matmul(flat_q.transpose(2, 3), flat_k) + if self.relative: + h_rel_logits, w_rel_logits = self.relative_logits(q) + logits += h_rel_logits + logits += w_rel_logits + weights = F.softmax(logits, dim=-1) + + # attn_out + # (batch, Nh, height * width, dvh) + attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) + attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) + # combine_heads_2d + # (batch, out_channels, height, width) + attn_out = self.combine_heads_2d(attn_out) + attn_out = self.attn_out(attn_out) + return torch.cat((conv_out, attn_out), dim=1) + + def compute_flat_qkv(self, x, dk, dv, Nh): + qkv = self.qkv_conv(x) + N, _, H, W = qkv.size() + q, k, v = torch.split(qkv, [dk, dk, dv], dim=1) + q = self.split_heads_2d(q, Nh) + k = self.split_heads_2d(k, Nh) + v = self.split_heads_2d(v, Nh) + + dkh = dk // Nh + q = q * (dkh ** -0.5) + flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W)) + flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W)) + flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W)) + return flat_q, flat_k, flat_v, q, k, v + + def split_heads_2d(self, x, Nh): + batch, channels, height, width = x.size() + ret_shape = (batch, Nh, channels // Nh, height, width) + split = torch.reshape(x, ret_shape) + return split + + def combine_heads_2d(self, x): + batch, Nh, dv, H, W = x.size() + ret_shape = (batch, Nh * dv, H, W) + return torch.reshape(x, ret_shape) + + def relative_logits(self, q): + B, Nh, dk, H, W = q.size() + q = torch.transpose(q, 2, 4).transpose(2, 3) + + rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w") + rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h") + + return rel_logits_h, rel_logits_w + + def relative_logits_1d(self, q, rel_k, H, W, Nh, case): + rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) + rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1)) + rel_logits = self.rel_to_abs(rel_logits) + + rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W)) + rel_logits = torch.unsqueeze(rel_logits, dim=3) + rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)) + + if case == "w": + rel_logits = torch.transpose(rel_logits, 3, 4) + elif case == "h": + rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5) + rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W)) + return rel_logits + + def rel_to_abs(self, x): + B, Nh, L, _ = x.size() + + col_pad = torch.zeros((B, Nh, L, 1)).to(x) + x = torch.cat((x, col_pad), dim=3) + + flat_x = torch.reshape(x, (B, Nh, L * 2 * L)) + flat_pad = torch.zeros((B, Nh, L - 1)).to(x) + flat_x_padded = torch.cat((flat_x, flat_pad), dim=2) + + final_x = torch.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1)) + final_x = final_x[:, :, :L, L - 1:] + return final_x + +if __name__ == "__main__": + input = torch.randn((16, 3, 32, 32)) + augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16) + output = augmented_conv(input) + print(output.shape) diff --git a/model/attention/FCA.py b/model/attention/FCA.py new file mode 100644 index 0000000..20dfa7e --- /dev/null +++ b/model/attention/FCA.py @@ -0,0 +1,126 @@ +import math +import torch +import torch.nn as nn + + +def get_freq_indices(method): + assert method in ['top1','top2','top4','top8','top16','top32', + 'bot1','bot2','bot4','bot8','bot16','bot32', + 'low1','low2','low4','low8','low16','low32'] + num_freq = int(method[3:]) + if 'top' in method: + all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1] + all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3] + mapper_x = all_top_indices_x[:num_freq] + mapper_y = all_top_indices_y[:num_freq] + elif 'low' in method: + all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4] + all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3] + mapper_x = all_low_indices_x[:num_freq] + mapper_y = all_low_indices_y[:num_freq] + elif 'bot' in method: + all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6] + all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3] + mapper_x = all_bot_indices_x[:num_freq] + mapper_y = all_bot_indices_y[:num_freq] + else: + raise NotImplementedError + return mapper_x, mapper_y + + +class MultiSpectralAttentionLayer(torch.nn.Module): + def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'): + super(MultiSpectralAttentionLayer, self).__init__() + self.reduction = reduction + self.dct_h = dct_h + self.dct_w = dct_w + + mapper_x, mapper_y = get_freq_indices(freq_sel_method) + self.num_split = len(mapper_x) + mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] + mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y] + # make the frequencies in different sizes are identical to a 7x7 frequency space + # eg, (2,2) in 14x14 is identical to (1,1) in 7x7 + + self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + n,c,h,w = x.shape + x_pooled = x + if h != self.dct_h or w != self.dct_w: + x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w)) + # If you have concerns about one-line-change, don't worry. :) + # In the ImageNet models, this line will never be triggered. + # This is for compatibility in instance segmentation and object detection. + y = self.dct_layer(x_pooled) + + y = self.fc(y).view(n, c, 1, 1) + return x * y.expand_as(x) + + +class MultiSpectralDCTLayer(nn.Module): + """ + Generate dct filters + """ + def __init__(self, height, width, mapper_x, mapper_y, channel): + super(MultiSpectralDCTLayer, self).__init__() + + assert len(mapper_x) == len(mapper_y) + assert channel % len(mapper_x) == 0 + + self.num_freq = len(mapper_x) + + # fixed DCT init + self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) + + # fixed random init + # self.register_buffer('weight', torch.rand(channel, height, width)) + + # learnable DCT init + # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) + + # learnable random init + # self.register_parameter('weight', torch.rand(channel, height, width)) + + # num_freq, h, w + + def forward(self, x): + assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape)) + # n, c, h, w = x.shape + + x = x * self.weight + + result = torch.sum(x, dim=[2,3]) + return result + + def build_filter(self, pos, freq, POS): + result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) + if freq == 0: + return result + else: + return result * math.sqrt(2) + + def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel): + dct_filter = torch.zeros(channel, tile_size_x, tile_size_y) + + c_part = channel // len(mapper_x) + + for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)): + for t_x in range(tile_size_x): + for t_y in range(tile_size_y): + dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y) + + return dct_filter + + +if __name__ == "__main__": + input = torch.randn(32, 128, 64, 64) # (b, c, h, w) + fca_layer = MultiSpectralAttentionLayer(channel = 128, dct_h = 64, dct_w = 64, reduction = 16, freq_sel_method = 'top16') + output = fca_layer(input) + print(output.shape) diff --git a/model/attention/GCAttention.py b/model/attention/GCAttention.py new file mode 100644 index 0000000..8f6f53c --- /dev/null +++ b/model/attention/GCAttention.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + + +class GCModule(nn.Module): + def __init__(self, channel, reduction=16): + super().__init__() + self.conv = nn.Conv2d(channel, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + self.transform = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=1), + nn.LayerNorm([channel // reduction, 1, 1]), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, kernel_size=1) + ) + + def context_modeling(self, x): + b, c, h, w = x.shape + input_x = x + input_x = input_x.reshape(b, c, h * w) + context = self.conv(x) + context = context.reshape(b, 1, h * w).transpose(1, 2) + out = torch.matmul(input_x, context) + out = out.reshape(b, c, 1, 1) + return out + + def forward(self, x): + context = self.context_modeling(x) + y = self.transform(context) + return x + y + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gc_layer = GCModule(64) + output = gc_layer(input) + print(output.shape) \ No newline at end of file diff --git a/model/attention/GCTAttention.py b/model/attention/GCTAttention.py new file mode 100644 index 0000000..f91f945 --- /dev/null +++ b/model/attention/GCTAttention.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F +import math +from torch import nn + + +class GCT(nn.Module): + + def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False): + super(GCT, self).__init__() + + self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1)) + self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) + self.epsilon = epsilon + self.mode = mode + self.after_relu = after_relu + + def forward(self, x): + + if self.mode == 'l2': + embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha + norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5) + + elif self.mode == 'l1': + if not self.after_relu: + _x = torch.abs(x) + else: + _x = x + embedding = _x.sum((2,3), keepdim=True) * self.alpha + norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon) + else: + print('Unknown mode!') + + gate = 1. + torch.tanh(embedding * norm + self.beta) + + return x * gate + + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gct_layer = GCT(64) + output = gct_layer(input) + print(output.shape) \ No newline at end of file diff --git a/model/attention/GaussianAttention.py b/model/attention/GaussianAttention.py new file mode 100644 index 0000000..d342d97 --- /dev/null +++ b/model/attention/GaussianAttention.py @@ -0,0 +1,26 @@ +import torch +from torch import nn + + +class GCA(nn.Module): + def __init__(self, channels, c=2, eps=1e-5): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.eps = eps + self.c = c + + def forward(self, x): + y = self.avgpool(x) + mean = y.mean(dim=1, keepdim=True) + mean_x2 = (y ** 2).mean(dim=1, keepdim=True) + var = mean_x2 - mean ** 2 + y_norm = (y - mean) / torch.sqrt(var + self.eps) + y_transform = torch.exp(-(y_norm ** 2 / 2 * self.c)) + return x * y_transform.expand_as(x) + + +if __name__ == "__main__": + input = torch.randn(16, 64, 32, 32) + gca_layer = GCA(64) + output = gca_layer(input) + print(output.shape) \ No newline at end of file diff --git a/model/attention/LCTAttention.py b/model/attention/LCTAttention.py new file mode 100644 index 0000000..2bb0fc6 --- /dev/null +++ b/model/attention/LCTAttention.py @@ -0,0 +1,34 @@ +import torch +from torch import nn + + +class LCT(nn.Module): + def __init__(self, channels, groups, eps=1e-5): + super().__init__() + assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups" + self.groups = groups + self.channels = channels + self.eps = eps + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.w = nn.Parameter(torch.ones(channels)) + self.b = nn.Parameter(torch.zeros(channels)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + batch_size = x.shape[0] + y = self.avgpool(x).view(batch_size, self.groups, -1) + mean = y.mean(dim=-1, keepdim=True) + mean_x2 = (y ** 2).mean(dim=-1, keepdim=True) + var = mean_x2 - mean ** 2 + y_norm = (y - mean) / torch.sqrt(var + self.eps) + y_norm = y_norm.reshape(batch_size, self.channels, 1, 1) + y_norm = self.w.reshape(1, -1, 1, 1) * y_norm + self.b.reshape(1, -1, 1, 1) + y_norm = self.sigmoid(y_norm) + return x * y_norm.expand_as(x) + + +if __name__ == "__main__": + x = torch.randn(16, 64, 32, 32) + attn = LCT(64, 8) + y = attn(x) + print(y.shape) \ No newline at end of file diff --git a/model/attention/__pycache__/AAAttention.cpython-38.pyc b/model/attention/__pycache__/AAAttention.cpython-38.pyc new file mode 100644 index 0000000..11d1cf8 Binary files /dev/null and b/model/attention/__pycache__/AAAttention.cpython-38.pyc differ diff --git a/model/attention/__pycache__/Axial_attention.cpython-38.pyc b/model/attention/__pycache__/Axial_attention.cpython-38.pyc new file mode 100644 index 0000000..35e8dc6 Binary files /dev/null and b/model/attention/__pycache__/Axial_attention.cpython-38.pyc differ diff --git a/model/attention/__pycache__/FCA.cpython-38.pyc b/model/attention/__pycache__/FCA.cpython-38.pyc new file mode 100644 index 0000000..8f03c65 Binary files /dev/null and b/model/attention/__pycache__/FCA.cpython-38.pyc differ diff --git a/model/attention/__pycache__/GCAttention.cpython-38.pyc b/model/attention/__pycache__/GCAttention.cpython-38.pyc new file mode 100644 index 0000000..6ff96fc Binary files /dev/null and b/model/attention/__pycache__/GCAttention.cpython-38.pyc differ diff --git a/model/attention/__pycache__/GCTAttention.cpython-38.pyc b/model/attention/__pycache__/GCTAttention.cpython-38.pyc new file mode 100644 index 0000000..4219c26 Binary files /dev/null and b/model/attention/__pycache__/GCTAttention.cpython-38.pyc differ diff --git a/model/attention/__pycache__/GaussianAttention.cpython-38.pyc b/model/attention/__pycache__/GaussianAttention.cpython-38.pyc new file mode 100644 index 0000000..8e329d2 Binary files /dev/null and b/model/attention/__pycache__/GaussianAttention.cpython-38.pyc differ diff --git a/model/attention/__pycache__/LCTAttention.cpython-38.pyc b/model/attention/__pycache__/LCTAttention.cpython-38.pyc new file mode 100644 index 0000000..1140a28 Binary files /dev/null and b/model/attention/__pycache__/LCTAttention.cpython-38.pyc differ diff --git a/model/img/AAAttention.png b/model/img/AAAttention.png new file mode 100644 index 0000000..ec003ae Binary files /dev/null and b/model/img/AAAttention.png differ diff --git a/model/img/FCANet.png b/model/img/FCANet.png new file mode 100644 index 0000000..b9ba864 Binary files /dev/null and b/model/img/FCANet.png differ diff --git a/model/img/GCNet.png b/model/img/GCNet.png new file mode 100644 index 0000000..29b53fd Binary files /dev/null and b/model/img/GCNet.png differ diff --git a/model/img/GCT.png b/model/img/GCT.png new file mode 100644 index 0000000..c89e772 Binary files /dev/null and b/model/img/GCT.png differ diff --git a/model/img/GaussianCA.png b/model/img/GaussianCA.png new file mode 100644 index 0000000..cad72c6 Binary files /dev/null and b/model/img/GaussianCA.png differ diff --git a/model/img/LCTAttention.png b/model/img/LCTAttention.png new file mode 100644 index 0000000..baa94f5 Binary files /dev/null and b/model/img/LCTAttention.png differ diff --git a/model/img/SimAttention.png b/model/img/SimAttention.png new file mode 100644 index 0000000..f4ee921 Binary files /dev/null and b/model/img/SimAttention.png differ