文章分享2:RepViT: Revisiting Mobile CNN From ViT Perspective
论文地址:https://arxiv.org/abs/2307.09283
代码地址:https://github.com/THU-MIG/RepViT?tab=readme-ov-file
引言
近年来
为了解决这一问题的研究者提出了一种名为RepViT的新架构;该架构是从轻量级ViT视角重新设计的轻量级CNN家族。本文旨在详细阐述设计理念、技术创新以及实验结果。
主要创新点
Repvit创新点主要在论文的第三章,块设计、宏设计和微设计。
A.块设计(block design)
(一)MobileNet Block
如图3(a)所示, MobileNetV3模块架构构成了 RepViT模块设计的基础框架,该架构通过巧妙的设计实现了高效的信息处理与特征提取。具体而言, MobileNetV3模块采用了两组关键组件进行信息处理:首先,模块中采用了一组1×1扩张卷积操作,其主要目标是增加通道数量,并显著提升了模型的表达能力;随后,通过一组深度可分步进卷积操作引入了一组投影层(projection layer),这一步骤的主要目标是实现各通道间的高效信息交互,即作为通道混合器(channel mixer)。在此基础上配置了一组深度可分步进卷积操作,其主要功能是融合空间信息特征;然而这种设计导致了一种结构性的问题:即令牌混合器与通道混合器之间存在耦合关系,在一定程度上制约了模型的高度灵活性与计算效率
(二)RepViT Block
为了缓解MobileNetV3模块的局限性问题,作者开发出了RepViT模块,其架构示于图 3 (b)所示。该模块的主要创新点在于独立开发了令牌聚合器与通道聚合器的具体实现方式:即通过将DW卷积从空间聚合层中迁移位置,使其能够单独执行空间信息处理功能;同时,还将可选的压缩与激励(Squeeze-and-Excitation, SE)层安置在DW卷积之后的位置,这一安排基于SE层依赖于空间信息交互这一前提条件做出。通过上述方法的成功实施,最终实现了两种聚合机制的有效分离。
此外, RepViT Block还应用了结构重参数化技术(structural reparameterization)。在训练阶段,该技术引入了多分支架构,有助于增强模型的学习能力;而当进行推理时,重参数化技术能够将多分支架构简化为单一架构,从而显著降低了计算开销,这对于移动设备上的资源受限环境尤为重要。经过这些优化改进,RepViT Block 在降低MobileNetV3-L模型延迟方面展现了显著的效果,尽管这可能带来一定程度的性能损失,但通过后续优化措施,其整体性能得到了明显提升

