Advertisement

《Distilling the Knowledge in a Neural Network》阅读

阅读量:

https://arxiv.org/abs/1503.02531
Hinton, J.Dean, NIPS 2015 引用量-3144

摘要

通过在相同的数据集上训练多个不同的模型并取其预测结果的平均值来实现对大多数机器学习算法的一种简单改进。不幸的是,在大规模部署时采用单一完整模型会导致冗长笨拙且计算资源消耗高。参考文献[1]提出了一种方法来将集成模型的知识浓缩为一个单一的易于部署的模式在此基础上我们进一步发展了这一概念采用了另一种压缩技术。我们在MNIST数据集上的实验结果表明通过将集成模式浓缩至单一模式的方式显著提升了传统笨重部署模式下的性能。此外我们还提出了一个新的集成框架这种由一个或多个完整模式以及多个专家模式组成的新型结构能够专门学习区分那些复杂混淆问题中的细粒度类别差异这些子级分类器能够高效地并行训练

1 介绍

在大规模机器学习场景中,我们通常采用相近的架构进行模型训练与部署。尽管这两个步骤存在显著差异:在训练阶段是从一个巨大的、高度冗余的数据集中提取特征;然而,在向众多用户提供服务时,则面临更为严格的资源与时间限制。为了提高数据特征提取的效果,则有必要构建一个笨重的大规模模型;这个大模型既可以由多个子模型集成而成,也可以采用一种带有强正则化的单一架构设计。完成这一过程后,则通过独特的训练方法(称为"蒸馏")将大模型的知识迁移到小规模部署版本上。这一概念阻碍了这一领域研究进程:我们倾向于依赖预训练参数来捕获训练模型的知识

对于大型模型而言,其主要任务是从多种类别中识别出数据所属的类型,并以最大化正确类别被预测的概率为目标。然而这会带来一个潜在的影响:错误类别被分类的概率。尽管对于正确类别来说远低于其他情况,在比较不同错误类别被分类的概率时仍存在显著差异。这些错误类别的预测概率不仅揭示了模型的行为模式,并且还能够提供关于数据分布的重要信息。例如,在处理一张宝马车图片时,其被误认为是‘垃圾车’的概率远低于将其误判为‘胡萝卜’的情况。

众所周知,在处理新数据时实现良好的泛化能力是我们训练的目标。然而,在具体实现这一目标的过程中却没有明确指导的信息可循。但在小规模模型的情况下,则可以通过将大模型的知识转移到其身上来实现这一目标。具体而言,在这种情况下意味着当大模型展现出良好的泛化能力——例如通过集成多个不同版本以增强其表现时——将其所掌握的知识被小型模型吸收后所形成的小型知识体系必定会比直接基于原始训练数据独立训练出的小型知识体系更为高效。

将大模型的知识传递给小模型的一种显而易见的方式在于利用大模型生成的概率分布作为小模型的学习目标。在迁移过程中既可以基于相同的训练集展开学习也可以建立独立的迁移集来进行优化。如果大模型是由多个基础模型组成则可通过对各子模型预测结果的不同聚合方式来生成soft targets例如通过将所有简单模型的预测结果进行算术平均或几何平均计算得到soft targets。由于soft targets具有较高的熵值在信息论中相比hard targets(即标签)能够提供更多关于数据分布的信息。此外在微分训练过程中 soft targets所带来的信息量更大且梯度变化较小从而使得小规模模型能够在较少的数据样本和较大的学习率下实现高效的收敛和性能提升。

对于MNIST这类常见任务而言,在这种情况下大模型几乎总能达到显著高的置信度得出正确的答案。

用于小模型的迁移数据集全部由未标记的数据构成[1]或者我们可以直接采用原始训练集。研究发现,在这种情况下效果显著。特别地,在这种情况下引入一个新的目标函数能够更好地促进小模型对真实标签(hard targets)的学习,并与大模型提供的软标签(soft targets)相协调。通常情况下,在这种情况下小模型难以准确学习软标签(soft targets),反而会偏离至正确答案上。

