PyTorch笔记 - Vision Transformer(ViT)
Transformer由编码器与解码器构成,其核心技术为多头自注意力机制(空间融合)与前馈神经网络(通道融合)。
Encoder和Decoder的交互信息:Memory-base Multi-Head Cross-Attention
注入位置信息Position Embedding
所需的数据量与其 归纳偏置(Inductive Bias) 的引入呈反比例关系,在这种情况下,其上限较高且所需的数据量需达到较高的水平。
基于归纳推理与演绎推理的基础上提出了"归纳偏差(Inductive Bias)"这一概念,并强调了其在构建模型设计中所扮演的关键角色。
Transformer的使用场景:
全编码器模型涵盖预训练语言模型(如BERT)以及监督分类等非流式任务。
生成语言模型系列包括基于自回归的模型架构(如GPT系列)以及语言建模等技术。
编码器-解码器架构适用于机器翻译系统以及语音或文本的语言识别应用。
Vision Transformer(ViT):
- DNN视角:图像到特征转换模块(Image-to-Patch)和特征嵌入模块(Patch-to-Embedding)
- 基于二维卷积的图像处理机制
- 占位符类技术
- 在推理阶段采用插值方法的位置编码机制
- Transformer编码器组件
- 分类头组件
Each Visual entity is Equivalent to a 16x16 Word Representation: Transformer-based Approaches for Scaling-up Image Recognition

Classification Token:起到Query的作用
2D Image Patching to Obtain Low-dimensional Representations -> Incorporate Positional Encoding into Patches -> Enhanced Transformer Architecture -> Neural Network Head Section
Patch + Position Embedding,先从左到右,再从上到下,拉成序列形状
实现Image2Embedding,TransformerEncoder由PyTorch封装
- torch.nn.functional.unfold:提取用于patch卷积的区域
- torch.nn.TransformerEncoder模块:实现Transformer编码器功能
ViT:
import torch
import torch.nn as nn
import torch.nn.functional as F
# step1 convert image to embedding vector sequence
def image2emb_naive(image, patch_size, weight):
"""
使用unfold生成patch
"""
# image shape: bs*channel*h*w
# 没有交叠,stride=patch_size,直接生成patch
patch = F.unfold(image, kernel_size=patch_size, stride=patch_size)
patch = patch.transpose(2, 1)
# (bs, patch_depth(patch_size*patch_size*ic), num_patch)
print(f'patch: {patch.shape}')
patch_embedding = patch @ weight # 输出的embeding
print(f'patch_embedding: {patch_embedding.shape}')
return patch_embedding
def image2emb_conv(image, kernel, stride):
"""
使用conv生成patch
"""
conv_output = F.conv2d(image, kernel, stride=stride) # bs*oc*oh*ow
bs, oc, oh, ow = conv_output.shape
patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(2, 1)
print(f'patch_embedding: {patch_embedding.shape}')
return patch_embedding
# test code for image2emb
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8 # embedding dim
max_num_token = 16
num_classes = 10
label = torch.randint(10, (bs,))
patch_depth = patch_size*patch_size*ic
# 分块方法得到embedding
torch.manual_seed(42)
image = torch.randn((bs, ic, image_h, image_w)) # 生成图像
weight = torch.randn((patch_depth, model_dim)) # patch_depth -> model_dim, model_dim是输出通道数目
print(f'weight: {weight.shape}')
patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(f'patch_embedding_naive: \n{patch_embedding_naive}')
# 二维卷积方法得到embedding
# kernel的形状,oc*ic*k_h*k_w
kernel = weight.transpose(1, 0).reshape((model_dim, ic, patch_size, patch_size))
patch_embedding_conv = image2emb_conv(image, kernel, stride=patch_size)
print(f'patch_embedding_conv: \n{patch_embedding_conv}')
# step2 prepend CLS token embedding
cls_token_embedding = torch.randn((bs, 1, model_dim), requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
print(f'token_embedding: {token_embedding.shape}')
# step3 add position embedding
position_embedding_table = torch.randn((max_num_token, model_dim), requires_grad=True)
seq_len = token_embedding.shape[1]
# 复制 position_embedding 操作
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])
token_embedding += position_embedding
print(f'token_embedding: {token_embedding.shape}')
# step4 pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)
# step5 do classification
cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(f'loss: {loss}')
代码解读