repvitblock
import torch.nn as nn
from timm.models.layers import SqueezeExcite
import torch
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class Residual(torch.nn.Module):
def __init__(self, m, drop=0.):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
@torch.no_grad()
def fuse(self):
if isinstance(self.m, Conv2d_BN):
m = self.m.fuse()
assert (m.groups == m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
m.weight += identity.to(m.weight.device)
return m
elif isinstance(self.m, torch.nn.Conv2d):
m = self.m
assert (m.groups != m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
m.weight += identity.to(m.weight.device)
return m
else:
return self
class RepVGGDW(torch.nn.Module):
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
self.dim = ed
self.bn = torch.nn.BatchNorm2d(ed)
def forward(self, x):
return self.bn((self.conv(x) + self.conv1(x)) + x)
@torch.no_grad()
def fuse(self):
conv = self.conv.fuse()
conv1 = self.conv1
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1])
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device),
[1, 1, 1, 1])
final_conv_w = conv_w + conv1_w + identity
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
bn = self.bn
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = conv.weight * w[:, None, None, None]
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
(bn.running_var + bn.eps) ** 0.5
conv.weight.data.copy_(w)
conv.bias.data.copy_(b)
return conv
class RepViTBlock(nn.Module):
def __init__(self, inp, oup, kernel_size=3, stride=2,hidden_dim=80, use_se=True, use_hs=True):
super(RepViTBlock, self).__init__()
if stride == 2:
self.token_mixer = nn.Sequential(
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
))
else:
self.token_mixer = nn.Sequential(
RepVGGDW(inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(oup, hidden_dim, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
))
def forward(self, x):
return self.channel_mixer(self.token_mixer(x))
B.宏设计 (Macro design)
(一)早期卷积(Early Convolutions)
在网络前端的 stem 部分中,MobileNetV3-L 和 ViTs 实施了不同的设计策略。MobileNetV3-L 嵌入了一个较为复杂的 stem 构架,该架构包含了 3×3 卷积层、深度可分离卷积层以及倒置瓶颈结构(如图4(a)所示)。这种复杂的设计架构虽然有效提升了处理高分辨率图像的能力但它也带来了明显的延迟瓶颈问题。为了缓解这一挑战 RepViT 借鉴了 ViTs 中采用的早期卷积技术方案(如图4(b)所示)并引入了两个步长为2的 3×3 卷积层作为 stem 部分。通过这种更为简洁的设计方案不仅将延迟降低到更短的时间区间(从1.01ms降至0.86ms)还显著提升了模型优化过程中的稳定性 最终使得模型在top-1分类任务中的准确率水平(从71.5%提升至73.9%)。
(二)更深的下采样层(Deeper Downsampling Layers)
在空间下采样操作中,ViTs一般情况下会采用单独设置的空间聚合层(patch merging layer),这种设计有助于提升网络深度的同时,在减少分辨率降低带来的信息损失方面表现出较强的适应性。相比之下,在倒置瓶颈块中仅通过步长为2的深度可分离卷积(DW conv)来实现下采样的MobileNetV3-L架构存在明显局限性:如图4(c)所示,在降低分辨率的同时可能导致网络整体深度不足从而引发信息丢失问题进而影响模型性能;而RepViT则通过更为先进的策略实现了对下采样的优化:如图4(d)所示首先分别采用深度可分离卷积(步长为2)进行空间降采样以及1×1卷积进行通道维度压缩随后在处理过程中不断加深各降采样层并引入前馈网络模块以期能够更好地存储潜在信息;这些改进措施最终使得该模型在top-1准确率上实现了显著提升至75.4%同时将延迟控制在了较为理想的0.96ms水平。
(三)简单分类器(Simple Classifier)
从分类器设计的角度看, MobileNetV3-L架构通过附加额外的1×1卷积层与全连接层,实现了对特征空间的扩展目标,如图4(e)所示。相比之下,在RepViT架构中,为了应对通道数量在后期阶段的增长需求,采用了更为简洁的分类器结构,即仅包含全局平均池化层(global average pooling layer)与全连接层组合而成,如图4(f)所示。尽管这种简化措施导致整体准确率出现微降(仅为0.6%),但显著降低了计算延迟(降至最低0.77ms)。
(四)整体阶段比例(Overall Stage Ratio)
该方法中的阶段比(stage ratio)定义了各层块数量的比例,在模型的设计与性能优化中具有重要意义。研究表明,在第三层采用较多的模块能够较好地平衡模型精度与计算效率。为此,本研究采用了1:1:7:1的层次划分,并将网络深度拓展至2-2-14-2结构以实现更深层的网络架构设计。这一设计优化使模型在Top-1测试集上的准确率达到76.9%,运行时延保持在0.91毫秒水平
如需进一步了解,请参考论文第三章中的相关内容。作者在第三章对相关内容进行了深入阐述。

实验
实验主要基于 ImageNet - 1K 数据集开展图像分类研究,在该数据集中设置了丰富的分类场景与挑战项。具体而言,在多个下游 tasks中进行了实验验证:针对这些 downstream tasks(如目标检测与实例分割)均采用了MS COCO 2017数据集作为基准;对应的语义分割则采用ADE20K数据集。
所有模型均在NVIDIA RTX 3090显卡上得到训练,并且在 iPhone 12(iOS 16)、Macbook M1 Pro 等移动设备上执行延迟测试。其中的模型编译过程采用了 Core ML Tools 工具。



