Advertisement

Vision Transformer(vit)的主干

阅读量:

图解:

代码:

复制代码
 class VisionTransformer(nn.Module):

    
     def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
    
              embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
    
              qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
    
              attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
    
              act_layer=None):
    
     """
    
     Args:
    
         img_size (int, tuple): input image size
    
 #输入图像的大小,通常是 224 或其他标准尺寸
    
         patch_size (int, tuple): patch size
    
 #每个块(patch)的大小,例如 16x16
    
         in_c (int): number of input channels
    
 #输入图像的通道数,RGB 图像是 3
    
         num_classes (int): number of classes for classification head
    
 #最终分类的类别数,默认 1000 类
    
         embed_dim (int): embedding dimension
    
 #嵌入维度,即每个 patch 被映射到的向量的维度,默认是 768
    
         depth (int): depth of transformer
    
 #Transformer 的深度,即堆叠的块(Block)数量。
    
         num_heads (int): number of attention heads
    
 #注意力头的数量,默认设为 12
    
         mlp_ratio (int): ratio of mlp hidden dim to embedding dim
    
 # MLP 隐藏层的维度与嵌入维度的比例。
    
         qkv_bias (bool): enable bias for qkv if True
    
 #是否为 QKV(查询、键、值)矩阵添加偏置
    
         qk_scale (float): override default qk scale of head_dim ** -0.5 if set
    
 #如果设定,将会覆盖默认的 qk 缩放因子
    
         representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
    
 #如果设置了这个值,将会有一个表示层(pre-logits)
    
         distilled (bool): model includes a distillation token and head as in DeiT models
    
 #vit中可以不管这个参数
    
         drop_ratio (float): dropout rate
    
 # Dropout 的比例
    
         attn_drop_ratio (float): attention dropout rate
    
 #注意力层的 Dropout 比例
    
         drop_path_ratio (float): stochastic depth rate
    
 #droppath比例
    
         embed_layer (nn.Module): patch embedding layer
    
 #用于嵌入图像的层,默认使用 PatchEmbed
    
         norm_layer: (nn.Module): normalization layer
    
 #正则化层,通常是 LayerNorm
    
     """
    
     super(VisionTransformer, self).__init__()
    
     self.num_classes = num_classes
    
     self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    
 # 与 embed_dim 保持一致,表示嵌入的维度。
    
     self.num_tokens = 2 if distilled else 1
    
 #不管distilled所以distilled=1
    
     norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
    
 #使用 LayerNorm作为默认的规范化层
    
     act_layer = act_layer or nn.GELU
    
 #默认使用 GELU 作为激活函数
    
  
    
     self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
    
 #Embedding层结构
    
     num_patches = self.patch_embed.num_patches
    
 #patches的个数
    
  
    
     self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    
 #这是用于分类的分类标记(Class Token),它是一个可学习的参数,初始值为零
    
     self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
    
 #不管distilled所以self.dist_token=None
    
     self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
    
 #位置编码(Position Embedding)
    
     self.pos_drop = nn.Dropout(p=drop_ratio)
    
 #位置编码后的 Dropout 操作
    
  
    
     dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
    
 #用于控制每个 Block 的 DropPath 比例
    
     self.blocks = nn.Sequential(*[
    
         Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
    
               drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
    
               norm_layer=norm_layer, act_layer=act_layer)
    
         for i in range(depth)
    
     ])
    
 #使用 Block 类构建了Transformer的主体部分,包括注意力和MLP层,并使用残差连接和 DropPath 
    
     self.norm = norm_layer(embed_dim)
    
 #最后的归一化层,用于 Transformer 输出的处理
    
  
    
     # Representation layer
    
     if representation_size and not distilled:
    
 #设置了 representation_size则会增加一个表示层 pre_logits,not distilled=true
    
         self.has_logits = True
    
         self.num_features = representation_size
    
         self.pre_logits = nn.Sequential(OrderedDict([
    
             ("fc", nn.Linear(embed_dim, representation_size)),
    
             ("act", nn.Tanh())
    
         ]))
    
 #pre_logits层结构一个全连接和tanh激活函数
    
     else:
    
         self.has_logits = False
    
         self.pre_logits = nn.Identity()
    
  
    
     # Classifier head(s)
    
     self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
    
     self.head_dist = None
    
     if distilled:
    
         self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
    
 #distilled为none不用管
    
  
    
     # Weight init
    
     nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
     if self.dist_token is not None:
    
         nn.init.trunc_normal_(self.dist_token, std=0.02)
    
  
    
     nn.init.trunc_normal_(self.cls_token, std=0.02)
    
     self.apply(_init_vit_weights)
    
 #权重初始化
    
  
    
     def forward_features(self, x):
    
     # [B, C, H, W] -> [B, num_patches, embed_dim]
    
     x = self.patch_embed(x)  # [B, 196, 768]
    
 #将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层
    
     # [1, 1, 768] -> [B, 1, 768]
    
     cls_token = self.cls_token.expand(x.shape[0], -1, -1)
    
     if self.dist_token is None:
    
         x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
    
     else:
    
         x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
    
 #分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接
    
  
    
     x = self.pos_drop(x + self.pos_embed)
    
 #添加位置编码并应用 Dropout
    
     x = self.blocks(x)
    
 #通过 Transformer 的 Block 堆叠进行处理
    
     x = self.norm(x)
    
 #进行归一化
    
 #vit中self.dist_token is None所以模型只有分类标记 (class token)。
    
     if self.dist_token is None:
    
         return self.pre_logits(x[:, 0])
    
 #x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
    
     else:
    
         return x[:, 0], x[:, 1]
    
  
    
     def forward(self, x):
    
     x = self.forward_features(x)
    
 #首先获取 Transformer 的特征输出
    
     if self.head_dist is not None:
    
         x, x_dist = self.head(x[0]), self.head_dist(x[1])
    
         if self.training and not torch.jit.is_scripting():
    
             # during inference, return the average of both classifier predictions
    
             return x, x_dist
    
         else:
    
             return (x + x_dist) / 2
    
     else:
    
 #self.head_dist为none只看head层就是最后的全连接层输出为num_classes
    
         x = self.head(x)
    
     return x
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-19/Q0b1pTZERkKWPl4Ji3HMwOvC8NDo.png)

