深度学习论文: A ConvNet for the 2020s及其PyTorch实现
深度学习论文: A ConvNet for the 2020s及其PyTorch实现
A ConvNet for the 2020s
PDF: https://arxiv.org/pdf/2103.09950.pdf
PyTorch代码: https://github.com/shanglianlm0525/CvPytorch
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks
手把手教你改模型,把ResNet50从76.1一步步干到82.0。
Modernizing a ConvNet: a Roadmap

1 Training Techniques (76.1—>78.8)
采用最新的训练方法AdamW;包括Mixup、Cutmix、RandAugment、Random Erasing、Stochastic Depth和Label Smoothing等数据增强技术;以及优化参数配置。

2 Macro Design (78.8—>79.5)
宏观的结构调整
2-1 改变stage compute ratio (78.8—>79.4)
调整layers 0至3中的block数量比例,在传统的配置(即每层分别为[2^n]个block)基础上优化为新的分配方案:第一层维持2^2个block不变(即4),随后依次递增一层并加入2^k个block(其中k从第三位开始依次增加)。具体来说,在优化后的新结构中各层的block数量比例如下:第一层与第二层均为2^2=4个block的比例不变;第三层则增加到2^4=8个block;第四层回归到与第一层相同的比例2^2=4个block。对于更大规模的模型,则可参考Swin设计的结构(即每层分别为2倍于前一层)。
2-2 使用Patchify的stem (79.4—>79.5)
基于ViT架构开始阶段,在图像处理过程中首先会对图像进行划分处理。具体而言,在这一过程中系统性地将图片被划分为多个小区域(patch),每个区域随后会被独立处理以生成对应的token表示。相比之下,在传统ResNet架构中茎模块(stem layer)采用了较为简洁的设计策略:通过一个stride=2的7×7卷积操作后紧跟最大池化操作来实现初步特征提取功能。
这里借鉴了Swin-T的设计思路,并采用了stride=4的4x4卷积进行stem操作。这样不仅保证滑动窗口之间不重叠,在每次处理时也只需关注单个patch的信息。
    # 标准ResNet
    stem = nn.Sequential(
    nn.Conv2d(in_chans, dims[0], kernel_size=7, stride=2),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )
    
    # ConvNeXt
    stem = nn.Sequential(
    nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
    LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
    )
    
    
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
        3 ResNeXt-ify (79.5—>80.5)
参考了ResNeXt在计算资源与准确性之间的平衡,并采用"增加更多组的数量以扩大模型深度"的方法。同时,在瓶颈层中将3×3卷积替换为深度可分离卷积,并对网络进行扩展至96通道。
4 Inverted Bottleneck (80.5—>80.6)
传统的ResNet架构中所采用的瓶颈单元设计遵循"宽-窄-宽"的模式以降低运算复杂度。然而,在MobileNetV2模型中,则采用了"窄-宽-窄"类型的倒置瓶颈结构。这种设计特点使得倒置瓶颈单元能够在不同维数特征空间间的转换过程中有效规避单一维数被压缩所带来的信息损失。值得注意的是,在Transformer模型的设计中也采用了类似的多层感知机(MLP)架构:中间层全连接层的维数配置为其两端层维数的四倍。

5 Large Kernel Sizes (80.6—>80.6)
该方法采用了基于Swin-T的7×7卷积核,并在此基础上增加了相应的7×7卷积块。其主要目的是为了实现特征对齐。
5-1 Moving up depthwise conv layer (80.6—>79.9)
由于inverted bottleneck结构使得隐藏层的卷积操作被显著放大,在直接替换该模块时会导致参数规模显著提升。因此,在保留功能完整性的同时将dw conv的位置进行调整,并将其位置进行调整,并将其放置于inverted bottleneck结构的起始位置。通过这种调整方式实现运算开销减少约15%左右的同时也使整体性能表现得到保留甚至微有提升。然而这种优化措施可能会导致性能轻微下降
5-2 Increasing the kernel size (79.9—>80.6)
由3×3扩大为7×7时,模型性能表现持续增强。实验证明,在使用7×7尺寸的卷积层时,系统已达到一定程度的饱和状态。
6 Micro Design (80.6—>82.0)
其他一些微观的结构调整。

6-1 用GELU替换ReLU (80.6—>80.6)
GELU主要用于NLP任务,并可被视为一种光滑版本的ReLU。用GELU替代ReLU的主要原因在于统一基准比较,并未带来性能上的提升。
6-2 减少激活层数量 (80.6—>81.3)
在设计上效仿Transformer架构时,在每一个block内部仅在两个相邻的1x1卷积之间插入一个激活函数;其余位置未采用该机制。经过实验发现,在这种设置下实现了比传统方法更好的性能表现——取得了0.7%的性能提升效果
6-3 减少归一化层数量 (81.3—>81.4)
考虑到Transformer架构中BN的使用频率较低,在本研究中我们仅在第一个1\times 1卷积前设置了Batch Normalization层,并且在两个连续的1\times 1卷积操作之间未配置Batch Normalization模块。通过实验结果表明,在每一个block模块的起始处添加Batch Normalization并不会带来性能提升。
6-4 用LN替换BN (81.4—>81.5)
在Transformer架构中采用了Layer Normalization(LN),研究表明Batch Normalization(BN)可能对网络性能产生不利影响,并在此背景下建议替换所有类型的Batch Normalization(BN)为Layer Normalization(LN)。
6-5 单独的下采样层 (81.5—>82.0)
在标准ResNet架构中,通常采用的是一个步长为2(stride=2)且尺寸为3×3的标准卷积层作为其下采样模块。而对于包含残差结构的操作单元,则在其短路连接路径上采用了一个步长为2( stride = )且尺寸为1 × 1 的特殊设计。相比之下,在Swin Transformer架构中设计了独特的独立化下采样机制。为了模仿Swin Transformer特有的可分隔性下的采样特性,在此模块中我们采用了步长为 stride = 2 并结合尺寸为 ² 的标准卷积操作。然而这种做法可能导致训练过程中的不稳定性。因此在stem模块之后、每个这样的操作节点之前以及全局平均池化操作之后均添加了归一化层(Layer Normalization)来提升模型稳定性
    self.downsample_layers = nn.ModuleList() 
    # stem也可以看成下采样层,一起存到downsample_layers中,推理时通过index进行访问
    stem = nn.Sequential(
    nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
    LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
    )
    self.downsample_layers.append(stem)
    for i in range(3):
    downsample_layer = nn.Sequential(
            LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
            nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
    )
    self.downsample_layers.append(downsample_layer)
    # 由于网络结构是downsample-stage-downsample-stage的形式,
    # 所以stem和后面的下采样层中的LN是不会连在一起的
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
        PyTorch代码:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor
    
    
    class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
    
    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
    
    
    class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    
    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
    
        x = input + self.drop_path(x)
        return x
    
    
    class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """
    
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
    
    
    class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf
    
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    
    def __init__(self, in_chans=3, num_classes=1000,
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.,
                 ):
        super().__init__()
    
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)
    
        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]
    
        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)
    
        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)
    
    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
        