Advertisement

每日Attention学习4——Spatial Attention Module

阅读量:
模块出处

[link] [

复制代码
](https://github.com/iCVTEAM/CTDNet) [MM 21] Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection

* * *

##### 模块名称

Spatial Attention Module (SAM)

* * *

##### 模块作用

空间注意力

* * *

##### 模块结构
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/Sdg73W1Au0eXLoEtiIlrj8nGHVaP.jpeg)

* * *

##### 模块代码
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias)


class SAM(nn.Module):
def __init__(self, in_chan, out_chan):
    super(SAM, self).__init__()
    self.conv_atten = conv3x3(2, 1)
    self.conv = conv3x3(in_chan, out_chan)
    self.bn = nn.BatchNorm2d(out_chan)

def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    atten = torch.cat([avg_out, max_out], dim=1)
    atten = torch.sigmoid(self.conv_atten(atten))
    out = torch.mul(x, atten)
    out = F.relu(self.bn(self.conv(out)), inplace=True)
    return out


if __name__ == '__main__':
x = torch.randn([1, 256, 16, 16])
sam = SAM(in_chan=256, out_chan=64)
out = sam(x)
print(out.shape)  # 1, 64, 16, 16


python
复制代码
* * *

##### 原文表述

我们设计了空间注意力模块 (SAM),以有效地完善特征(见图 3)。我们首先沿通道轴使用平均和最大运算,分别生成两个不同的单通道空间图$S_{avg}$和$S_{max}$。然后,我们将它们连接起来,通过3×3卷积和sigmoid函数计算出空间注意力图。空间注意力图$M_{sa}$可以通过元素级相乘从空间维度对特征重新加权。最后,细化后的特征被送入3×3卷积层,将通道压缩至64。

全部评论 (0)

还没有任何评论哟~