《Distilling the Knowledge in a Neural Network》论文阅读笔记
《Distilling the Knowledge in a Neural Network》
Abstract:
对于很多大型的机器学习任务来说,集成是一个非常简便就能提高表现的方法,但是集成方法计算量太大。Caruana等人说明了将一个大型网络(集成的模型)中的知识压缩到一个便于计算的小模型中是可取的,Hinton等人提出了一种不同的知识压缩的技巧—知识蒸馏。作者在MNIST上将一个商用网络的知识蒸馏出来应用于一个小模型后,效果有了显著的提高。
Introduction:
基于幼虫理论—“许多昆虫小时候学习是通过吸收环境中的养分,而长大后则是基于小时候学到的知识,通过一次次的实践来调整”。大型神经网络在训练阶段可以(需要)在“超算中心”建立一个很Robust的模型,输入很多很多的数据来进行拟合。但是在部署阶段,不同的场景下都有着各种各样的需求,算力都没有那么大。作者提出了一种能够将已经训练好的大模型的知识迁移到一个容易部署的模型中去的训练方法—“知识蒸馏”。
那么到底什么是原来模型中所学到的知识呢 ,人们在这之前一般的理解就是模型中的每个权重和偏差的数值以及模型整体的架构,我们怎么去改变模型的整体架构却要同时保持着相同的知识呢,作者在原文中提出了这么一段话表达了读者应该产生的一个疑问:
We tend to identify the knowledge in a trained model with the learned parameter values and this makes it hard to see how we can change the form of the model but keep the same knowledge.
这个问题解答是,作者将一个神经网络其中所蕴含的知识表达为输入向量和输出向量之间的映射关系 。一般的大型神经网络的训练目标—最小化损失函数,其实也就是等价于极大正确答案的似然函数(MLE is equal to minimize the loss function),但同时这个操作会导致模型将一些极小的概率赋值给其他错误的答案,其中有一些也会比一些类别的概率很大。因此,**这种概率分布就显示了一个很大的模型是如何generalize分配所有类别的概率的。**换句话说,模型对于所有类别(也许特别是错误的类别)所分配的概率是十分重要的。
文章中提到,如果我们所采用的训练数据能够很好的归纳客观世界的信息,那么理论上训练模型就可以很好的去归纳 出客观世界的信息。在蒸馏过程中,我们需要让小模型采用和大模型一样的归纳 方式。
if the cumbersome model generalizes well , then we can train the small model to generalize in the same way as the large model
那么问题又来了,如何让小模型拥有和大模型一样的归纳方式呢?做法是将大型模型产生的类别概率分布作为小模型的"soft targets"(如果大型模型是集成的模型的话,我们可以取所有集成模型的算术均值或者几何均值)。关于数据我们可以仍然使用之前大模型的训练集或者单独划分出来一个 迁移集。这里还提到了,如果产生的"soft targets"拥有很大的熵的话(即接近均匀分布),我们训练小模型的时候就可以使用很大的学习率以及更少的数据。从信息论的角度看软标签,那些对于一些很简单的模型所产生的"soft target"来说,其大量的信息存在于那些对于错误类别分到的很小的概率值的比率 。所以这里再次强调了,真正有用的信息不是对于正确的类别有多确定,而是其对错误类别极小概率值的分配情况,但是有一个问题就是这些很小的概率值可能对于交叉熵而言并不会影响很大,如果只是单纯的使用Softmax的话提取的到信息可能不太理想。
因此这里又产生了新的问题,怎么才能让模型取学到这些有用的信息呢?一个解决方法是不用Softmax所产生的概率值,而是用网络倒数第二层输入进Softmax的 logits。第二解决方法就是使用蒸馏,即在最后Softmax的时候提高温度T直到产生一个合适的概率值,之后训练小模型的时候也用相同大小的温度值。这里第一种方式其实是第二种方法的特例。
Distillation
q_i = \frac {exp(z_i/T)} {\sum_j{exp(z_j/T)}}
最简单的知识蒸馏过程是在训练大模型进行Softmax的时候对每一项除一个温度T,这里就相当于把网络之中的信息进行了一个“升温”以及便达到蒸馏的条件。To be specific, T越大,产生的Soft Target就越"Soft"。之后通过这些Soft target训练小模型,同时也使用同样温度的T进行“升温”。但在测试阶段T=1。
除了Soft Target之外,加上正确的one hot标签也会很大的提高模型的准确性。总结下来,就是将Loss Function 分为两部分,一部分是Soft Target的损失(温度为T),一部分是Hard Target的损失(温度为1)。有了两项损失,很自然的就想到他们的权重应该怎么分配,作者发现给第二项损失(即Hard Target)一个相对来说比较小的权重的话效果会比较好。此外,为了在改变蒸馏温度T时还能保持数据的原有分布,对所有的梯度 应该乘上T^2。
这里再细微的讨论一下T的取值情况。取T是为了让交叉熵损失也可以注意到那些具有很小概率的分类以让知识转移到小模型时有用。我们可以想象T较小时,数据相对来说还是偏向于之前的分布,蒸馏过程对于那些概率值较低的类的注意力就不会太大;相反,T如果取得很大,数据之间的差异性就可能减少。所以综合下来T要去一个相对来说比较中间的值。
Experiments on MNIST
作者在MNIST上分别训练了两个网络,其中一个为Teacher 网络,包含两层隐含层,采用了DropOut强力正则化,在测试集上有67个 错误;另外一个为Student网络,在测试集上有146个 错误;使用知识蒸馏提取大型网络的知识后,学生网络在测试集上的错误降低到了74个。
This shows that about how to generalize that is learned from translated training data even though the transfer set does not contain any translations.
这里有一个有趣的现象是,作者将迁移数据(Transfer Set)中所有是3的样本去掉之后训练知识蒸馏网络。学生模型仅仅取得了206个错误,总共有1010个“3”的数据只有133个错误,这一表现就证明了对于”负例“分配概率的重要性。
Most of the errors are caused by the fact that the learned bias for the 3 class is much too low
如果把bias增加到3.5,模型整体分错109例,其中只有14例在3上面,这就表现了只要bias是正确的,尽管模型从来没有见过”3“,其表现的效果也不错。同时,如果transfer set 只包含7 和 8 ,从上面的结论上来看网络对于7和8的bias肯定会很大,错误率也在47.3%左右,但人工减少了Bias之后,错误率就降低到了13.2%。 which is quite amazing。
Experiments on speech recognition
作者还在当时State-of-the-art的语音识别任务上做了实验,证明了应用知识蒸馏这种方法将集成模型中的知识转移到一个小模型中比比单单只训练一个小模型效果要好的多。
原来的模型是通过DNN将声波信号转换为一个可供HMM隐状态使用的概率分布,然后通过HMM解析出文字,这里HMM充当的是语言模型的角色。作者使用了一个8层隐藏层的神经网络作为Baseline,在此基础上训练了10个模型然后集成在一起,之后再从集成的模型之中进行知识蒸馏。最终结果如下:

