2023 Curriculum Temperature for Knowledge Distillation
论文地址:https://arxiv.org/abs/2211.16231
代码地址:https://github.com/zhengli97/CTKD
1 研究动机与研究思路
研究动机 :大多数现有的蒸馏方法忽略了温度在损失函数中的灵活作用,将其固定为超参数。一般而言,温度控制着两种分布之间的差异,确定蒸馏任务的难易程度。保持一个恒定的温度,即固定的任务难度,在渐进学习阶段通常是次优的。
研究思路 :本文提出了一种简单的基于课程的技术,称为知识蒸馏课程温度( CTKD ),它是一个动态温度超参蒸馏新方法。具体来说,遵循由易到难的课程设置,随温度的变化逐渐增加蒸馏损失,以对抗的方式导致蒸馏难度的增加。
(在我的数据集上实验结果并不理想)
2 主要工作
本文的主要工作:
- 本文提出在学生的训练过程中使用反向梯度对抗学习动态温度超参数,以最大化师生之间的蒸馏损失。
- 本文引入了简单有效的课程,通过一个动态和可学习的温度参数,从易到难地组织蒸馏任务。
3 方法
3.1 知识蒸馏
传统的两段蒸馏过程通常以预先训练的繁琐的教师网络开始。然后在教师网络的监督下以soft预测或中间表示的形式训练一个紧凑的学生网络。采用带有温度超参的KL Divergence Loss散度损失最小化学生和教师模型的soft输出概率差异,从而在教师模 型和学生模型之间进行蒸馏, 公式如下:
L_{k d}\left(q^t, q^s, \tau\right)=\sum_{i=1}^I \tau^2 K L\left(\sigma\left(q_i^t / \tau\right), \sigma\left(q_i^s / \tau\right)\right)
其中,q^t, q^s分别表示教师和学生产生的logit,\sigma ( · )为softmax函数.温度超参 \tau 用来衡量两个分布 q^t 和 q^s 的平滑程度,决定了两个概率分布间的距离, \tau 越大( \tau>1) ,就会使得概率分布越平滑(soft), \tau 越小 (0<\tau<1) ,越接近0,会使得概率分布越尖锐(sharp)。 \tau 的大小影响着蒸馏中学生模型学习的难度,而现有工作普遍的方式都是采用固定的温度超参,一般会设定成4。
3.2 对抗性蒸馏
针对原始蒸馏任务,以最小化任务特定损失和蒸馏损失为目标,对学生进行优化。蒸馏过程的目标可以表述如下:
\begin{aligned} \min _{\theta_{s t u}} L\left(\theta_{s t u}\right) & =\min _{\theta_{s t u}} \sum_{x \in D} \alpha_1 L_{t a s k}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^l\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{\text {stu }}\right), \tau\right) .\end{aligned}
其中L_{t a s k}是图像分类任务的正则交叉熵损失,f^L(\cdot) 和 f^s(\cdot)是教师和学生的函数,\alpha_1和\alpha_2是平衡权重。
为了通过动态温度控制学生的学习难度,受GANs的启发,本文提出对抗学习一个动态温度模块\theta_{\text {temp }},该模块预测一个适合当前训练的温度值 \tau 。该模块在与学生相反的方向上进行优化,旨在最大化学生与教师之间的蒸馏损失。与原始蒸馏不同,学生\theta_{s t u}和温度模块\theta_{\text {temp }} 以如下价值函数L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) 进行两人极大极小不等式博弈:
\begin{aligned} & \min _{\theta_{\text {stu }}} \max _{\theta_{\text {temp }}} L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) \\ & =\min _{\theta_{s t u}} \max _{\text {temp }} \sum_{x \in D} \alpha_1 L_{\text {task }}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^t\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{s t u}\right), \theta_{\text {temp }}\right) . \end{aligned}
采用交替算法求解方程中的问题,固定一组变量,求解另一组变量:
\begin{aligned} \hat{\theta}_{\text {stu }} & =\arg \min _{\theta_{\text {stu }}} L\left(\theta_{\text {stu }}, \hat{\theta}_{\text {temp }}\right) \\ \hat{\theta}_{\text {temp }} & =\arg \max _{\theta_{\text {temp }}} L\left(\hat{\theta}_{\text {stu }}, \theta_{\text {temp }}\right)\end{aligned}
通过随机梯度下降( SGD )进行优化,学生\theta_{s t u}和温度模块\theta_{\text {temp }} 参数更新如下:
\begin{aligned} \theta_{s t u} & \leftarrow \theta_{s t u}-\mu \frac{\partial L}{\partial \theta_{s t u}} \\ \theta_{t e m p} & \leftarrow \theta_{t e m p}+\mu \frac{\partial L}{\partial \theta_{\text {temp }}}\end{aligned}
通过一个非参数化的梯度反转层( Gradient Reversal Layer,GRL )来实现上述对抗过程,在softmax层和可学习温度模块之间插入GRL,知识蒸馏课程温度( CTKD )如图所示。

