Advertisement

深度学习模型压缩算法综述(三):知识蒸馏算法+实战

阅读量:

深度学习模型压缩算法综述(三):知识蒸馏算法

  • 本文严禁转载
  • 项目位置:
  • 联系人:
  • 知识蒸馏技术:
    • 分类模型的训练目标:

    • 两阶段检测模型的训练目标:

      • 基本思路:

      • 损失函数的定义:

      • one-stage检测模型训练目标:

        • 基本思路:
        • 损失函数:
        • 基于特征的NMS算法:
      • Deep Mutual Learning:

        • 基本思路:
        • 基本原理:
        • 迭代训练:
    • 目标检测模型知识蒸馏实验一:

      • 环境参数:
      • 项目地址:
      • 实验结果:
    • 目标检测模型知识蒸馏实验二:

      • 环境参数
      • 项目地址:
      • 实验结果:
    • 关注我的公众号:

    • 联系作者:

在这里插入图片描述

本文禁止转载

项目地址:

https://github.com/Sharpiless/yolov5-distillation-5.0

联系作者:

B站:https://space.bilibili.com/470550823

:<>

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

知识蒸馏:

知识蒸馏的主要目的是让新模型(通常是一个参数量更少的模型)近似于原模型(模型即函数)。值得注意的是,在机器学习中,我们通常假设输入与输出之间存在潜在的函数关系,该函数本身是未知的:从零开始训练新模型的过程,相当于从有限的数据中近似一个未知的函数。如果让新模型近似于原模型,由于原模型的函数是已知的,我们可以利用大量非训练集内的伪数据来训练新模型。

分类模型训练目标:

过去需要新模型的输出概率分布与真实标签完全吻合,但现在我们只需确保新模型在给定输入下的输出概率分布与原模型的匹配即可。然而,由于softmax函数本质上只是对arg max的一种近似,它所承载的知识(即输出的概率信息)十分有限。因此,常用的方法是直接让新旧模型在logits输出层达到一致,即最小化:

在这里插入图片描述

该文则探讨了一个扩展版的softmax函数:

在这里插入图片描述

其中T是一个类似于温度的参数,当温度T趋近于0时,softmax函数的输出将趋于一个one-hot向量;当温度T值增大时,softmax输出的分布会变得更加平缓。因此,在蒸馏方法中,我们通常采用较高的T值进行训练,而使用T=1进行推理,这正是蒸馏方法之所以得名的原因。

此时的损失函数(最小化两个分布的交叉熵)变为:

在这里插入图片描述

该损失函数对于某个logit zi的梯度为:

在这里插入图片描述

已知当x趋向于0时,ex-1与x是等价无穷小,则当T足够大是,有:

在这里插入图片描述

如果假设所有logits对每个样本都是零均值化,即:

在这里插入图片描述

则有:

在这里插入图片描述

当迁移数据集上的真实标签已知时,可以通过同时提升新模型输出分布与真实标签分布的交叉熵来优化性能。在损失加权处理中,建议将软标签的损失项乘以权重T2,以确保各梯度量级保持一致。

two-stage检测模型训练目标:

基本思路:

该研究或文章可能是第一篇系统性探讨利用知识蒸馏方法提升目标检测网络效率的论文。研究者指出,尽管知识蒸馏在提升基础分类模型性能方面表现出色,但将其应用于目标检测模型时,会遇到回归问题、区域建议问题以及类别标注不一致的问题。为了解决这些问题,研究者采用了加权交叉熵损失函数来解决类别分布不均衡的问题,使用教师边界损失函数来解决回归问题,并通过引入可调节层来更有效地从教师网络的中间层特征中获取知识。

在这里插入图片描述

损失函数:

基本的损失函数如下:

在这里插入图片描述

其中:

N和M分别是RCN和RPN的batch-size

Lcls由两个部分组成,其中一部分是学生网络输出与真实标定之间的损失,另一部分是教师网络输出与学生网络之间的损失:

在这里插入图片描述

其中:

在这里插入图片描述

在分类模型中,权重wc代表不同类别(包括two-stage架构中的背景类)的分类权重,其作用在于平衡教师网络对背景类样本输出的重视程度。例如,在VOC数据集上,通常将wbg设定为1.5。此外,在分类模型蒸馏过程中,可引入温度参数T进行辅助优化。

由平滑损失和提出的教师边界损失组成:

在这里插入图片描述

其中m被定义为阈值,即表示当学生网络的效能超过教师网络的特定数值时,该学生网络的损失不再被计算。

Lhint是启发式的损失函数,鼓励学生模仿老师的特征表示:

在这里插入图片描述

V被定义为学生网络中的guided layer,而Z则代表教师网络中的hint layer。即通过引入特征向量进行匹配,以期使学生网络能够学习到更多的特征表示。当hint layer与guided layer在结构上不匹配时,需要引入一个适应层来进行转换。具体而言,当hint layer和guided layer均为全连接层(FC)时,可以直接使用FC进行转换;若均为卷积层(conv layer),则采用1x1卷积层来进行匹配。此外,作者发现即使在通道数相同的情况下,引入一个额外的适应层也有助于提升知识迁移的效率。