可以看到最终结果和预期的一致,即—集成大模型>蒸馏出来的模型>Baseline。
Training ensembles of specialists on very big datasets
训练集成模型是一种常用的增加正确率以及从平行计算中获得好处的做法,其在测试阶段面临的大量计算量的问题也可以通过蒸馏来解决。但若遇到原本单个模型就很复杂,数据集也很大,这是单个训练的过程也会很慢。对于这种情况,作者这里采用了针对每一个比较Confuse的子类独立的去训练特定的模型。
如果一个数据集很多个类别,我们一般会采取训练一个”全局的模型+Several针对每个类别的模型“,针对每个类别的模型只需要在Softmax时候将所有他不关心的类别归为一类就可。此外作者针对过拟合采取了一一系列手段,如初始化时的参数与全局模型一样、训练集采用一半特定的数据+一半随机的数据、训练后调整一下负类(不关心的类)的比例以防止过采样。
对于一个数据图片,作者会做一个自顶向下的分类,具体步骤如下:
首先在全局的模型上找到最可能的n类,将这n类称为k。在实验中n=1
然后取与k这个类别有着非空交集的所有特定模型的活跃子集A^k,然后找一个全局的概率分布q:
Min\ KL(p^g, q)+\sum_{m\in A_k }KL(p^m,q)
p^g,p^m分别代表着全局模型和特定模型上面的概率分布
最终的结果部分:

Soft targets as Regularizers

我们都知道数据量过少会导致过拟合。从上图中可以看到,第二行只采用3%数据的时候测试误差下降了很多(这里作者还是使用了Early Stopping的)。但第三行只采用3%数据但同时采用了Soft Targets训练出来的网络测试误差竟能达到57%(without early stopping)。
This shows that soft targets are a very effective way of communicating the regularities discovered by a model trained on all of the data to another model
Conclusion
这篇论文的核心思想就是采用Soft target的知识蒸馏,去把一个大的集成模型上学习到的东西高效快速的转移到一个易于实施和部署的小模型上面去。同时个人印象比较深刻的一点就是作者指出错误标签概率分布的重要性,其也是解决过拟合的一个重要途径(从一个角度来说网络会产生过拟合的原因就是原始数据的标签是独热的,只包含着0&1这类信息导致数据小时网络会”妖魔化“的倾向于此而不考虑整体效应,而Soft target作为标签则是一个熵更大的标签)。
如有错误地方请指正。