2 蒸馏

在存在真实标签的情况下,在进行蒸馏网络的训练时采用真实标签能够显著提升效果。一种常见的做法是通过调整soft targets来利用真实标签的信息;然而我们发现一种更为优化的方法即直接对软目标和硬目标进行加权平均。

通过预先训练一个teacher网络来实现,并将其输出结果q设为目标;随后通过训练student网络使其输出结果p尽可能接近q;从而得出损失函数的形式为
L=CE(y,p)+\alpha CE(q,p)

这里CE代表交叉熵(Cross-Entropy),其中y表示真实标签以one-hot编码的形式呈现,而qp分别代表教师型网络与学生型网络的输出结果。研究者指出,在这一框架下采用较小权重时效果更为理想。

但是q直接基于teacher网络的分类结果以及经过softmax处理后的输出进行计算,并不能很好地反映类别间的相似度信息。为了克服这一局限性,我们提出了一种改进的Softmax函数:
q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}
其中q_i代表student网络的目标(soft标签),而z_i则是 teacher网络在 softmax 层之前(即logit层)的输出结果。
在这个改进版本中:

  • 当温度参数 T=1 时,默认采用标准Softmax函数;
  • 当温度参数趋近于 0 时(即 T→0^+),输出结果趋近于onehot编码形式;
  • 当温度参数增大时(即 T>1),输出分布趋于更加均匀;
  • 当温度参数趋向于无穷大(即 T→+\infty)时,默认采用均匀分布策略以平衡各类别间的相似度信息。

基于给定的Lq值进行student网络的参数更新。在训练过程中, student网络采用了与 teacher一致的 teacher consistency策略, 而完成后的 forward pass采用了 teacher一致 T=1。

2.1 Distillation在特殊情况下等于直接匹配logits

如何证明这一结论?当T值极大时对交叉熵进行求导运算,并对结果进行分析。q_i等于1除以T乘以括号内e的(z_i除以T)次方除以所有j对应的e的(z_j除以T)次方之和。

如何证明这一结论?当T值极大时对交叉熵进行求导运算,并对结果进行分析。q_i=\frac{1}{T}\left(\frac{e^{\frac{z_i}{T}}}{\sum_{j} e^{\frac{z_j}{T}}}\right)

p_i=\frac{1}{T}\left(\frac{e^{\frac{v_i}{T}}}{\sum_{j} e^{\frac{v_i}{T}}}\right)

其中vi是cumbersome model产生额logits,相当于真实分布。zi是distilled model.
\text { cross entropy }=-\sum_j p_{j} \times \log q_{j}
求导,得到
\frac{\partial C}{\partial z_{i}}=\frac{1}{T}\left(q_{i}-p_{i}\right)=\frac{1}{T}\left(\frac{e^{\frac{z_i}{T}}}{\sum_{j} e^{\frac{z_i}{T}}}\right)-\frac{1}{T}\left(\frac{e^{\frac{v_i}{T}}}{\sum_{j} e^{\frac{v_i}{T}}}\right)
T非常大的时候distillation优化的目标等价于Caruana提取的对logits的平方误差求最优化
\frac{\partial C}{\partial z_{i}} \simeq \frac{1}{T}\left(\frac{1+\frac{z_i}{T}}{N+\frac{\sum_j z_j}{T}}-\frac{1+\frac{v_i}{T}}{N+\frac{\sum_j v_{j}}{T}}\right)
当我们假设logits是zero-means,则
\sum_j z_j=\sum_j v_j
所以
\frac{\partial C}{\partial z_{i}} \approx \frac{1}{N T^{2}}\left(z_{i}-v_{i}\right)

实验

3 MNIST实验

基于MNIST数据集,在采用较大规模的网络架构进行训练时出现了67个测试集错误样本;而当转而使用较小规模的网络模型进行训练时,则导致测试集中误分类样本减少至146个。通过将软标签引入目标函数作为正则项,在这一过程中可观察到测试集误分类数量进一步减少至74个左右的结果出现。这一结果进一步验证了教师模型成功地将知识转移至学生模型中,并由此实现了分类性能的提升

