【GAM】《Global Attention Mechanism:Retain Information to Enhance Channel-Spatial Interactions》

arXiv-2021
文章目录
- 1 背景介绍与动机分析
- 2 相关研究综述
- 3 方法优势及贡献
- 4 方法细节描述
- 5 实验研究
-
5.1 数据集选择与评估指标
-
5.2 基于CIFAR-100和ImageNet数据集的分类性能分析
-
5.3 基线分析研究
-
6 Conclusion(own)
-
1 Background and Motivation
在提升pointing technique的有效性方面表现出色的attention机制能够有效地捕捉关键信息。
然而previous approaches未能充分意识到在both channel和spatial维度上保持信息的重要性以促进cross-dimension interactions.
作者开发了名为 Global Attention Mechanism 的技术细节,并通过该机制建立 Channel 和 Spatial 之间的互动关系
2 Related Work
- SENet
- CBAM
- BAM
- TAM(channel,height,width)Triplet attention module
作者也是 capturing significant features across all three dimensions
3 Advantages / Contributions
- 开发了 novel GAM 注意力机制,并更进一步地提升 channel-spatial interactions 的表现。
- 该模型在 standard benchmarks 框架下进行了评估,并在 CIFAR-100 和 ImageNet-1K 数据集上取得了令人满意的实验结果。
4 Method
整体流程

Attention
Attention

下面看看更多细节的地方:

相比于 CBAM
做 channel attention 的时候保留了 H 和 W
做 spatial attention 的时候保留了 C
参数规模过大,在处理具有空间关注机制(spatial attention)的情况下,则引入了通道分组(channel groups)和通道交错(channel shuffle)以优化计算效率

参数量暴增,看看代码
import numpy as np
import torch
from torch import nn
from torch.nn import init
class GAMAttention(nn.Module):
#https://paperswithcode.com/paper/global-attention-mechanism-retain-information
def __init__(self, c1, c2, group=True,rate=4):
super(GAMAttention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(c1, int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(c1 / rate), c1)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(c1 /rate)),
nn.ReLU(inplace=True),
nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3),
nn.BatchNorm2d(c2)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle
out = x * x_spatial_att
return out
def channel_shuffle(x, groups=2):
B, C, H, W = x.size()
out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
out=out.view(B, C, H, W)
return out
py

可以观察到在 channel attention模块中未采用sigmoid函数,在整个GAM模块中同样未引入残差连接机制
5 Experiments
5.1 Datasets and Metrics
- CIFAR100
- ImageNet 1K
top1 and top5 error
5.2 Classification on CIFAR-100 and ImageNet datasets


效果确实是最好,但是参数量也是激增呀
5.3 Ablation studies
消融了下 channel attention 和 spatial attention

二合一比较猛

It is potentially the case that max-pooling could contribute negatively to spatial attention, depending on the specific neural architecture.
6 Conclusion(own)
Enhance the interactions between channels and spatial features. When processing channels, we retained the spatial information. Similarly, when handling spatial features, we preserved the channel data.
