Advertisement

【GAM】《Global Attention Mechanism:Retain Information to Enhance Channel-Spatial Interactions》

阅读量:
在这里插入图片描述

arXiv-2021


文章目录

  • 1 背景介绍与动机分析
  • 2 相关研究综述
  • 3 方法优势及贡献
  • 4 方法细节描述
  • 5 实验研究
    • 5.1 数据集选择与评估指标

    • 5.2 基于CIFAR-100和ImageNet数据集的分类性能分析

    • 5.3 基线分析研究

    • 6 Conclusion(own)


1 Background and Motivation

在提升pointing technique的有效性方面表现出色的attention机制能够有效地捕捉关键信息。
然而previous approaches未能充分意识到在both channel和spatial维度上保持信息的重要性以促进cross-dimension interactions.

作者开发了名为 Global Attention Mechanism 的技术细节,并通过该机制建立 Channel 和 Spatial 之间的互动关系

  • SENet
  • CBAM
  • BAM
  • TAM(channel,height,width)Triplet attention module

作者也是 capturing significant features across all three dimensions

3 Advantages / Contributions

  • 开发了 novel GAM 注意力机制,并更进一步地提升 channel-spatial interactions 的表现。
  • 该模型在 standard benchmarks 框架下进行了评估,并在 CIFAR-100 和 ImageNet-1K 数据集上取得了令人满意的实验结果。

4 Method

整体流程

在这里插入图片描述

Attention

Attention

在这里插入图片描述

下面看看更多细节的地方:

在这里插入图片描述

相比于 CBAM

做 channel attention 的时候保留了 H 和 W

做 spatial attention 的时候保留了 C

参数规模过大,在处理具有空间关注机制(spatial attention)的情况下,则引入了通道分组(channel groups)和通道交错(channel shuffle)以优化计算效率

在这里插入图片描述

参数量暴增,看看代码

复制代码
    import numpy as np
    import torch
    from torch import nn
    from torch.nn import init
     
    class GAMAttention(nn.Module):
       #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
    def __init__(self, c1, c2, group=True,rate=4):
        super(GAMAttention, self).__init__()
        
        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), 
            nn.BatchNorm2d(int(c1 /rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), 
            nn.BatchNorm2d(c2)
        )
     
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        x = x * x_channel_att
     
        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle 
        out = x * x_spatial_att
        return out  
     
    def channel_shuffle(x, groups=2):
        B, C, H, W = x.size()
        out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
        out=out.view(B, C, H, W) 
        return out
    
    
    py
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/lxJv6TRZpQWaIr8FSk5cU4EKzusG.png)

可以观察到在 channel attention模块中未采用sigmoid函数,在整个GAM模块中同样未引入残差连接机制

5 Experiments

5.1 Datasets and Metrics

  • CIFAR100
  • ImageNet 1K

top1 and top5 error

5.2 Classification on CIFAR-100 and ImageNet datasets

在这里插入图片描述

效果确实是最好,但是参数量也是激增呀

5.3 Ablation studies

消融了下 channel attention 和 spatial attention

在这里插入图片描述

二合一比较猛

在这里插入图片描述

It is potentially the case that max-pooling could contribute negatively to spatial attention, depending on the specific neural architecture.

6 Conclusion(own)

Enhance the interactions between channels and spatial features. When processing channels, we retained the spatial information. Similarly, when handling spatial features, we preserved the channel data.

全部评论 (0)

还没有任何评论哟~