Advertisement

每日Attention学习10——Scale-Aware Modulation

阅读量:
模块出处

[ICCV 23] [link] [

复制代码
](https://github.com/AFeng-x/SMT) Scale-Aware Modulation Meet Transformer

* * *

##### 模块名称

Scale-Aware Modulation (SAM)

* * *

##### 模块作用

改进的自注意力

* * *

##### 模块结构
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-19/keLPFv4w82pMZ15RcADGWbdalhIs.png)

* * *

##### 模块代码
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SAM(nn.Module):
def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None,
                   attn_drop=0., proj_drop=0., expand_ratio=2):
    super().__init__()
    self.ca_attention = 1
    self.dim = dim
    self.ca_num_heads = ca_num_heads
    self.sa_num_heads = sa_num_heads
    assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."
    assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}."
    self.act = nn.GELU()
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)
    self.split_groups=self.dim//ca_num_heads
    self.v = nn.Linear(dim, dim, bias=qkv_bias)
    self.s = nn.Linear(dim, dim, bias=qkv_bias)
    for i in range(self.ca_num_heads):
        local_conv = nn.Conv2d(dim//self.ca_num_heads, dim//self.ca_num_heads, kernel_size=(3+i*2), padding=(1+i), stride=1, groups=dim//self.ca_num_heads)
        setattr(self, f"local_conv_{i + 1}", local_conv)
    self.proj0 = nn.Conv2d(dim, dim*expand_ratio, kernel_size=1, padding=0, stride=1, groups=self.split_groups)
    self.bn = nn.BatchNorm2d(dim*expand_ratio)
    self.proj1 = nn.Conv2d(dim*expand_ratio, dim, kernel_size=1, padding=0, stride=1)

def forward(self, x, H, W):
    # In
    B, N, C = x.shape
    v = self.v(x)
    s = self.s(x).reshape(B, H, W, self.ca_num_heads, C//self.ca_num_heads).permute(3, 0, 4, 1, 2)

多头混合卷积操作

遍历从i=1到self.ca_num_heads:
获取局部卷积层local_conv:
将当前特征通过该局部卷积层进行处理,并将其重塑为[B×split_groups×H×W]的形状。
初始化输出张量为空:
如果这是第一个头:
输出张量等于当前处理结果。
否则:
输出张量等于前一个输出与当前处理结果在第二个维度上的拼接。
最后:
将最终输出重塑为[B×C×H×W]的形状。

Scale-Aware Aggregation (SAA)

s_out经self.proj0处理后经BN层激活后经self.proj1映射得到
self.modulator赋值为s_out
经过维度重塑和索引重新排列后得到新的s_out
x与v相乘得到最终结果

复制代码
    # Out
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

if __name__ == '__main__':
x = torch.randn([3, 1024, 256])  # B, N, C
sam = SAM(dim=256)
out = sam(x, H=32, W=32)  # H=N*W
print(out.shape)  # 3, 1024, 256

python3
复制代码
* * *

##### 原文表述

我们提出了一种新颖的卷积调制,称为尺度感知调制 (SAM),它包含两个新模块:多头混合卷积 (MHMC) 和尺度感知聚合 (SAA)。MHMC 模块旨在增强感受野并同时捕获多尺度特征。SAA 模块旨在有效地聚合不同头部之间的特征,同时保持轻量级架构。

全部评论 (0)

还没有任何评论哟~