Advertisement

每日Attention学习13——Adaptive Feature Fusion

阅读量:
模块出处

[AAAI 23] [link] [

复制代码
](https://github.com/HUuxiaobin/HitNet) High-Resolution Iterative Feedback Network for Camoufaged Object Detection

* * *

##### 模块名称

Adaptive Feature Fusion

* * *

##### 模块作用

多尺度特征融合

* * *

##### 模块结构
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-02/tpALJyFeBzGTEDx5j0gWqmnsvP9h.jpeg)

* * *

##### 模块代码
复制代码
import torch
import torch.nn as nn

class AFF(nn.Module):
def __init__(self, ch_in=32, reduction=16):
    super(AFF, self).__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Sequential(
        nn.Linear(ch_in, ch_in // reduction, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(ch_in // reduction, ch_in, bias=False),
        nn.Sigmoid()
    )
    self.fc_wight = nn.Sequential(
        nn.Linear(ch_in, ch_in // reduction, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(ch_in // reduction, 1, bias=False),
        nn.Sigmoid()
    )

def forward(self, x_h, x_l):
    # 增强high level特征, 在Squeeze and Excite上引入了一个额外的h_weight($\alpha_1$)
    b, c, _, _ = x_h.size()
    y_h = self.avg_pool(x_h).view(b, c)
    h_weight = self.fc_wight(y_h)
    y_h = self.fc(y_h).view(b, c, 1, 1)
    x_fusion_h = x_h * y_h.expand_as(x_h)
    x_fusion_h = torch.mul(x_fusion_h, h_weight.view(b, 1, 1, 1))

提升基础层特征,在Squeeze and Excite架构中增加了额外的加权参数l_weight(\alpha_2)。具体而言:

  • 首先获取输入张量x_l的尺寸信息b(batch size)、c(通道数)、以及空间维度信息(_ , _)= x_l.size()。

  • 通过全局平均池化操作将x_l转换为二维特征向量y_l,并对其进行全连接层处理以获得加权参数l_weight(\alpha_2)。

  • 将重塑后的y_l特征图与原始输入x_l进行按元素乘法操作生成初步融合特征x_fusion_initial。

  • 最后将重塑后的y_l特征图与初步融合特征x_fusion_initial进行深度加权融合以生成最终融合特征x_fusion_final。

    复制代码
      # 多级特征融合
      x_fusion = x_fusion_h + x_fusion_l
      return x_fusion

if name == 'main':
attention_forests = AttentionForest()
high_feature_map = torch.randn([1, 32, 16, 16], dtype=torch.float32)
low_feature_map = torch.randn([1, 32//4, (8//2), (8//2)], dtype=torch.float32)
output_feature_map = attention_forests(
high_feature_map,
low_feature_map
)
print(f"输出特征图大小:{output_feature_map.shape}")

复制代码

全部评论 (0)

还没有任何评论哟~