Advertisement

PyTorch笔记 - Vision Transformer(ViT)

阅读量:

Transformer由编码器与解码器构成,其核心技术为多头自注意力机制(空间融合)与前馈神经网络(通道融合)。

Encoder和Decoder的交互信息:Memory-base Multi-Head Cross-Attention

注入位置信息Position Embedding

所需的数据量与其 归纳偏置(Inductive Bias) 的引入呈反比例关系,在这种情况下,其上限较高且所需的数据量需达到较高的水平。

基于归纳推理与演绎推理的基础上提出了"归纳偏差(Inductive Bias)"这一概念,并强调了其在构建模型设计中所扮演的关键角色。

Transformer的使用场景:

全编码器模型涵盖预训练语言模型(如BERT)以及监督分类等非流式任务。
生成语言模型系列包括基于自回归的模型架构(如GPT系列)以及语言建模等技术。
编码器-解码器架构适用于机器翻译系统以及语音或文本的语言识别应用。

Vision Transformer(ViT):

  • DNN视角:图像到特征转换模块(Image-to-Patch)和特征嵌入模块(Patch-to-Embedding)
  • 基于二维卷积的图像处理机制
  • 占位符类技术
  • 在推理阶段采用插值方法的位置编码机制
  • Transformer编码器组件
  • 分类头组件

Each Visual entity is Equivalent to a 16x16 Word Representation: Transformer-based Approaches for Scaling-up Image Recognition

image-20220816085320844

Classification Token:起到Query的作用

2D Image Patching to Obtain Low-dimensional Representations -> Incorporate Positional Encoding into Patches -> Enhanced Transformer Architecture -> Neural Network Head Section

Patch + Position Embedding,先从左到右,再从上到下,拉成序列形状

实现Image2Embedding,TransformerEncoder由PyTorch封装

ViT:

复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    # step1 convert image to embedding vector sequence
    def image2emb_naive(image, patch_size, weight):
    """
    使用unfold生成patch
    """
    # image shape: bs*channel*h*w
    # 没有交叠,stride=patch_size,直接生成patch
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size)
    patch = patch.transpose(2, 1)
    # (bs, patch_depth(patch_size*patch_size*ic), num_patch)
    print(f'patch: {patch.shape}')
    patch_embedding = patch @ weight  # 输出的embeding
    print(f'patch_embedding: {patch_embedding.shape}')
    return patch_embedding
    
    
    def image2emb_conv(image, kernel, stride):
    """
    使用conv生成patch
    """
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs*oc*oh*ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(2, 1)
    print(f'patch_embedding: {patch_embedding.shape}')
    return patch_embedding
    
    
    # test code for image2emb
    bs, ic, image_h, image_w = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8  # embedding dim
    max_num_token = 16 
    num_classes = 10
    label = torch.randint(10, (bs,))
    patch_depth = patch_size*patch_size*ic
    
    # 分块方法得到embedding
    torch.manual_seed(42)
    image = torch.randn((bs, ic, image_h, image_w))  # 生成图像
    weight = torch.randn((patch_depth, model_dim))  # patch_depth -> model_dim, model_dim是输出通道数目
    print(f'weight: {weight.shape}')
    patch_embedding_naive = image2emb_naive(image, patch_size, weight)
    print(f'patch_embedding_naive: \n{patch_embedding_naive}')
    
    # 二维卷积方法得到embedding
    # kernel的形状,oc*ic*k_h*k_w
    kernel = weight.transpose(1, 0).reshape((model_dim, ic, patch_size, patch_size))
    patch_embedding_conv = image2emb_conv(image, kernel, stride=patch_size)
    print(f'patch_embedding_conv: \n{patch_embedding_conv}')
    
    
    # step2 prepend CLS token embedding
    cls_token_embedding = torch.randn((bs, 1, model_dim), requires_grad=True)
    token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
    print(f'token_embedding: {token_embedding.shape}')
    
    
    # step3 add position embedding
    position_embedding_table = torch.randn((max_num_token, model_dim), requires_grad=True)
    seq_len = token_embedding.shape[1]
    # 复制 position_embedding 操作
    position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])   
    token_embedding += position_embedding
    print(f'token_embedding: {token_embedding.shape}')
    
    
    # step4 pass embedding to Transformer Encoder
    encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
    transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    encoder_output = transformer_encoder(token_embedding)
    
    
    # step5 do classification
    cls_token_output = encoder_output[:, 0, :]
    linear_layer = nn.Linear(model_dim, num_classes)
    logits = linear_layer(cls_token_output)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(logits, label)
    print(f'loss: {loss}')
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

全部评论 (0)

还没有任何评论哟~