图1:知识蒸馏课程温度
( a )引入了一个可学习的温度模块来预测合适的蒸馏温度,使用梯度反转层来反转反向传播过程中温度模块的梯度。( b )遵循先易后难的课程设置,逐步增加参数λ,导致学生的学习难度
3.3 课程温度
保持恒定的学习难度对于一个成长中的学生来说是次优的。受课程学习的启发,本文进一步介绍了一个简单而有效的课程,它通过直接将损失L按\lambda大小温度来组织蒸馏任务,即L \rightarrow \lambda L。因此,\theta_{\text {temp }} 将被更新:
\theta_{t e m p} \leftarrow \theta_{t e m p}+\mu \frac{\partial(\lambda L)}{\partial \theta_{t e m p}}
将初始λ值设置为0,使得低年级学生可以专注于学习任务而不受任何约束。通过逐步提高λ,随着蒸馏难度的增加,学生学习到更高级的知识。具体而言,遵循课程学习的基本理念,本文提出的课程满足以下两个条件:
-
给定唯一变量 \tau ,蒸馏损失温度模块(简化为\left.L_{k d}(\tau)\right))逐渐增大,即
L_{k d}\left(\tau_{n+1}\right) \geq L_{k d}\left(\tau_n\right) -
λ值增大,即
\lambda_{n+1} \geq \lambda_n
式中:n表示第n步训练。当在En epoch处训练时以如下的余弦调度逐步增加λ:
\begin{aligned} \lambda_n & =\lambda_{\min } \\ & +\frac{1}{2}\left(\lambda_{\max }-\lambda_{\min }\right)\left(1+\cos \left(\left(1+\frac{\min \left(E_n, E_{\text {loops }}\right)}{E_{\text {loops }}}\right) \pi\right)\right.\end{aligned}
式中:\lambda_{\max }和\lambda_{\min }为\lambda的取值范围。E_ {loops}是难度尺度\lambda逐渐变化的超参数。本文默认设置\lambda_{\max },\lambda_{\min }和E_ {loops}分别为1、0和10。该过程表明参数\lambda在10个训练周期内从0增加到1,并一直保持1直到结束。详细的消融研究见表6和表8。
3.4可学习温度模块
可学习温度模块有两个版本,即Global - T和Instance - T。

全局版本只包含一个可学习的参数,对所有实例预测一个值T_{\text {pred }},如图2 ( a )所示。这种高效的版本不会给蒸馏过程带来额外的计算成本,因为它只涉及一个可学习的参数。
Instance-T.为了获得更好的蒸馏性能,一个全局温度对于所有实例都不够准确。我们进一步探索了实例变量Instance - T,它对所有实例单独预测一个温度值,例如,对于一批128个样本,本文预测128个对应的温度值。本文提出利用概率分布的统计信息来控制自身的平滑性。具体来说,本文引入了一个2层MLP,将两个预测作为输入,输出预测值T_{\text {pred }},如图2 ( b )所示。在训练过程中,该模块会自动学习原始分布和平滑分布之间的隐含关系。为了保证温度参数的非负性并使其值保持在合适的范围内,我们用下面的公式对预测的T_{\text {pred }}进行标度:
\tau=\tau_{\text {init }}+\tau_{\text {range }}\left(\delta\left(T_{\text {pred }}\right)\right)
其中 \tau_{\text {init }} 为初始值,\tau_{\text {range }}为\tau的取值范围,\delta(\cdot)为sigmoid函数,T_{\text {pred }}为预测值。默认设置 \tau_{\text {init }}和\tau_{\text {range }}为1和20,这样可以包含所有的正常值。
与Global - T相比,Instance - T由于具有更好的表示能力,可以获得更好的精馏性能。
4 算法伪代码

