知识蒸馏(Knowledge Distillation)
知识蒸馏是一种通过预训练模型的知识迁移,提升小型模型性能的技术。以下是知识蒸馏的总结:
基本概念
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,通过将大型预训练模型的知识迁移到小型模型中,以提高小模型的性能。预训练模型作为教师模型,小型模型作为学生模型,学生模型通过模仿教师模型的输出和特征来学习知识。
核心方法
混合损失函数:通常结合分类损失(如CE loss)和知识蒸馏损失(如KD loss),并进行加权平均。KD损失计算学生模型输出与教师模型软标签的交叉熵损失。
\text{Total loss} = \alpha \times \text{CE loss} + (1-\alpha) \times \text{KD loss}
其中,α是权重参数,通常在0.1到1之间。
基于中间层特征的知识蒸馏:通过迁移教师模型的中间层特征,构建学生模型的全连接层权重。这种方式能够保留教师模型的深层知识,但需要较高的计算资源。
基于注意力的知识蒸馏:利用教师模型的注意力机制,迁移知识。注意力机制能够捕捉样本之间的关系,适用于序列数据(如NLP)。
基于对比学习的知识蒸馏:通过对比学习,教师模型学习样本之间的相似度,学生模型模仿这种相似度。这种方法适用于无标注数据和多模态数据。
应用领域
Transformer模型压缩:通过蒸馏教师模型的注意力机制,迁移知识到学生模型,实现模型压缩。
量化模型压缩:结合量化技术,进一步优化模型性能。
多模态模型:如CLIP与BART结合,实现视觉与语言知识的迁移。
结合其他技术
自蒸馏:教师模型既是教师也是学生,通过蒸馏过程优化自身结构。
量化的知识蒸馏:结合量化技术,优化模型压缩效果。
目标检测与实例分割:通过蒸馏教师模型的分类和定位能力,提升目标检测和实例分割模型的性能。
挑战与优化
蒸馏过程中的噪声:蒸馏过程可能引入噪声,影响学生模型的性能。
模型结构限制:蒸馏效果可能受到学生模型结构的限制,需要选择合适的结构以适应蒸馏需求。
蒸馏效率:蒸馏过程需要高效的算法和计算资源,以确保在合理时间内完成。
总结
知识蒸馏是一种高效的知识迁移技术,能够通过蒸馏过程提升小型模型的性能。通过结合多种方法和优化技术,可以进一步提高蒸馏效果,适用于多种应用场景,如Transformer模型压缩、多模态模型优化等。尽管存在一些挑战,但其在模型压缩和效率提升方面具有重要意义。
本文重点介绍与知识蒸馏相关的算法及应用情况。但首先需要明确的是,教师网络或给定的预训练模型中包含哪些可迁移的知识?基于常见的深度学习任务,可迁移知识包括:
中间层特征:浅层特征强调纹理细节,深层特征则强调抽象语义。
任务相关知识:涵盖分类概率分布、目标检测中的实例语义、位置回归信息等。
表征相关知识:强调迁移能力,同时关注表征间的相关性,如相似度和关系。

此外,知识蒸馏的应用大致来说包括哪些?包括模型压缩技术、迁移学习方法以及多教师信息融合技术等。
1、Distilling the Knowledge in a Neural Network
Hinton在该文章中首次系统性地提出了一种基于教师网络的知识蒸馏方法。该方法通过设计一种与教师网络相关的软目标函数,以指导学生网络进行训练,从而实现知识的有效迁移。