4 speech recognition实验

第二个实验是在speech recognition领域,使用不同的参数训练了10个DNN,对这10个模型的预测结果求平均作为emsemble的结果,相比于单个模型有一定的提升。然后将这10个模型作为teacher网络,训练student网络。得到的Distilled Single model相比于直接的单个网络,也有一定的提升,结果见下表:

5 在大型数据集上训练专家集成模型

构建一套完整的集成模型可借助并行计算这一高效工具。然而,在测试阶段所需计算资源非常庞大;而蒸馏技术则可有效缓解这一问题。但集成模型仍存在另一个局限性:当单个神经网络规模庞大且所使用的数据集同样 colossal时;此时所需的训练时间也会显著增加;即使采用并行计算策略。

本节阐述了一个大数据集作为案例进行分析,并采用了系列专注于特定子数据集的专业化模型架构。这些方法成功地解释了如何有效地降低集成学习过程中的计算负担。通过引入软标签技术处理这一主要问题时态特征和潜在挑战时态特征之间的关系模式。

5.1 The JFT dataset

基于

Google的训练用了两个并行方法:

  1. 采用了大量复制的神经网络,并将其分别部署在各个不同的核心单元以及 mini-batch 中使用。
  2. 每个复制网络则会将其自身具有的神经元均匀分布于各自占据的核心单元中。
  3. 集成训练是并行训练中的第三个层次,在前两种方法的基础上。
  4. 只有当拥有大量核心资源时才可行。
  5. 但若核心资源较为匮乏,则要在数年内完成训练任务显得不切实际;
  6. 因此我们亟需一种更为高效的方法来提升现有基准模型的表现。

5.2 Specialist Models

通过并行训练多个模型并构建集成模型是一种有效的方法;然而这会消耗大量资源尤其是当单个模型较大时;研究者认为每个单独的模型应专注于一个特定类别;每个这样的专家型模型都具备特定的优势;但这种做法容易导致过拟合问题;为此他们采用了蒸馏技术结合软标签策略来防止过拟合现象的发生

  • 将容易混淆的数据样本分为不同的易混类别。
  • 未被关注到的所有其他类别被统一归为一个dustbin类别。
  • 其权重由通用模型(基于全部训练数据)进行初始化。
  • 其权重在训练过程中采用了一半来自自身领域知识的数据样本和另一半来自随机采样的其他领域数据。
  • 经过训练后,在调整偏移后的训练集中提升dustbin类别的logits值,并将其对应于 expert 类别 的抽样比例取其自然对数。

5.3 给专家模型分配类别子集

为了强调通用模型中容易出现混淆的部分, 我们特别关注那些经常被这类模型所误判的对象. 采用这种方法的原因是为了避免因过于依赖单一指标而导致分类结果不够准确, 而通过引入更加灵活且实用的替代方案来提升整体分析效果.

为通用模型的预测结果构建了协方差矩阵,并采用某种聚类方法对其进行分组。那些常见于同一组预测任务中的类别被归为一类,并命名为特定专家模型m所对应的类别集合S^m。通过对协方差矩阵各列的数据输入到k-means算法中进行迭代优化,在有限次迭代后获得较为理想的分类效果。

在这里插入图片描述

5.4 用专家集成模型进行推理

评估一下专家集成模型的表现如何。除此之外,在所有数据集上我们也建立了通用模型。这样一来,在没有专家可用的情况下我们也能完成相应的分类任务。

给定输入图像X:
首先利用通用模型确定n个可能性最高的类别k,在此设定中取n=1。
接下来确定涵盖对应类别k的所有专家模型集合Ak。
然后求解得到每个类别对应的全概率分布q_min(即预测的概率值)。

