Advertisement

【AAAI 2021】零样本知识蒸馏:Data-Free Knowledge Distillation with Soft Targeted Transfer Set Synthesis

阅读量:

AAAI 2021

AAAI 2021

  • 论文地址:

  • 主要问题:

  • 主要思路:

  • 主要贡献:

  • 具体实现:

    • 基本符号:
    • 具有多元正态分布的特征空间建模:
    • 软目标标签生成:
    • 样本生成:
    • 基于伪样本的知识蒸馏:
  • 实验结果:

  • 联系作者:

  • 我的公众号:

论文地址:

https://arxiv.org/abs/2104.04868

主要问题:

通过生成模拟数据来模仿原始数据集的特性是实现无数据 KD的核心思路。

主要思路:

这篇文章深入探讨了一种基于无数据条件的KD方法,在研究过程中发现该方法能够有效实现教师中间特征空间的建模,并通过多元正态分布模型构建教师中间特征空间;随后通过该分布生成的软目标标签创建伪样本集作为转移集的数据来源。

此方法与传统的硬最大空间相比,在较低层次上的A类分布被建模后可被归类为更具包容性的A类目标。

主要贡献:

  1. 研究者基于多元正态分布构建教师中间层特征空间模型,并通过从该分布中采样生成优化伪样本以提升合成样本质量。
  2. 研究者通过较低层次输出分布模型而非直接构建模糊最大空间模型来获得更广义的软目标以提高性能指标。
  3. 通过多组基准网络架构与数据集测试与对比分析发现所提出算法相较于现有方法具有显著优势。

具体实现:

基本符号:

假设知识蒸馏中,教师模型和学生模型的输出分别是:

The probability distribution P_T^{\tau} equals the transformation function T applied to x with weight matrix W_T and parameter \tau, calculated as \operatorname{softmax}(a_T/\tau). Similarly, the probability distribution P_S^{\tau} equals the transformation function T applied to x with weight matrix W_S and parameter \tau, calculated as \operatorname{softmax}(a_S/\tau).

这样我们就可以将蒸馏损失写作:

\mathcal{L}_{KD}被定义为两个交叉熵损失之和:一个是基于时间τ的条件概率分布之间的对比度损失项\mathcal{L}_{CE}(P_T^\tau, P_S^\tau);另一个是基于源域数据的分类器预测结果与真实标签之间的对比度损失项\lambda_c\mathcal{L}_{CE}(P_S,y)

当缺乏原始训练数据集时,我们不得不生成伪样本作为信息传递媒介(通常称为传输集)。

在本研究工作中, 研究者构建了基于多元正态分布的方法来表征教师模型的中间特征. 我们从数据集中采用该方法生成样本用于特征提取.

然后我们基于固定教师模型的反向传播过程来优化噪声输入,并在此过程中生成伪样本。该方法旨在最小化要优化的输入与其对应的软最大输出和生成软目标之间的差距。

最后通过一个标准的KD过程,用伪样本训练学生模型

具有多元正态分布的特征空间建模:

定义 s^l = \\left\\{ s_1^l, s_2^l, \\cdots, s_k^l \\right\\} 服从概率分布 p(s^l) 并被建模为多元正态分布 s^l \\sim \\mathcal{N}_k(\\bm{\\mu}, \\bm{\\Sigma}) ,其中均值向量 \\bm{\\mu}\\in\\mathbb{R^k} 和协方差矩阵 \\bm{\\Sigma}\\in\\mathbb{R^{k\times k}}

从统计学角度来看,总体协方差矩阵\mathbf{\Sigma}=\left(\sigma_{i j}\right)_{k \times k} 可以表示为相关矩阵\boldsymbol{R}=\left(\rho_{i j}\right)_{k \times k} 的函数。其中假设总体协方差矩阵\boldsymbol{D}=\operatorname{diag}\left[\sqrt{\sigma_{11}}, \sqrt{\sigma_{22}}, \cdots, \sqrt{\sigma_{k k}}\right]包含主对角线元素的平方根向量。由此可得:当计算出协方差矩阵后,则可以通过该公式进一步推导。

