知识蒸馏经典论文阅读
此篇Distilling the Knowledge in a Neural Network 被视为知识蒸馏领域的开创性论文,在该领域首次提出知识蒸馏的思想
整体的论文研究动机如下:
- 模型在工业部署中对实时性能和计算资源有较高要求,在移动终端设备等场景下需在部署成本最小化的同时实现快速准确的预测结果。
- 为了提高模型的预测精度,在现有研究中多采用集成学习的方法进行改进工作,在保证复杂度的前提下实现了多模型协同完成预测任务。
- 在训练过程中将概率值分配给所有可能的答案选项,并非只关注高置信度的答案。
- 尽管其他类别答案的概率值极其微小且接近于零,在交叉熵评估标准下其对整体评估指标的影响非常有限。
举例说明中,请考虑以下步骤:通过训练一个CNN模型(全连接神经网络),给定一张宝马车图片作为输入样本后,在图像处理模块中完成特征提取和分类任务;经过一系列参数优化后,在测试集上达到了85.2%的准确率。
| label | 概率 |
|---|---|
| 宝马车 | 0.90 |
| 垃圾车 | 0.09 |
| 胡萝卜 | 0.01 |
我们观察到,在CNN模型经过训练之后表现出高度确定性地将图片归类为宝马汽车。相比之下,在被误判为垃圾汽车的情况下虽然概率较低(尽管概率很低),但其可能性是将某物误判为胡萝卜的九倍。此时的模型能否推断出它不仅能够正确识别宝马汽车这一类别,并且还具备识别其他不同种类的能力?在这种情况下,在进行模型压缩时(即进行模型精简),如果小型模型能够学习到这种潜在的知识而非仅仅依赖于答案记忆,则该模型的泛化能力将会得到显著提升。这正是本文的核心观点。
论文要点
神经网络主要依赖于通过logit转换后的值作为输入到softmax后端处理单元中来计算各类标签的概率值。Teacher模型的输出在经过softmax函数进行处理后会乘以指数e从而放大各分类之间的距离这一操作使得最终生成的概率分布呈现出类似于one-hot编码的结果这种结构在一定程度上限制了student模型的学习能力
因此,Hinton基于蒸馏温度的概念进行了相关研究. 通过计算z_i后,将每个label的概率转换为q_i.
q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}
蒸馏温度通常设为1,在常规的softmax操作中;
当温度较低时会加剧错误分类的概率,并导致额外的干扰。
根据损失函数求导的结果可知,
\frac{\partial C}{\partial z_i} 的计算公式为 \frac{1}{T}(q_i - p_i),
其中 q_i 和 p_i 分别表示真实类别和预测类别的概率分布。
进一步展开可得:
q_i = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}},\quad p_i = \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}
基于一阶泰勒展开,
e^{z_i/T} 可近似表示为 1 + z_i/T,
因此上述损失函数的梯度计算式可简化为:
\frac{\partial C}{\partial z_i} ≈ \frac{1}{T}\left( \frac{1+z_i/T}{N+\sum_j(z_j/T)} - \frac{1+v_i/T}{N+\sum_j(v_j/T)} \right)
假设在每个迁移样例中,
logit输出均为零均值分布,
即 \sum_j z_j = \sum_j v_j = 0。
在此条件下,
梯度计算式进一步简化为:
\frac{\partial C}{\partial z_i} ≈ (NT^{-2})(z_i - v_i)
经过蒸馏处理后,相较于之前采用的softmax方法,在计算梯度时相当于被乘上了\frac{1}{T^2}这个因子。由此可知,在保持损失函数L_{hard}不变的前提下,L_{soft}理论上需要乘以大约一个数量级才能维持相同的优化效果。
论文框架
根据设定相应的损失函数,在教师模型(Big Model)与学生模型(Small Model)之间建立相应的知识转移机制。论文算法的整体框架图如下:

Loss = \lambda * L_{soft} + (1-\lambda)*L_{hard}
当λ值越大时,表明模型对teacher网络的知识依赖程度越高;
L_{soft}用于评估学生网络从教师网络中学习知识的效果;
L_{hard}则用于评估学生网络在真实标签指导下准确回答问题的能力。
在之前的讨论中提到过其他预测类别的概率数值较低的情况。经过计算发现交叉熵损失值相对较小,在这种情况下能够提供的信息量有限。因此作者提出了软标签(soft target)这一概念作为解决方案
soft target
采用各个预测分布的算术平均数或几何平均数作为软目标设定依据,在优化过程中使soft target具有更高的熵;从而使得仅需较少的数据即可实现较大的学习率,并且在梯度下降过程中具有较低的偏差。
soft target被定义为teacher模型最后一层隐藏层H_{t}与student模型最后一层隐藏层H_{s}之间的对比,并在蒸馏温度T作用下进行Softmax处理以计算出最终的损失值
hard target
one-hot ground truth 通过Softmax函数处理后与真实标签Y对应的损失
实验
数据集:MNIST
教师模型:对包含两个隐藏层且每个层具有1200个单元的大型网络进行训练,并采用Dropout和权重约束作为正则化手段。
student Model:具有两层800个单元隐藏层没有正则的网络
- 如果小型网络通过引入一个新增任务来模仿大型网络生成的软目标实现,则能提升模型性能;
- 软目标不仅能够将大量知识转移至提取的模型中,在处理来自翻译后训练数据的学习过程中还能够掌握泛化能力;即便在转换集上没有任何实际应用
总结
这篇文章可被视为知识蒸馏领域的入门读物。其核心在于通过调节蒸馏温度T使大模型对各类标签的预测概率尽可能地迁移至小规模模型。该框架由官方支持的两个知名实验室——哈工大和讯飞实验室提供支持,并可通过TB进行使用。