Eq.5
K L\left(\boldsymbol{p}_{\text {gen}}^{g}, q\right) + \sum_{m \in A_{k}} K L\left(\boldsymbol{p}_{\text {expert}}^{m}, q\right)
其中 K L 表示 KL 散度指标,在本研究中用于衡量两个概率分布之间的差异程度。\boldsymbol{p}{\text {gen}}^{g} 是全局生成器的概率密度函数估计值。\boldsymbol{p}_{\text {expert}}^{m} 是每个专家分类器对应的条件概率密度函数估计值。\boldsymbol{p}{\text {expert}}^{m} 的定义域包含了所有可能的目标类别以及一个额外的聚类中心点(即所谓的 "dustbin" 类)。因此,在计算该聚类中心点与 q 之间的 KL 散度时,则评估了 $q 在该聚类中心点下的各类别概率总和。

Equation 5 does not generally possess a closed-form solution, although when each model outputs a single probability for every class, the solution simplifies to either an arithmetic mean or a geometric mean, depending on whether we employ the KL divergence from p to q or from q to p. We parameterize q = \text{softmax}(z) (with T = 1) and optimize the logits z with respect to Equation 5 using gradient descent. Note that this optimization must be performed for each individual image.

5.5 结果

从基准线开始进行训练;专家模型的学习效率极高;仅需几天即可替代原本耗时几周的过程;且均为完全独立的学习过程。

在这里插入图片描述

在上图中展示的结果表格3与之相对应的是对比实验结果。条件测试仅涉及专家类别,在预测过程中也同样局限于这一类别,并且其准确性表现出了较高的水平。

在一项名为JFT的专家实验项目中, 我们开发了多达百余家专业领域的专家模型, 每个模型都包含300个类别(其中包括dustbin类别)。由于各 expert 所涉及的 categories 存在交叉, 因此, 在实际应用中, 一个特定的 image category 通常会被多个 expert 所涵盖。

表4展示了测试集中的样本数量及其在位置1上的正确率变化情况,并按涵盖不同类别的专家细分分析了JFT数据集在top1分类任务上的准确率较之前的提升幅度。

例如,在JFT中包含着共计350,037条数据样本。这些样本中未被任何 expert model 覆盖的共有141,993条样本;仅被一个 expert model 所涵盖的约有9,082条样本;而约9,082条样本已被至少十个 expert model 所涵盖。第③/④列则体现了通过使用 expert model 达成的准确率与精确度较之前 top 1 的提升程度。

当我们在一个特定的类上部署更多专家时,在这个领域内我们通常会看到更高的精度提升。这带来了积极的影响,并且由于它们易于并行训练且实现相对简单

6 用soft targets当正则项

基于软目标而非硬目标的核心观点认为,在软目标中能携带大量有用信息,并且软目标无法由单一硬目标编码。在本节中,我们通过仅使用少量数据拟合前面所述的基础线性模型(拥有85 million参数)展示了这一显著的效果。

在这里插入图片描述

如表5所示, 该方法仅包含约2千万个实例的数据样本, 即使如此, 使用hard targets进行训练也会导致明显的过拟合现象(我们在达到44.5%后立即终止)。然而, 使用soft targets进行训练则能够恢复近似全部的信息量。值得注意的是, soft targets系统最终收敛至约为57%, 这些结果表明, 采用soft targets的方法能够显著提高性能。

6.1 soft targets 防止专家模型过拟合

在我们的实验研究中,在JFT数据集上,这些被归类为non-expert class的样本被集中到同一个dustbin类别中。如果我们让这些专家能够在每个classes上都具备完整的softmax机制,则或许可以找到一种更为有效的策略来避免发生overfitting现象的发生。

一个专家模型接收的数据在其 expert 类别中高度丰富。这表明其 training set 的有效规模较小,并且容易严重地 overfit 到其 expert 类别。问题无法通过缩小该 model 来解决,因为这样做会导致我们无法建模所有非 expert 类所能提供的宝贵信息。

基于仅3%语音数据开展的实验结果表明,在采用了通用权值初始化的情况下,并不仅限于采用硬标签方式进行训练之外

7 Relationship to Mixtures of Experts

8 讨论

全部评论 (0)

还没有任何评论哟~