操作:

代码:

[B, C, H, W] -> [B, num_patches, embed_dim]

x = self.patch_embed(x) # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层

操作:

代码:

[1, 1, 768] -> [B, 1, 768]

cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

操作:

代码:

x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout

操作:

代码:

x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
x = self.norm(x)
#进行归一化

操作:

代码:

#vit中self.dist_token is None所以模型只有分类标记 (class token)。
if self.dist_token is None:
return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
else:
return x[:, 0], x[:, 1]

操作:

代码:

#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
x = self.head(x)

分类标记 (Class Token):

是一种特殊的 输入 token ,在 Transformer 模型中被用来聚合全局特征。

它在模型中起到了类似于 CNN 中全局池化 (Global Pooling) 的作用,负责从所有 patch 的信息中提取一个全局表示。

这个 token 的输出向量被用作分类任务的特征输入,之后会被送入分类头 (classifier head) 进行最终的类别预测。

embedding层:

[Vision Transformer(vit)的Embedding层结构-博客]( "Vision Transformer(vit)的Embedding层结构-博客")

Multi-Head Self-Attention:

[Vision Transformer(vit)的Multi-Head Self-Attention(多头注意力机制)结构-博客]( "Vision Transformer(vit)的Multi-Head Self-Attention(多头注意力机制)结构-博客")

MLP模块:

[Vision Transformer(vit)的MLP模块-博客]( "Vision Transformer(vit)的MLP模块-博客")

Encoder block:

[Vision Transformer(vit)的Encoder层结构-博客]( "Vision Transformer(vit)的Encoder层结构-博客")

详解:[Vision Transformer详解-博客]( "Vision Transformer详解-博客")

全部评论 (0)

还没有任何评论哟~