Advertisement

大模型压缩方法之知识蒸馏

阅读量:

知识蒸馏 的训练过程是通过结合硬标签损失 (( L_{\text{hard}} ))和软标签损失 (( L_{\text{soft}} ))进行反向传播,更新学生模型的参数。

具体流程如下:

前向传播

复制代码
 * **教师模型** 和 **学生模型** 分别对相同的输入数据进行前向传播,计算它们各自的输出。
 * 教师模型的输出会生成“软标签”,即通过温度系数平滑过的类别概率分布。学生模型则输出它自己的类别概率分布。

计算损失

复制代码
 * **硬标签损失** (( L_{\text{hard}} )):这是学生模型的输出与真实标签之间的交叉熵损失,通常用于确保学生模型在最终的任务上取得好的性能。
 * **软标签损失** (( L_{\text{soft}} )):这是学生模型的输出概率分布与教师模型输出的“软标签”之间的差异,通常使用**KL散度** (Kullback-Leibler Divergence)来度量。通过软标签损失,学生模型能从教师模型的特征中学到更多细节信息。

损失函数通常是两者的加权和,公式如下:
[
L = \alpha \cdot L_{\text{hard}} + (1 - \alpha) \cdot L_{\text{soft}}
]

复制代码
 * ( \alpha ) 是一个超参数,用来控制硬标签损失和软标签损失的相对权重。
 * 温度系数 ( T ) 通常用于软化教师模型输出的概率分布,使其更加平滑,能提供更多类别之间的相关性信息。

反向传播

复制代码
 * 计算出的**总损失** ( L ) 会通过**反向传播** (Backpropagation)过程,更新学生模型的参数。
 * 在反向传播过程中,损失函数的梯度会通过链式法则从输出层传回到模型的每一层,逐步调整模型参数,最终提升学生模型的表现。

迭代训练

复制代码
 * 重复执行上述的前向传播、损失计算和反向传播,直到学生模型在训练集上达到期望的性能或者达到预设的训练轮数。

关键要点:

  • 硬标签损失 ( L_{\text{hard}} ) 强调学生模型能够正确地学习真实标签,确保学生模型在任务上的准确性。
  • 软标签损失 ( L_{\text{soft}} ) 则让学生模型学习教师模型的类别概率分布,使其能够捕捉更丰富的特征和类别之间的相关性。
  • 反向传播 是通过计算出的总损失来更新学生模型的参数,最终优化学生模型的性能。

总结

知识蒸馏的反向传播过程是基于总损失函数 ,即硬标签损失软标签损失 的加权和。这个损失函数通过反向传播来优化学生模型的参数,使学生模型不仅能学习真实标签,还能从教师模型中吸收更多深层次的知识。

全部评论 (0)

还没有任何评论哟~