Advertisement

深度学习论文:Learning Spatial Fusion for Single-Shot Object Detection及其PyTorch实现

阅读量:

Extracting Spatio-Temporal Information Fusion for Object-Level Single-Shot Detection

1 概述

本文提出了一种新型的数据驱动自适应空间特征融合机制...通过学习空间冲突信息来抑制梯度反传的一致性问题...从而增强了对比例关系的鲁棒性...这有助于提升目标检测的整体性能

在这里插入图片描述

2 自适应特征融合(ASFF)

在这里插入图片描述

2-1 特征尺寸调整(Feature Resizing)

对于需要进行上采样的层而言,在试图获得ASFF3时,则需依次采取以下措施:首先对level1的特征图实施1×1卷积操作以调整其通道数至与level3一致,并随后通过插值放大尺寸以便使其与level3特征图尺寸相匹配。而对于那些需要进行下采样的层来说,在试图生成ASFF1时,则有两种不同的操作路径:针对从level2到level1的操作只需应用一个3×3且步长为2的卷积即可实现目标;而从level3到level1的操作则需在上述基础上增加一个步长为2的最大池化层作为辅助操作步骤才能完成目标转换。

2-2 自适应融合(Adaptive Fusion)

在这里插入图片描述

权重参数α、β、γ则经由 resize 后的 level₁ 至 level₃ 特征图进行 1 \times 1 卷积运算得到。随后这些参数经 softmax 函数处理后将在区间 [0, 1] 内取值,并且其总和等于 1

3 实验对比

3-1 与concat, elewise_sum 对比

在这里插入图片描述

3-2 加入其他目标检测增强策略

在这里插入图片描述

[43] Zhi Zhang et al. propose a Bag of freebies approach to enhance the training efficiency of object detection neural networks. Their method is published in the arXiv preprint repository under the identifier arXiv:1902.04103 in 2019.
clickable link here
[38] Jiaqi Wang and colleagues developed a region proposal method based on guided anchoring for their work presented at CVPR in 2019.
[41] Jiahui Yu et al. introduced the Unitbox framework as a state-of-the-art solution for object detection tasks during the ACMM conference in 2016.

4 ASFF可视化

在这里插入图片描述

PyTorch代码:

复制代码
    import torch
    import torch.nn as nn
    import torchvision
    
    def Conv1x1BnRelu(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )
    
    def upSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest')
    )
    
    def upSampling2(in_channels,out_channels):
    return nn.Sequential(
        upSampling1(in_channels,out_channels),
        nn.Upsample(scale_factor=2, mode='nearest'),
    )
    
    def downSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )
    
    def downSampling2(in_channels,out_channels):
    return nn.Sequential(
        nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
        downSampling1(in_channels=in_channels, out_channels=out_channels),
    )
    
    class ASFF(nn.Module):
    def __init__(self, level, channel1, channel2, channel3, out_channel):
        super(ASFF, self).__init__()
        self.level = level
        funsed_channel = 8
    
        if self.level == 1:
            # level = 1:
            self.level2_1 = downSampling1(channel2,channel1)
            self.level3_1 = downSampling2(channel3,channel1)
    
            self.weight1 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel1, funsed_channel)
    
            self.expand_conv = Conv1x1BnRelu(channel1,out_channel)
    
        if self.level == 2:
            #  level = 2:
            self.level1_2 = upSampling1(channel1,channel2)
            self.level3_2 = downSampling1(channel3,channel2)
    
            self.weight1 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel2, funsed_channel)
    
            self.expand_conv = Conv1x1BnRelu(channel2, out_channel)
    
        if self.level == 3:
            #  level = 3:
            self.level1_3 = upSampling2(channel1,channel3)
            self.level2_3 = upSampling1(channel2,channel3)
    
            self.weight1 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel3, funsed_channel)
    
            self.expand_conv = Conv1x1BnRelu(channel3, out_channel)
    
        self.weight_level = nn.Conv2d(funsed_channel * 3, 3, kernel_size=1, stride=1, padding=0)
    
        self.softmax = nn.Softmax(dim=1)
    
    
    def forward(self, x, y, z):
        if self.level == 1:
            level_x = x
            level_y = self.level2_1(y)
            level_z = self.level3_1(z)
    
        if self.level == 2:
            level_x = self.level1_2(x)
            level_y = y
            level_z = self.level3_2(z)
    
        if self.level == 3:
            level_x = self.level1_3(x)
            level_y = self.level2_3(y)
            level_z = z
    
        weight1 = self.weight1(level_x)
        weight2 = self.weight2(level_y)
        weight3 = self.weight3(level_z)
    
        level_weight = torch.cat((weight1, weight2, weight3), 1)
        weight_level = self.weight_level(level_weight)
        weight_level = self.softmax(weight_level)
    
        fused_level = level_x * weight_level[:,0,:,:] + level_y * weight_level[:,1,:,:] + level_z * weight_level[:,2,:,:]
        out = self.expand_conv(fused_level)
        return out
    
    if __name__ == '__main__':
    model = ASFF(level=3, channel1=512, channel2=256, channel3=128, out_channel=128)
    print(model)
    
    x = torch.randn(1, 512, 16, 16)
    y = torch.randn(1, 256, 32, 32)
    z = torch.randn(1, 128, 64, 64)
    out = model(x,y,z)
    print(out.shape)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

全部评论 (0)

还没有任何评论哟~