\Sigma=\boldsymbol{D} \times \boldsymbol{R} \times \boldsymbol{D}

也就是:

\sigma_{i j}=\rho_{i j} \sqrt{\sigma_{i i} \sigma_{j j}}, \quad i, j=1,2, \cdots, k

相较于协方差矩阵\mathbf{\Sigma},系数矩阵\boldsymbol{R}更加直观地作为统计量,在衡量随机向量中各分量之间相关程度方面具有显著优势,并且能够通过教师模型的权值参数有效提取相关信息。

为了便于分析的过程,在此处作者进行了构建,并在教师模型中构建了全连接层的特征空间表示

假设我们关注于构建教师模型第 l 数层特征空间,则可推出该权重矩阵即为教师模型在该数层所学习到的基础模板。

该模板通过将第(l-1)层神经元与其后续第l层神经元进行配准,并阐述了两层特征图间的关联:当特征空间中的某个元素达到峰值时,则表明对应的权重会激发其响应;相反地,在关系不协调的情况下,则可能导致某些特征单元数值降低。

据分析可知,在该层的权值中蕴含着特征空间元素间的内在关联被隐式地编码。通过计算这些权重矩阵中的数值变化来推断特征间的关系。

The formula R(i,j)= \rho_{ij}= \frac{\boldsymbol{w}_i^{T} \cdot \boldsymbol{w}_j}{||\boldsymbol{w}_i|| \cdot ||\boldsymbol{w}_j||} determines the cosine similarity between vectors \boldsymbol{w}_i and \boldsymbol{w}_j, which quantifies their directional correspondence.

在无标注数据的 KD 场景下,在给定 \boldsymbol{D} 中每个变量 \sigma_{i i}(i=1,2,...,k) 的方差判定较为复杂的情况下,则视之为一个超参数。

而且 \sigma 可以被视作衡量特征空间密度的关键参数,在这种情况下较大的 \sigma 会引导样本集中在特征表示的一个或多个组件上;相反地,在 \sigma 减小时的情形下,则会导致样本趋于呈现较为均匀的分布状态

软目标标签生成:

如果建模最后一层(软最大空间算法),那么软目标标签可以写作:

对于每一个样本s^l和温度参数\tau来说,在经过计算后得到的结果为\mathbf{y}^{\texttt{soft}} = \texttt{softmax}\left( \frac{\mathbf{s}^{l}}{\tau} \right)

而对于前面的层(也就是特征层),其软标签通常写作:

\boldsymbol{y}^{\mathrm{soft}}=T\left(\boldsymbol{s}^{l}, W_{T}^{l+}, \tau\right)

样本生成:

随后我们从候选列表中随机选择 n 个软目标 \boldsymbol{y}_{\boldsymbol{i}}^{\text {soft }}(i=1,2, \cdots, n) 并生成一批具有代表性的噪声样本 \hat{x_i} 这些输入信号会被送入固定 teacher 模型中进行推断运算 最后系统自动生成预测结果向量 \hat{y}_i

然后通过递归反向传播机制对随机干扰信号进行优化处理,并降低 y_{i}^{\text {soft }}\hat{y_i} 之间的差异程度:

损失函数d被定义为KL散度(软标签向量与预测概率向量之间的差异)。

当输入训练样本时,在训练良好深度神经网络中通常会产生较高的激活值,并且为了提高图像合成效果,在卷积层 x_{\text{conv-1}} 的输出中引入了额外的激活损失项。

\mathcal{L}_{a}=-\frac{1}{n} \sum_{i=1}^{n}\left\|x_{\mathrm{conv}-1}^{i}\right\|_{1}

基于伪样本的知识蒸馏:

最后,我们使用生成的伪样本作为转移集,用如下损失函数训练学生模型:

该损失函数被定义为交叉熵损失操作被定义为T操作和S操作的组合。

整个算法的伪代码如下:

在这里插入图片描述

实验结果:

在这里插入图片描述

联系作者:

微信号:Sharpiless

作者的其他主页:

B站:https://space.bilibili.com/470550823

:<>

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

我的公众号:

全部评论 (0)

还没有任何评论哟~