one-stage检测模型训练目标:

基本思路:

One-stage目标检测任务的训练难度较高,主要原因在于teacher网络能够预测出更多的背景边界框。直接将teacher网络的预测结果作为student进行学习的soft label,容易导致类别分布失衡的问题。为了解决这一问题,需要引入新的解决方案。

该文章主要针对前景检测问题提出了一种解决方案,具体而言,为解决该问题而提出了一种解决方案,即对YOLOv3模型的三个不同head分别适配了不同的蒸馏损失函数,并对其中的分类和回归损失函数应用objectness分值进行抑制,以有效解决前景与背景类别不均衡的问题。

此外,该研究主要采用未标注数据构成蒸馏损失,并将检测损失(基于标注数据)进行加权求和,最终形成总损失函数。

在这里插入图片描述

损失函数:

假设o、p、b分别代表objectness、类别概率以及坐标框的位置,其中,YOLO的目标检测损失函数通常表示为:

在这里插入图片描述

本文针对这几个损失函数分别进行优化,其核心思路在于仅当teacher网络的objectness value较高时,才对bounding box的坐标和class的概率进行学习。

  1. object loss:
在这里插入图片描述
  1. classification loss:
在这里插入图片描述
  1. regression loss:
在这里插入图片描述

基于特征的NMS算法:

如果不进行NMS,直接将teacher network的预测框未经处理地传递给student network,student network接收到来自object的极大值损失,经过几轮训练后会对其过拟合。为了应对上述问题,这里采用了借鉴NMS算法的feature map NMS方法,用于去除student网络中因重复预测框而产生的多余计算。

在这里插入图片描述

该文章基于以下假设:在网格单元中,若相邻的grid cell对目标进行预测时,其预测的边界框类别一致,则这些预测的边界框很可能对应同一个物体。基于此假设,该算法的具体实现步骤如下:首先,遍历所有3×3的grid cell组,对每组grid cell中类别相同的边界框,计算其预测得分,并按得分从高到低排序,最终选取得分最高的边界框进行传播。

Deep Mutual Learning:

基本思路:

在现有技术中,蒸馏模型通常是从功能强大的大型网络或集成网络向结构简单且运行快速的小型网络转移。本文提出了一种创新性方法,突破了传统固定教师学生关系的限制,提出了一种深度相互学习策略(deep mutual learning, DML)。该策略中,一组学生网络在整体训练过程中协同进化,实现相互学习与指导,而非传统的静态预定义教师与学生之间的单向知识传递通道。通过协同训练,该方法能够显著提升模型的泛化能力。

在这里插入图片描述

基本原理:

假设我们有这\theta_0,...,\theta_K几个分类网络,他们对应的logits输出为z^m_j

对于M个类别的N个样本:

在这里插入图片描述

我们可以计算\theta_j(下式j=1)对于输入x_i预测为类别m的概率:

其交叉熵损失函数为:

在这里插入图片描述

为了增强模型θ₁的泛化能力,我们可以借鉴另一个网络θ₂的经验作为参考。K-L散度是衡量两种概率分布P和Q之间差异的一种指标,又被称作相对熵。在概率学和统计学领域,我们经常采用一种更简单的、近似的分布来替代观察数据或过于复杂的分布。K-L散度有助于评估用一个分布近似另一个分布时所损失的信息量:通过使用K-L散度,我们可以更加高效地筛选出更优的近似分布。

在这里插入图片描述

此时两个模型的损失函数就可以更新为:

在这里插入图片描述

迭代训练:

DML在每次训练迭代过程中,首先生成两个模型的预测结果,随后利用另一个模型的预测结果来更新两个网络的参数,持续迭代直至达到收敛状态。

在这里插入图片描述

目标检测模型知识蒸馏实验一:

环境参数:

数据集名称:VOC2012;计算设备:V100×1;批量大小:8;训练步数:70,000;基础模型:Yolov3-MobileNetV1;教师模型:Yolov3-Resnet34

项目地址:

https://aistudio.baidu.com/aistudio/projectdetail/1878840

实验结果:

在这里插入图片描述

目标检测模型知识蒸馏实验二:

环境参数

数据集:VOC2007;批量大小:8;训练周期数:100;基线模型:Yolov5s;教师模型:Yolov5l/m(mAP50值为0.80/0.81);温度:4

项目地址:

https://github.com/Sharpiless/yolov5-distillation-5.0

实验结果:

在借鉴《Object detection at 200 Frames Per Second》论文中的损失函数框架时,除了在蒸馏训练过程中采用logits项的蒸馏损失外,其余蒸馏损失均采用了L2损失。此外,所采用的teacher model在准确性上具有相近的水平,但其参数规模存在差异。

在这里插入图片描述
在这里插入图片描述

关注我的公众号:

感兴趣的同学关注我的公众号——可达鸭的深度学习教程:

在这里插入图片描述

联系作者:

B站:https://space.bilibili.com/470550823

:<>

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

全部评论 (0)

还没有任何评论哟~