Advertisement

深度学习论文: MobileNeXt: Rethinking Bottleneck Structure for Efficient Mobile Network Design及其PyTorch实现

阅读量:

深度学习领域的研究:基于MobileNeXt框架探讨高效移动网络设计及其PyTorch实现。
基于MobileNeXt框架探讨高效移动网络设计。
PDF:https://arxiv.org/pdf/2007.02269.pdf
PyTorch:https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

通过对比分析MobileNeXt和MobileNetV2在ImageNet上的分类性能表现, 经对比可以看出, MobileNeXt的优势依然较为显著。

在这里插入图片描述

2 Sandglass Block

本研究中对ResNet、MobileNetV2以及我们提出的MobileNeXt进行了性能对比分析

在这里插入图片描述

作者深入探讨了MobileNetV2网络中的bottleneck模块设计,并对其潜在问题进行了剖析:例如,在倒置残差模块中采用先提升维度后降低维度的方式会导致梯度跨层传输效率下降;此外,在这一过程中将特征从高维度空间压缩至低维度空间将会造成信息损失;同时这也会导致梯度混淆现象出现(具体表现为梯度消失或爆炸)

该 Sandglass Block 借助以下特性得以实现

在这里插入图片描述

Sandglass Block:深度方向上(带有Relu)+ 点方向上(不带Relu)+ 点方向上(带有Relu)+ 深度方向上(不带Relu)。

PyTorch代码

复制代码
    class SandglassBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expansion_factor=6):
        super(SandglassBlock, self).__init__()
        self.stride = stride
        mid_channels = in_channels // expansion_factor
        self.identity = stride == 1 and in_channels == out_channels
    
        self.bottleneck = nn.Sequential(
            Conv3x3BNReLU(in_channels, in_channels, 1, groups=in_channels),
            Conv1x1BN(in_channels, mid_channels),
            Conv1x1BNReLU(mid_channels, out_channels),
            Conv3x3BN(out_channels, out_channels, stride, groups=out_channels),
        )
    
    def forward(self, x):
        out = self.bottleneck(x)
        if self.identity:
            return out + x
        else:
            return out
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3 MobileNeXt Architecture

在这里插入图片描述

PyTorch代码

复制代码
    class MobileNetXt(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetXt,self).__init__()
    
        self.first_conv = Conv3x3BNReLU(3,32,2,groups=1)
    
        self.layer1 = self.make_layer(in_channels=32, out_channels=96, stride=2, expansion_factor=2, block_num=1)
        self.layer2 = self.make_layer(in_channels=96, out_channels=144, stride=1, expansion_factor=6, block_num=1)
        self.layer3 = self.make_layer(in_channels=144, out_channels=192, stride=2, expansion_factor=6, block_num=3)
        self.layer4 = self.make_layer(in_channels=192, out_channels=288, stride=2, expansion_factor=6, block_num=3)
        self.layer5 = self.make_layer(in_channels=288, out_channels=384, stride=1, expansion_factor=6, block_num=4)
        self.layer6 = self.make_layer(in_channels=384, out_channels=576, stride=2, expansion_factor=6, block_num=4)
        self.layer7 = self.make_layer(in_channels=576, out_channels=960, stride=1, expansion_factor=6, block_num=2)
        self.layer8 = self.make_layer(in_channels=960, out_channels=1280, stride=1, expansion_factor=6, block_num=1)
    
        self.avgpool = nn.AvgPool2d(kernel_size=7,stride=1)
        self.dropout = nn.Dropout(p=0.2)
        self.linear = nn.Linear(in_features=1280,out_features=num_classes)
    
    def make_layer(self, in_channels, out_channels, stride, expansion_factor, block_num):
        layers = []
        layers.append(SandglassBlock(in_channels, out_channels, stride,expansion_factor))
        for i in range(1, block_num):
            layers.append(SandglassBlock(out_channels,out_channels,1,expansion_factor))
        return nn.Sequential(*layers)
    
    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.first_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.dropout(x)
        out = self.linear(x)
        return out
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4 Experiments

4-1 和MobileNetV2比较

在这里插入图片描述

4-2 PTQ结果

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~