如图所示,教师网络(左侧)的预测输出经过除法运算后,再进行Softmax处理,得到一个平滑的概率分布(即软目标或软标签),其数值范围在0到1之间,分布较为平缓。当Temperature值越大,分布越平缓;反之,当Temperature值减小,容易放大错误分类的概率,引入不必要的噪声。对于较为困难的分类或检测任务,Temperature通常设置为1,以确保教师网络中正确预测的样本贡献得到充分重视。硬目标则是样本的真实标注,通常采用One-hot向量进行表示。Total loss设计为软目标与硬目标对应的交叉熵损失的加权平均(即KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移学习更依赖于教师网络的贡献,这对学生网络的初期训练阶段非常必要,有助于其更轻松地区分简单样本,但到了后期需要适当减小软目标的比重,以利用真实标注帮助学生网络鉴别困难样本。此外,教师网络的预测精度通常优于学生网络,而模型容量并无具体限制,且教师网络的推理精度越高,对学生网络的学习越有利。
教师网络与学生网络可实现协同训练,此时教师网络的隐含知识及学习方式将通过影响学生网络的学习效果。其中,三个关键损失项分别为教师网络Softmax输出的交叉熵损失、学生网络Softmax输出的交叉熵损失,以及教师网络数值输出与学生网络Softmax输出的交叉熵损失。
联合训练的Paper地址:https://arxiv.org/abs/1711.05852


深入研究深度神经网络知识提取在高效硬件实现中的应用
GitHub平台提供了名为kknowledge-distillation的PyTorch框架项目
这篇文章将Total loss重新定义如下:

PyTorch实现的Total loss框架引入了模型输出的精简版本与 teacher network 的输出之间的KL散度项。在 teacher training 的过程中,首先将 teacher network 的预测输出缓存至 CPU 内存中,从而减轻了GPU显存占用。
def loss_fn_kd(outputs, labels, teacher_outputs, params):
"""
Compute the knowledge-distillation (KD) loss given outputs, labels.
"Hyperparameters": temperature and alpha
NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
and student expects the input tensor to be log probabilities! See Issue #2
"""
alpha = params.alpha
T = params.temperature
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
F.cross_entropy(outputs, labels) * (1. - alpha)
return KD_loss
3、Ensemble of Multiple Teachers
论文链接:Efficient Knowledge Distillation from an Ensemble of Teachers
第一种算法首先通过加权平均的方式整合多个教师网络的输出结果,然后将整合后的Soft label用于指导学生网络的训练过程。


针对第二种算法而言,由于加权平均方法可能会削弱、平滑多个教师网络的预测结果,因此,可以选择某个教师网络的Soft label作为指导信息。

第三种算法:

4、Hint-based Knowledge Transfer
Paper地址:https://arxiv.org/abs/1412.6550
GitHub地址:https://github.com/adri-romsor/FitNets
为了诱导训练更深、更紧凑的结构(Deeper and thinner FitNet),需要考虑教师网络中间层的特征图(Feature Maps)作为提示(Hint layer),以指导学生网络中相应的指导层(Guided layer)。在此过程中,需要引入L2损失(L2 loss)来指导训练过程,该损失计算了教师网络提示层与学生网络指导层输出特征图之间的差异。当教师网络的提示层与学生网络的指导层输出特征图形状不一致时,需要在指导层上增加一个回归层(regression layer)来解决这一问题。


具体训练过程划分为两个阶段:第一阶段通过Hint-based loss诱导学生网络达到一个合理的初始状态,仅更新W_Guided和W_r。第二阶段通过教师网络的soft label进行知识蒸馏,且在Total loss中Soft target部分的比重逐渐降低,使学生网络能够全面区分简单样本与困难样本。在这一过程中,教师网络能够有效识别简单样本,而困难样本则需要依赖真实标注,即采用Hard target进行辅助训练。

5、Attention to Attention Transfer
Paper地址:https://arxiv.org/abs/1612.03928
GitHub地址:https://github.com/szagoruyko/attention-transfer
通过网络中间层模块的注意力机制,实现Teacher网络与Student网络之间的知识迁移。对于给定的Tensor A,基于激活值的注意力矩阵可以表示为以下三种形式:

随着网络层次的增加,关键区域的注意力水平显著提升。文章末尾采用了第二种形式的注意力可视化方法,其中p值设定为2。与梯度式注意力机制相比,激活式注意力机制在知识迁移方面表现更优。其损失函数定义及迁移流程如下:



6、Flow of the Solution Procedure
Paper地址:
超出预期的性能水平,该方法相较于asio表现出色,这归功于其简洁的设计和简明的代码结构。
暗知识也可视为训练过程中的求解途径。教师网络或学生网络的FSP矩阵定义如下:基于Gram矩阵的形式,该矩阵具体表示为...


训练的第一阶段旨在最小化teacher network与student network的FSP矩阵间L2损失,并初始化student network的可训练参数:

在目标任务的数据集上进行学生网络的微调训练,以实现促进知识迁移、加快收敛速度,同时实现迁移学习的目的。
Knowledge Distillation is enhanced through the integration of Adversarial Samples in shaping the Decision Boundary.
Paper地址:https://arxiv.org/abs/1805.05532
从分类任务的决策边界视角分析,在知识迁移过程中,可以理解为教师网络通过指导学生网络有效区分不同决策边界的过程。当模型在区分不同决策边界的能力越强时,其泛化能力就越强。

文章首先采用对抗攻击策略,将基准类样本数据集(Base class sample set)转换为目标类样本,这些样本位于决策边界附近(BSS: boundary supporting sample)。随后,文章采用迭代方法生成对抗样本,具体需要在Loss function(基准类得分与目标类得分之差)的梯度负方向上调整样本参数,直至满足预设的终止条件。

Loss function定义如下:

沿Loss function的梯度负方向调整样本:

停止条件(只要满足三者之一):

基于对抗生成的数据样本,通过教师网络对学生的网络进行训练,所需的总损失函数包括交叉熵损失、知识蒸馏损失以及边界支持损失(BS loss):

8、Label Refinery:Improving ImageNet Classification through Label Progression
GitHub地址:https://github.com/hessamb/label-refinery
该文章提出了一种基于迭代学习的诱导训练策略,其核心任务是解决样本的裁剪与标签不匹配的问题,从而有效提升模型的泛化能力。

在诱导过程中,Total loss被定义为本次迭代(t>1)网络预测输出(概率分布)与上一轮迭代输出(Label Refiner,类似于教师网络的角色)之间的KL散度。

文章实验部分表明,不仅能够使用训练网络作为Label Refinery Network,而且支持采用其他高质量网络(如Resnet50)作为Label Refinery Network。在对抗生成样本的诱导过程中,实现了数据增强的效果。
9、Meal V2 KD (Ensemble of Multi-Teachers)
Paper地址:https://arxiv.org/abs/2009.08453
GitHub:https://github.com/szq0214/MEAL-V2
从核心理念来看,MEAL V2通过知识蒸馏技术,将多个Teacher模型的输出融合并迁移至一个Student模型中,具体而言,包括Teacher模型集成、KL散度损失函数以及判别器网络。
多个教师的预测概率取平均值,以获得更稳定的估计;
模型仅依赖于教师的软标签信息,而未引入其他外部数据;
判别器通过生成对抗训练机制,对生成器的输出进行有效约束;
学生网络从预训练模型出发,通过蒸馏技术显著降低了训练成本。

10、KD for Lightweight Face Detector
Paper地址:
该方法在知识蒸馏方面表现优异,得益于其简洁的设计和高效的代码实现。
人脸检测模型的分类预测输出,通常被定义为二分类问题(背景为0,人脸为1)。在教师网络与学生网络之间,通常观察到Classification map的差异显著大于Regression map的差异。此外,Classification map提供的Soft label,更容易作为监督信息。另外,需要基于教师网络与学生网络输出得分的差异,筛选出简单样本并实现在线难例的挖掘,以有效监督学生网络的学习过程。
Loss function的实现如下:


def kd_loss(teacher_output, student_output, alpha=50.0):
teacher_output = F.softmax(teacher_output, dim=-1)
student_output = F.softmax(student_output, dim=-1)
scale = 16.2 # when beta=6.4 and gamma=3.2
beta = 6.4
threshold = scale * torch.pow(torch.abs(teacher_output[:, :, 1] - 0.5), beta)
mask = teacher_output[:, :, 1] > threshold
t_feat = teacher_output[mask]
f_feat = student_output[mask]
loss = torch.nn.functional.mse_loss(t_feat, f_feat)
return loss * alpha
11、Relational Knowledge Distillation
Paper地址:https://arxiv.org/abs/1904.05068v2
[mdistiller/RKD.py at master · megvii-research/mdistiller · GitHub](https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/RKD.py "mdistiller/RKD.py at master · megvii-research/mdistiller · GitHub)
RKD通过提取样本表示特征之间的Relation用于表示知识,并通过从Teacher到Student的知识迁移实现知识共享,其对应的损失函数定义如下:

其中t和s分别代表Teacher和Student的样本特征表示(如表征方法、Backbone网络提取语义特征等);

表示特征之间的Relation;l表示Relation之间的距离函数。
Relation的定义越高级,其迁移效果越好。文章主要针对Distance-wise与Angle-wise两种Relation进行定义。
- Distance-wise Distillation Loss:




- Angle-wise Distillation Loss:




12、Knowledge Distillation meets Self-supervision
Paper地址:https://arxiv.org/pdf/2006.07114.pdf
SSKD项目
SSKD仓库
个人SSKD项目
SSKD代码仓库
SSKD开源项目
SSKD项目仓库
SSKD代码库
SSKD项目资源
SSKD开源项目
SSKD代码仓库
SSKD项目资源
SSKD(采用自监督学习作为辅助任务,用于知识蒸馏)中,学生网络在经过数据变换后的数据集上,结合自监督辅助任务,实现了更为丰富的结构化知识迁移。传统知识蒸馏(KD)中,学生网络模仿教师网络在任务层的预测输出(如分类、位置回归等);而在SSKD中,在变换后的数据集和自监督辅助任务上,能够实现更为丰富的结构化知识迁移。对比学习因其在自监督学习中的卓越表现,成为SSKD中选择的自监督辅助任务。对比学习通过使网络区分正负样本,最大化每个样本变换前后的相似度(基于Contrastive loss),使得模型学习到具有变换不变性的表征能力。文章使用余弦函数以衡量不同表征之间的相似度,并构造相似度矩阵;然后通过Cross entropy loss,以衡量变换后样本与某一原样本是否为正样本对:
# 4 means one original sample and three augmented samples
batch = int(x.size(0) / 4)
nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
aug_index = (torch.arange(4*batch) % 4 != 0).cuda()
# rep is the representation features of all samples
nor_rep = rep[nor_index]
aug_rep = rep[aug_index]
nor_rep = nor_rep.unsqueeze(2).expand(-1, -1, 3*batch).transpose(0, 2)
aug_rep = aug_rep.unsqueeze(2).expand(-1, -1, 1*batch)
# cosine similarity is used for similarity matrix
simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
target = torch.arange(batch).unsqueeze(1).expand(-1, 3).contiguous().view(-1).long().cuda()
loss = F.cross_entropy(simi, target)
SSKD一方面要求学生掌握与任务相关的知识(包括正样本和负样本的任务预测),另一方面,要求学生模仿教师的特征表征能力,以便有效区分正负样本(教师网络需先进行对比学习,获取表征能力以供迁移)。总的损失函数表示如下:



其中Lce代表基于硬标签的损失函数,Lkd代表原样本的Hinton KD损失,Lss代表表征知识的迁移能力,LT代表变换样本的Hinton KD损失;B表示通过Softmax计算得到的相似度矩阵,其反映了学生网络对教师网络预测样本对相似度的概率分布。在训练过程中,通过OHEM算法筛选高质量的变换样本用于计算LT与Lss的损失,排序依据为教师网络的Soft-label和相似度矩阵。其中LT的OHEM过程如下:
aug_target = target.unsqueeze(1).expand(-1, 3).contiguous().view(-1).long().cuda()
rank = torch.argsort(aug_knowledge, dim=1, descending=True)
rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1) # groundtruth label's rank
index = torch.argsort(rank)
tmp = torch.nonzero(rank, as_tuple=True)[0]
wrong_num = tmp.numel()
correct_num = 3 * batch - wrong_num
wrong_keep = int(wrong_num * args.ratio_tf)
index = index[:correct_num + wrong_keep]
distill_index_tf = torch.sort(index)[0]
SSKD方法特别适合标注不充分的场景(对于无标注场景,可以去掉Lce模块)、同时支持Few-shot学习(依赖对比学习实现表征知识的迁移)。此外,由于SSKD主要依赖于Final layer的迁移知识,因此也适用于异质网络的诱导训练。SSKD的具体应用框架如图所示,教师网络与学生网络均由三部分组成:用于特征提取的Backbone模块、用于主任务的分类器以及用于辅助任务的自监督模块。

13、Contrastive Pruning
Paper地址:https://arxiv.org/abs/2112.07198
与其相比,该方法在性能上表现出色,主要得益于其简单的设计和简洁的代码。
在下游任务的微调过程中,执行BERT剪枝与知识蒸馏;KD 核心思想包括:
Teacher models不仅包括预训练模型(这些模型主要承载了与任务无关的表征知识)、微调过的模型(这些模型主要承载了与任务相关的知识)以及剪枝过程中保存的模型(这些模型保留了历史剪枝过程中的信息);处于剪枝状态的模型则作为Student model。知识蒸馏涵盖任务相关的KD(如软标签)以及与任务无关的KD(如自监督和对比损失)。


14、Few Sample Knowledge Distillation
Paper地址:https://arxiv.org/abs/1812.01839
GitHub地址:GitHub - LTH14/FSKD
在数据安全性和隐私保护的考虑下,用户通常仅提供少量无标注数据集(如少量无标注数据集),此时对压缩模型的应用面临严峻挑战。FSKD通过基于蒸馏的低成本方法(耗时短、数据利用率高)实现,从而为模型压缩精度恢复提供了可靠保障。
- 通过剪枝、张量分解等技术实现模型压缩;
- 以压缩模型作为Student,在每个Block-level之后添加1x1 conv、以对齐Teacher的Block-level特征输出(实现Teacher知识迁移);并以Least square regression为优化目标,通过Block coordinate descent方式逐层、渐进式优化求解1x1 conv的参数,优化目标如下:

- 将1x1 conv合并到前一个卷积层, 获得推理模型,即

;


15、Vision-language Knowledge Distillation
Paper地址:https://arxiv.org/abs/2203.06386
为了优化多模型生成任务的表现,基于CLIP构建Teacher模型,基于BART (Encoder-decoder结构)构建Student模型,实现了多模态表征知识的跨模态迁移:
为了优化多模型生成任务的表现,基于CLIP构建Teacher模型,基于BART (Encoder-decoder结构)构建Student模型,实现了多模态表征知识的跨模态迁移:
- CLIP包含Image encoder与Text encoder,具备统一、共享的多模态表征空间 ,能够实现视觉表征与文本表征的对齐;
- 为了实现CLIP表征知识的迁移 ,引入了TTDL (Text-Text Distance Minimization)、ITCL (Image-Text Contrastive Learning)与ICTI (Image-Conditioned Text Infilling)三个迁移任务以对齐多模态表征 ;迁移训练期间,CLIP主干参数冻结;
- TTDL (L2 distance loss):

* **ITCL (InforNCE loss, used in contrastive learning):**



* **ICTI (Sum of log-softmax loss, used in regressive decoding):**

* **Total loss:**


16、Decoupled Knowledge Distillation
Decentralized Knowledge Transfer: Its Core Mechanism Is the Realization of Target Distribution and Non-Target Distribution Decoupling
17、Miscellaneous
知识蒸馏可以与量化技术相结合,考虑到各层特征图之间的关系,参考文献:
结合量化的知识蒸馏(Quantization Mimic)_AI Flash-博客
-------- 知识蒸馏与Hint Learning相结合,可以训练精简的Faster-RCNN,可参考:
令人惊讶地超越了OpenCV 性能的效率,这得益于其简洁的设计理念和高效的代码结构。
焦点与全局知识蒸馏——目标检测模型的知识蒸馏
网络结构搜索(NAS)同样可以采用蒸馏操作,以提升搜索性能。参考Cream NAS的蒸馏方法:
自蒸馏One-shot NAS在性能上令人惊喜地超越了asiotop,这得益于其简洁的设计和高效的代码实现。
知识蒸馏在模型压缩任务中,主要基于Self-attention Knowledge Distillation方法,具体参考:
这一发现令人出乎意料,表明该方法在性能上超越了asiot。其优势主要归因于简洁的设计理念和高效的代码实现。
出乎意料地超越了asio的性能表现,我想这得益于其设计上的简单性和代码的简洁性。
-------- 模型压缩方面,更为详细的讨论,请参考:
该文详细介绍了深度学习模型压缩与优化加速的相关技术与方法,深入探讨了压缩策略及其对性能提升的关键作用。
