Advertisement

【论文阅读】Vision Transformer

阅读量:

Vision Transformer

1. 模型介绍

在计算机视觉领域中,在大多数现有方法都基于CNN的固定架构的基础上,在CNN内部增加注意力机制模块或者将注意力机制模块用于替代其某些关键组件。研究表明,在深度学习模型设计过程中过度依赖仅基于卷积神经网络的架构并非最优选择。基于此理念,在本研究中我们提出了一种新的视觉 Transformer 模型(ViT),该模型仅依赖于Transformer架构同样可在图像分类等核心任务上展现出卓越性能。

受现代自然语言处理(NLP)领域中Transformer的成功应用启发,在ViT算法中尝试将传统的Transformer架构直接应用于图像处理,并仅对图像分类流程进行必要的优化与调整。具体而言,在ViT算法的设计过程中,首先将整个输入图像划分为均匀分布的小块区域;然后将这些小块区域经过线性嵌入处理后得到的序列作为输入传递给Transformer模型;最后通过监督学习方法对各个区域进行分类训练。

该算法在中型规模(例如ILSVRC 2012 Challenge)以及大规模(例如ImageNet 21k-way partition和JFT-300M)数据集上经过系统性的实验验证,并发现在这些场景下表现出了显著的效果。

  • Transformer相比传统的CNN结构,在平移不变性和局部判别能力方面存在一定缺陷。这导致在面对小规模的数据集时,在性能上往往难以与CNN结构达到相同的效果水平。具体而言,在采用中等规模的ImageNet进行Transformer模型训练时,在精度指标上通常会比ResNet架构低出大约5-10个百分点。
  • 在训练数据量较大时(即拥有充足的数据支撑),结果则会发生显著变化。通过利用大规模的数据集进行预训练,并结合迁移学习方法应用于其他相关数据集上,则可以在性能指标方面实现或超越当前研究领域的最佳水平(SOTA)。

2. 模型结构与实现

ViT算法的整体结构如 图1 所示。

图1 ViT算法结构示意图

图1 ViT算法结构示意图

2.1. 图像分块嵌入

值得注意的是,在Transformer架构中,输入通常被表示为二维矩阵的形式 (N,D) 。在这里面 N 代表sequence长度而 D 则代表每个向量所处的空间维度。为了实现这一目标,在ViT算法中需要采取一系列步骤:首先需要设法将 H \times W \times C 的三维图像空间转换为一个长度为 N 的一维向量,并通过嵌入层将其映射至高维空间形式 (N,D)

ViT中的具体实现方式如下:将输入的一幅H \times W \times C的空间图象通过特定变换转换为一个形状为N \times (P^2 \times C)的空间序列。这表示该序列可被视为一系列展开的小块。根据上述公式可知该序列包含由h=w=H/W= W/P个二维小块构成的小块集合,并且每个小块的空间维度为P×P。其中P代表每个小块的空间尺寸参数而C则代表通道数量参数。

然而,在当前情况下(此处指某个场景),每个图像块的空间维度为 (P^2 \times C);而为了满足需求(即达到所需的向量空间维数 D),我们需要对该图像块进行进一步处理——嵌入操作)。在该处理过程中(嵌入操作),我们仅需对单个空间大小为 (P^2 \times C) 的图像块执行一次线性变换即可完成这一目标。

上述对图像进行分块以及 Embedding 的具体方式如 图2 所示。

图2 图像分块嵌入示意图

图2 图像分块嵌入示意图

具体代码实现如下所示:

复制代码
    class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
    
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
    
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

2.2. 多头注意力

经将图像转换为 N \times (P^2 \times C) 序列后, 从而就可以将其被输入至 Transformer 结构以实现特征提取. 如 图3 所示

图3 多头注意力示意图

图3 多头注意力示意图

