Advertisement

每日Attention学习17——Multi-Kernel Inverted Residual

阅读量:
模块出处

[ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation


模块名称

Multi-Kernel Inverted Residual (MKIR)


模块作用

超轻量编码器块


模块结构

在这里插入图片描述
在这里插入图片描述


模块特点
  • 类Inverted Residual操作。在MKDC之前,使用一组Conv-BN-ReLU提升通道数。在MKDC之后,使用一组Conv-BN降低或维持通道数。该操作最早在mobilenet v2中实现。
  • 使用不同kernel size的深度卷积以进行多尺度特征提取。深度卷积相比于标准卷积而言计算量更小。
  • 使用ReLU6激活函数替代ReLU。相比于ReLU,ReLU6会将结果约束在0~6之间,更适用于轻量级的低精度(如int8)计算。
  • 多个branch的特征使用add进行融合而非concat。concat会增大通道数带来更大的计算开销。
  • 融合特征使用channel shuffle以鼓励不同branch特征间的进一步交互,类似于1×1卷积。

模块代码
复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    def gcd(a, b):
    while b:
        a, b = b, a % b
    return a
    
    
    def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, 
               channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batchsize, -1, height, width) 
    return x
    
    
    class MultiKernelDepthwiseConv(nn.Module):
    def __init__(self, in_channels, kernel_sizes, stride, dw_parallel=True):
        super(MultiKernelDepthwiseConv, self).__init__()
        self.in_channels = in_channels
        self.dw_parallel = dw_parallel
        self.dwconvs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(self.in_channels, self.in_channels, kernel_size, stride, kernel_size // 2, groups=self.in_channels, bias=False),
                nn.BatchNorm2d(self.in_channels),
                nn.ReLU6(inplace=True)
            )
            for kernel_size in kernel_sizes
        ])
    
    def forward(self, x):
        outputs = []
        for dwconv in self.dwconvs:
            dw_out = dwconv(x)
            outputs.append(dw_out)
            if self.dw_parallel == False:
                x = x + dw_out
        return outputs
    
    
    class MKIR(nn.Module):
    def __init__(self, in_c, out_c, stride=1, expansion_factor=2, dw_parallel=True, add=True, kernel_sizes=[1,3,5]):
        super(MKIR, self).__init__()
        self.stride = stride
        self.in_c = in_c
        self.out_c = out_c
        self.kernel_sizes = kernel_sizes
        self.add = add
        self.n_scales = len(kernel_sizes)
        self.use_skip_connection = True if self.stride == 1 else False
        self.ex_c = int(self.in_c * expansion_factor)
        self.pconv1 = nn.Sequential(
            nn.Conv2d(self.in_c, self.ex_c, 1, 1, 0, bias=False), 
            nn.BatchNorm2d(self.ex_c),
            nn.ReLU6(inplace=True)
        )        
        self.multi_scale_dwconv = MultiKernelDepthwiseConv(self.ex_c, self.kernel_sizes, self.stride, dw_parallel=dw_parallel)
        if self.add == True:
            self.combined_channels = self.ex_c*1
        else:
            self.combined_channels = self.ex_c*self.n_scales
        self.pconv2 = nn.Sequential(
            nn.Conv2d(self.combined_channels, self.out_c, 1, 1, 0, bias=False), # 
            nn.BatchNorm2d(self.out_c),
        )
        if self.use_skip_connection and (self.in_c != self.out_c):
            self.conv1x1 = nn.Conv2d(self.in_c, self.out_c, 1, 1, 0, bias=False) 
    
    
    def forward(self, x):
        pout1 = self.pconv1(x)
        dwconv_outs = self.multi_scale_dwconv(pout1)
        if self.add == True:
            dout = 0
            for dwout in dwconv_outs:
                dout = dout + dwout
        else:
            dout = torch.cat(dwconv_outs, dim=1)
        dout = channel_shuffle(dout, gcd(self.combined_channels,self.out_c))
        out = self.pconv2(dout)
        if self.use_skip_connection:
            if self.in_c != self.out_c:
                x = self.conv1x1(x)
            return x + out
        else:
            return out
    
    
    if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    mkir = MKIR(in_c=64, out_c=128)
    out = mkir(x)
    print(out.shape)  # [1, 128, 44, 44]
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

全部评论 (0)

还没有任何评论哟~