Advertisement

Self-Attention、Multi-head Self-Attention

阅读量:

个人笔记

讲解得很透彻啊!一看就会啊!视频链接

一、 理论

1. Self-Attention、Multi-head Self-Attention最终效果:

输入:X1 X2 ------self attention------ 输出 Y1 Y2

四者shape相同;

Y1是X1 X2不同权重的加权和;

Y2是X1 X2不同权重的加权和;

2. 计算过程

a1 a2 向量 WQ WK WV 矩阵

shape 1,dmodel dmodel,dk

计算公式如下:

  • 第一步:求取q k v

多个a向量拼接成矩阵;矩阵相乘并行运算速度快

  • 第二步: 求取权重系数
  • 第三步:加权相加

3. Multi-head Self-Attention

n个头,就有n组 WQ WK WV 矩阵;

相较于一个头WQ WK WV 行数不变,列缩减为原来的n分之一

最终得到n组 q k v ;

同理,相较于1个头,q k v 行数不变,列缩减为原来的n分之一

假设对a1 a2进行 Multi-head Self-Attention,头数n=2

  • 第一步:求取每一组的q k v
  • 第二步:

对每个组单独进行Self-Attention(两组互不影响)

  • 第三步:拼接
  • 第四步 :融合

二 、代码实现:

复制代码
 class Attention(nn.Module):   # 多头注意力机制

    
     def __init__(self,
    
              dim,                   # 输入token的dim     如512
    
              num_heads=8,           # 8个头
    
              qkv_bias=False,        #偏置
    
              qk_scale=None,         # VIT中为None 不用管
    
              attn_drop_ratio=0.,
    
              proj_drop_ratio=0.):
    
     super(Attention, self).__init__()
    
     self.num_heads = num_heads
    
     head_dim = dim // num_heads     # 多头其实就是分组计算再合并
    
     self.scale = qk_scale or head_dim ** -0.5     # 也就是公式中的  根号K
    
     self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)   # qkv三个矩阵都是512,512 合并在一起512,512*3
    
                                                         # qkv计算通过全连接实现的
    
     self.attn_drop = nn.Dropout(attn_drop_ratio)
    
     self.proj = nn.Linear(dim, dim)               # 完成自注意计算之后得到两个向量,   还要经过一层全连接映射
    
     self.proj_drop = nn.Dropout(proj_drop_ratio)
    
  
    
     def forward(self, x):    # 牢记输入: (批量,单词数,维度)===(B, N, C)
    
     # [batch_size, num_patches + 1, total_embed_dim]
    
     B, N, C = x.shape      # num_patches相当于输入小图片的个数,也就是单词的个数, 因为最开始要加一个标签分类,所以是 num_patches + 1
    
                            # total_embed_dim也就是dim,也就是单词的维度,eg:512
    
     # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
    
     # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]  # -3维度代表 qkv  -2维度代表不同的头
    
     # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]  # 2维度代表不同的头  0维度代表 qkv
    
     qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)   #求取Q K V 多维矩阵相乘,只需要最后两个维度匹配即可
    
     #
    
     q, k, v = qkv[0], qkv[1], qkv[2]  # 取出 q k v
    
     # q, k, v维度: [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
    
  
    
     attn = (q @ k.transpose(-2, -1)) * self.scale  # K要转置,交换最后两个维度,Q K才能相乘
    
     # 只保证矩阵最后两个维度满足矩阵乘法要求即可  前两个维度不会变[batch_size, num_heads, ... , ...]
    
     # 这个就是公式 Q * K转置 / 根号d
    
     # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]-------K转置后的维度
    
     # [batch_size, num_heads, num_patches + 1, num_patches + 1]---------(Q * K转置 / 根号K)结果的维度
    
  
    
     attn = attn.softmax(dim=-1)   #就是最后一个维度num_patches + 1做softmax
    
     attn = self.attn_drop(attn)   #
    
  
    
     #  multiply后的shape:  [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
    
     # transpose后的shape:  [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
    
     #  reshape后的shape:   [batch_size, num_patches + 1, total_embed_dim]
    
     x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    
  
    
     x = self.proj(x)  # [batch_size, num_patches + 1, total_embed_dim] * [total_embed_dim, total_embed_dim]
    
                       #  最终shape:  [batch_size, num_patches + 1, total_embed_dim]
    
     x = self.proj_drop(x)
    
     return x

上述代码说明:

num_heads=1就是Self-Attention

num_heads>1就是Multi-head Self-Attention

输出:(batch,seq_len,dim)---------------输出:(batch,seq_len,dim)

其实就是全连接,夹杂着做各种shape变换。

全部评论 (0)

还没有任何评论哟~