在Transformer架构中占据核心地位的是多头注意力机制,在这种机制中通常会设置多个头以增强模型的表示能力(如图4所示)。具体而言,在具有两个头的Multi-head Attention架构中(如图4所示),输入序列ai首先通过预处理层展开为一组特征向量,并被划分为查询块q{(i,1)}, q{(i,2)}和对应的键块k{(j,i)}, k{(j,i)}以及值块v{(j,i)}, v{(j,i)}}(其中j=1或2)。随后每个查询块q{^ (j,i}与对应的键块k{^ (j,i}之间进行注意力计算以获得注意力权重α{^ (j,i})。随后计算出注意力权重后将各权重与对应的值块v{^ (j,i}进行加权叠加以生成最终输出b{^ (j,i}(其中j=1或2);最后将所有头的结果沿着特征维度拼接起来并通过全连接层进一步变换得到最终输出结果b_i(其中 i=1,…N)

图4 多头注意力

图4 多头注意力

其中,使用 q^{(i,j)}k^{(i,j)}v^{(i,j)} 计算 b^{(i,j)}(i=1,2,…,N) 的方法是缩放点积注意力 (Scaled Dot-Product Attention)。 结构如 图5 所示。首先使用每个 q^{(i,j)} 去与 k^{(i,j)} 做 attention,这里说的 attention 就是匹配这两个向量有多接近,具体的方式就是计算向量的加权内积,得到 \alpha_{(i,j)}。这里的加权内积计算方式如下所示:

\alpha_{(1,i)} = q^1 * k^i / \sqrt{d}

具体来说,在这里d代表了qk这两个变量的维度数量。由于q*k的乘积数值会在维度增长时显著增加,在计算过程中将其值除以\sqrt{d}的结果就类似于对数据进行了归一化处理。

接下来,把计算得到的 \alpha_{(i,j)} 取 softmax 操作,再将其与 v^{(i,j)} 相乘。

图5 缩放点积注意力

图5 缩放点积注意力

具体代码实现如下所示:

复制代码
    class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)
    
    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape
    
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
    
        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
    
        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

2.3. 多层感知机(MLP)

在 Transformer 结构中存在一个关键模块即是多层感知机(MLP),它是一种能够通过逐层线性变换与非线性激活函数处理信息的网络模型。如图6所示

图6 MLP多层感知机的结构

具体代码实现如下所示:

复制代码
    class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

2.4. DropPath

除了上述几个关键模块之外,在代码实现过程中我们采用了DropPath(Stochastic Depth)作为替代方案。该方法的作用如下:给定输入张量x具有通道维度[B,C,H,W]的情况下,在一个Batch_size中随机有drop_prob的概率使得样本不会经过主干网络而直接由分支网络进行恒等映射操作。

具体实现如下:

复制代码
    def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output
    
    
    class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    
    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

2.5. 基础模块

基于前面所实现的Attention机制、MLP层以及DropPath组件就能够整合成Vision Transformer架构中的一个基本单元结构,请参见图8

图8 基础模块示意图

图8 基础模块示意图

基础模块的具体实现如下:

复制代码
    class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
    
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

2.6. 定义ViT网络

基础模块搭建完成后,在系统性架构整个ViT架构之前,则需要逐一完成各个关键组件的搭建工作:包括但不限于多模态融合层、时空注意力机制以及特征抽取器等核心功能单元的配置与调优。

  • Class Token

假设我们将原始图像按照 3 \times 3 分割成9个小图像块,在这种情况下最终得到的输入序列长度为10。实际上我们为了提高输入序列的长度而人为添加了一个额外的向量作为输入数据通常我们将这一额外引入的数据项称为 Class Token。这会带来怎样的影响呢?

在缺乏这一机制的情况下,在图像分类任务中我们仍然面临一个问题:即如何确定从这9个编码结果中选取一个来进行分类预测。ViT算法通过引入 learnable 的 Class Token 来解决这一问题:它能够 learnably 将自身嵌入并整合至原有的9个编码结果之中,并通过这一额外的过程生成10个最终的编码表示。随后,在后续的模型推导过程中我们就可以直接利用这个 learnable 的 Class Token 进行分类预测

其实在这种情况下可以认为是 ViT 实际上仅涉及了 Transformer 中的 Encoder 部分 而它并未触及 Decoder 结构 Class Token 的主要功能在于识别其他 9 个输入向量所属的类别

  • Positional Encoding

遵循 Transformer 模型中位置编码的一般做法,在该研究中同样采用了位置编码机制。值得注意的是,在该研究中采用了一种不同于传统 Transformer 模型中 sincos 基于的方法,并非直接继承自原版模型。具体而言,在 ViT 的实现过程中,默认采用了具有学习能力的 Positional Encoding 表示而非固定的 sincos 编码方案。

  • MLP Head

在得到输出结果后,ViT模型中采用了MLP Head结构用于对输出结果进行分类处理。该结构主要由LayerNorm层以及两个全连接层构成,并应用了GELU激活函数以增强非线性表征能力。

具体代码如下所示:

复制代码
    class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU
    
        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
    
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
    
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
    
        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
    
    
        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)
    
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)
    
    
    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
    
        x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
    
    
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        return x[:, 0]
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
    
    def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)

3. 模型特点

  • ViT被视为计算机视觉领域最具影响力的一种Transformer架构,在传统CNN架构的基础上实现了对标准Transformer模型的直接应用。
  • 为适应Transformer模型的要求,在不大幅修改原有流程的前提下,在整个图像中均匀分割出若干个独立的小块,并将这些小块对应的编码序列输入到网络中运行;同时通过类元标记辅助实现分类预测。

全部评论 (0)

还没有任何评论哟~