【AAAI 2021】基于Attention的知识蒸馏:Knowledge Distillation via Attention-based Feature Matching
AAAI 2021
AAAI 2021
AAAI 2021
AAAI 2021
- 论文地址:
- 代码地址:
- 主要问题:
- 主要思路:
- 具体实现:
- 实验结果:
论文地址:
https://arxiv.org/abs/2102.02973
代码地址:
GitHub CLOVA AI - Attention Feature Distillation
主要问题:
大多数采用基于特征的知识蒸馏方法都是通过人工构建教师和学生之间的中间特征,并采用固定形式的连接方式来实现知识传递。然而,这种人工选择往往容易导致构建出无效的连接,从而影响知识蒸馏的性能。
主要思路:
该文详细阐述了一种高效的蒸馏方法,通过注意机制实现特征连接。该方法能够充分利用教师模型的所有特征层,无需人工选择连接方式。
具体而言,作者提出了一种基于元网络的方法,该方法以注意力机制为基础,学习特征间的相对相似性,并将识别出的相似性用于调节所有教师-学生特征对的蒸馏强度。

具体实现:
定义:由教师模型生成的一组候选特征向量 \mathbf{h}^{\mathrm{T}}=\left\{h_{1}^{\mathrm{T}}, \ldots, h_{T}^{\mathrm{T}}\right\},以及由学生模型生成的一组候选特征向量 \mathbf{h}^{\mathrm{S}}=\left\{h_{1}^{\mathrm{S}}, \ldots, h_{S}^{\mathrm{S}}\right\}。每个候选特征的特征图大小和通道数量均有所差异,例如,特征向量 h 属于 \mathbb{R}^{H \times W \times d} 空间。
AFD 的主要目标是检测所有可能的组合,这些组合共有 S×T 对。通过这种方式,教师模型候选特征的知识被转移给能够识别相似性的学生模型的对应特征。
通过分析上图,我们可以发现,作者将候选特征分别采用两种池化方法,在两个不同方向上进行了对比分析。具体而言,这些方法包括全局平均池化操作和通道池化技术。此外,作者以两个全局集合特征的相似性程度作为衡量标准,对这两种池化方法进行了综合评估。
为了识别教师和学生的特征间相似性,AFD方法基于注意力机制的查询-键概念进行设计,其中教师特征生成对应的\mathbf{q}^t,学生特征则对应识别出\mathbf{k}^s。
\begin{array}{l} \mathbf{q}_{t}由f_{Q}生成,其值为权重矩阵W_{t}^{\mathrm{Q}}与ϕ^{HW}函数作用于h_{t}^{\mathrm{T}}后的乘积。 \\ \mathbf{k}_{s}由f_{K}生成,其值为权重矩阵W_{s}^{\mathrm{K}}与ϕ^{HW}函数作用于h_{s}^{\mathrm{S}}后的乘积。 \end{array}
其中,\phi^{H W}(\cdot) 是一种全局平均池化操作,f_Q 和 f_K 是激活函数,W^Q_t 和 W^K_s 分别属于 \mathbb{R}^{R \times d^T_t} 和 \mathbb{R}^{R \times d^S_s},它们是线性转换参数。
需要注意的是,这些特征各自具有独特的转换权重,具体而言,它们通过不同的层级具有各自的属性特征。低层次的视觉特征主要表征基本的几何特性,而高层次的视觉特征则关联复杂的语义关联。
然后我们就可以计算 attention 值:
\begin{aligned} \alpha_{t}=& \text{softmax}\left(\left[\left(\mathbf{q}_{t}^{\top} W_{1}^{\mathrm{Q}-\mathrm{K}} \mathbf{k}_{t, 1}+\left(\mathbf{p}_{t}^{\mathrm{T}}\right)^{\top} \mathbf{p}_{1}^{\mathrm{S}}\right) \div \sqrt{d}\right.\right.\\ &\left.\left.\cdots,\left(\mathbf{q}_{t}^{\top} W_{S}^{\mathrm{Q}-\mathrm{K}} \mathbf{k}_{t, S}+\left(\mathbf{p}_{t}^{\mathrm{T}}\right)^{\top} \mathbf{p}_{S}^{\mathrm{S}}\right) \div \sqrt{d}\right]\right) \end{aligned}
该方法采用了额外的双线性参数权重矩阵 W_{t}^{\mathrm{Q}-\mathrm{K}} \in \mathbb{R}^{d \times d},并设置了两个位置编码向量 \mathbf{p}_{t}^{\mathrm{T}} \in \mathbb{R}^{d} 和 \mathbf{p}_{s}^{\mathrm{S}} \in \mathbb{R}^{d}。
因为query-key是从不同的维特征中提取出来的,作者通过双线性权重机制生成了不同来源等级的注意力值,编码信息则用于共享不同实例的共同特征。
\alpha_t 用于表示第t个教师特征与整个学生特征之间关系的注意向量,从而使得教师特征 h^T_t 能够选择性地将知识转移到学生特征中。
最终的蒸馏损失就可以写作:
该损失函数\mathcal{L}_{\mathrm{AFD}}的定义可表示为:\mathcal{L}_{\mathrm{AFD}}=\Sigma_{t} \Sigma_{s} \alpha_{t, s}\left\|\tilde{\phi}^{C}\left(h_{t}^{\mathrm{T}}\right)-\tilde{\phi}^{C}\left(\hat{h}_{s}^{\mathrm{S}}\right)\right\|_{2},其中\Sigma_{t} \Sigma_{s}表示对所有时间步t和序列位置s的双重求和。
其中,\tilde{\phi}^{C}表示经过L2归一化的通道平均池化层的组合函数,\hat{h}_{s}^{\mathrm{S}}是h^S_s经过上采样或下采样处理后得到的特征,以使特征映射的尺寸与教师特征的尺寸一致。
实验结果:


