《Multi-Scale Aligned Distillation for Low-Resolution Detection》论文笔记
一、介绍
在实例检测阶段利用深度学习进行推理时会消耗大量计算资源
因此我们希望通过高分辨率图片训练的模型(称为low-res student)来提升性能。具体方法是通过将高分辨率图片训练的模型(称为high-res teacher)进行知识蒸馏。然而存在一个问题,在相同的网络阶段中(high-stage),high-res teacher和low-res student输出的质量与细节存在显著差异。例如,在经过第一个卷积层后输入为8xx x8xx gray图像会生成4xx x4xx大小的大规模特征图;而相对应地,在经过第一个卷积层后输入为4xx x4xx gray图像会生成2xx x2xx大小的小规模特征图;一般的方法是对high-res教师产生的大尺度特征图像进行下采样处理以匹配目标尺寸;但这种做法会严重破坏high-res教师预测所获得的信息量
作者采用了一种方法来实现特征图间的对齐以解决该问题。在构建特征金字塔时,每个网络层级所对应的特征图尺寸为上一层级的一半,从而使得学生网络接收输入图像时其分辨率同样为教师的一半,这一设计有助于不同层级之间的特征图实现有效对齐


原本特征图不对齐****现在对齐了
特征图的配准旨在为高分辨率教师与低分辨率学生之间的知识蒸馏过程提供基础

直接方法是仅训练一个高分辨率教师模型以指导学生模型;然而该教师模型所学习的特征对多分辨率输入不兼容(poorly compatible),导致其不太适用学生模型;为此作者期望训练一个具备多分辨率知识的老师模型,并在此基础上将其知识蒸馏给学生模组。为此提出了对齐多尺度学习(Aligned Multi-Scale Training, AMST)以及交叉特征级融合模块(Crossing Feature-Level Fusion Module, CFHFm)的方法。研究者采用FCOS作为检测器,并发表名为《FCOS: Fully Convolutional One-Stage Object Detection》的研究论文
二、教师模型训练
1.对齐多尺度学习(Aligned multi-scale training)

在训练过程中,在教师端被用来训练两组模型:一组处理高分辨率图像(左侧接收),另一组处理低分辨率图像(右侧接收)。其中右侧接收的低分辨率图像尺寸仅为左侧原始高分辨图像的一半,在结构图中可清晰观察到这一比例关系。每次迭代时都会应用一个随机缩放因子α以调整原始图像尺寸(范围限定于[0.8, 1]之间)。
特征金字塔(FPN)结构通过主动学习机制提取出宽度为HPs、高度为WPs并带有256通道数的特征图谱(其中Ps代表第s层特征金字塔生成的具体特征图谱)。这些特征图谱之间仅相差一个层级:例如P3与P2'特征图谱具有相同的尺寸,P4与P3'特征图谱同样尺寸一致,依此类推。经过FPN完成两个分辨率模型提取后,将这些特征传递给检测头模块进行后续处理

其中

其中,
AMML\_loss = \alpha Lcls + \beta Lreg + \gamma Lctr。
我们因此定义对齐多尺度学习损失(AMML loss)为:
其中,
AMML\_loss = \alpha Lcls + \beta Lreg + \gamma Lctr。

2.交叉特征级融合( Crossing feature-level fusion)
相比而言,在进行小物体检测时

该模块接收两个具有相同空间尺寸的feature map Ps 和 P's-1(假设为 HPs×WPs×C),随后对其进行拼接操作(concatenation),导致通道数从原来的 C 增加至 2C。接着执行全局池化过程:将每个 channel 上的 H×W 个像素值相加,并除以 H×W 得到一个数值结果。这样一来,2C 个 channel 将分别生成 2C 个数值结果。因此 Global pooling 的输出结果为 1×1×2C 的三维数组。随后经过全连接层处理以完成维度转换,并应用 ReLu 激活函数。最终生成一个 1×1×2 维的空间尺寸特征图,并通过 softmax 函数将其映射为概率权重分布(即置信度)。最后将原始 feature maps Ps 和 P's-1 分别与其对应的权重相乘后相加,从而得到最终输出 PsT。
PsT 损失计算公式为

λ是loss weight,(H&H') × (W&W')表示融合后的特征图,

的计算方法跟

是一样的。
于是教师模型的损失函数为:

三、知识蒸馏
在得到multi-resolution 特征图 PsT后,就可以蒸馏给学生模型了。

其中

,m=k/2,k是低分辨率图像的缩放倍数,这里k=2。τ 是超参数,代表loss weight。
学生模型所受的损失包含知识蒸馏损失以及original detection_loss两个部分_其计算公式如下:Loss = L_{KD} + L_{OD}

γ是loss weight
