Advertisement

【KD】What Makes a “Good“ Data Augmentation in Knowledge Distillation -- A Statistical Perspective

阅读量:

论文链接信息:

**一、**研究背景与动机

知识蒸馏(knowledge distillation,KD)是一种通用神经网络训练方法,通过教师模型KD的方式,帮助学生模型KD,在各类AI任务中展现出广泛的应用前景。数据增强技术DA在神经网络训练中被广泛采用,成为不可或缺的基础方法。

知识蒸馏按照蒸馏的位置通常分为

基于网络中间特征图的蒸馏

基于网络输出的蒸馏

在后者方面,近年来KD在分类任务上的发展主要集中在新型损失函数的引入。例如,ICLR’20年的CRD和ECCV’20年的SSKD则通过将对比学习技术融入损失函数中,从而从teacher模型中提取了更为丰富的信息,为student模型的学习提供了更有力的支持,最终实现了该领域的当前最佳表现(SOTA)。

本文未涉及损失函数、蒸馏位置等传统研究问题,我们沿用最原始版本的KD loss(即Hinton等人在NIPS’14 workshop上提出KD时所采用的Cross-Entropy + KL divergence)。我们重点关注网络的输入端:如何评估不同数据增强方法在KD中的优劣?(相比之下,以往的KD论文通常更关注网络的中间特征或输出端)。系统流程图如下所示,本文的主要目标是提出一种指标,以评估图中“Stronger DA”这一概念的强弱程度。

一切源于一个意外的实验发现:在**KD中延长迭代次数,通常可以显著提升KD的性能。例如,KD实验中常用的ResNet34/ResNet18 pair,在ImageNet-1K上,将迭代次数从100 epochs提升到200 epochs,可以使top1/top5准确率分别从70.66/89.88提升至71.38/90.59,达到当时CRD方法的性能水平(71.38/90.49)。令人费解的是,将最基础的方法训练时间稍作延长即可达到SOTA?通过大量实验分析,我们最终发现,数据增强是其背后的关键因素。

**直观上而言,**每次迭代,数据增强是随机的,导致每次得到的样本都不相同。当迭代次数增加时,student会遇到越来越多的不相同的样本,这有助于teacher模型从中提取更为丰富的信息(与对比学习loss具有相似的作用),从而有助于student模型进行学习。

自然地,我们可以进一步推想:不同数据增强方法引入的"多样性"应当是各具特色的。例如,我们期待基于强化学习搜索出来的AutoAugment相较于简单的随机翻转而言,其多样性应当更为显著。具体而言,该论文旨在探讨:如何量化这种多样性,并在实际应用中如何有效利用这一度量结果。

为什么这个问题重要?

1. 理论意义:帮助我们更深地理解KD和DA。

实验结果表明,在KD框架中采用更强的DA策略能够显著提升性能。若我们能够掌握影响这种"强弱"关系的关键因素,便能够创造更优的DA设计,从而受益于KD框架性能的提升。

**二、**主要贡献和内容

文章的主要贡献是三点:

我们阐述了一个定理,以明确说明什么样的数据增强是有效的。该定理指出,好的数据增强方法应降低teacher-student交叉熵的协方差。

定理的核心内容是探讨不同数据增强方法对训练样本相关性的影响,相关性越高表明样本越相似,从而导致多样性降低,这会直接影响student性能的表现。这一直观结论在文中得到严谨证明,构成了理论上的重要贡献。值得注意的是,相关性并非直接基于原始样本进行计算,而是基于teacher层的logits输出进行评估,换言之,在teacher看来,原始数据层的样本相关性并不重要,关键在于这些样本在teacher视角下是否足够相似,越不相似越好。

(2)基于该定理,构建了一个实用的评估指标(基于教师平均概率的标准差,记为T. stddev),该指标可以对每种数据增强方法计算出一个数值,通过将这些数值进行排序,可以确定哪种数据增强方法表现最优。在实验中,我们测试了7种现有的数据增强方法,发现CutMix表现最为出色。

(3)该定理指导我们提出了一种新的数据增强方法,命名为CutMixPick。该方法在原有CutMix方法的基础上,通过选择具有最大熵的样本(其高熵值反映了样本信息量丰富且多样性充足)来提升训练效果。实验结果表明,即使采用最基础的KD loss函数,其性能也能够与当前最优KD方法(如CRD)相比,达到相同的训练目标。该方法在实现KD目标的同时,显著提高了训练效率。

**三、**实验效果

文中最重要的实验是考察提出的指标(T. Stddev)是否能有效表征不同数据增强策略下student模型(S. test loss)的表现优劣,即两者之间的关联程度如何。实验结果表明,这种关联性具有显著性。

本研究对9种数据增强技术进行了系统性测试。在CIFAR-100、Tiny ImageNet和ImageNet-100等标准数据集上进行了实验验证。实验结果表明,各数据增强方法与目标任务的相关性均显著强,且计算所得的p-value普遍低于5%的显著性水平。

值得注意的一点是,纵轴反映了student的性能水平,而横轴指标完全由teacher计算得出,与student没有任何直接信息联系,但二者之间却展现出显著的相关性。这表明,KD方法对DA优劣的评价可能与student的性能无关。此外,对于不同teacher和数据集,DA的相对排序关系相对稳定。例如,CutMix方法在多个数据集上均显示出比Cutout更好的性能。这些发现表明,在特定网络和数据集下优化得到的DA方法,很可能在其他网络和数据集中同样表现出良好的效果,从而显著提升了方法的实际应用价值。

**四、**总结

本文关注数据增强在知识蒸馏中的影响,在理论和实际算法方面均有贡献,主要有三点:

我们对"评估知识蒸馏中不同数据增强方法的优劣"这一问题展开了深入的理论探讨(好的数据增强方法应通过最小化teacher-student交叉熵的协方差来实现其有效性);

该理论构建了一个实际可计算的度量指标(stddev of teacher’s mean probability)

该研究提出了一种新型数据增强方法(CutMixPick),通过进一步提高CutMix性能,在知识蒸馏过程中实现了新的最优性能SOTA

局限性:本研究在基于ImageNet-1K的数据集上,所提出的指标(T. Stddev)与student性能的相关性不显著。目前,原因尚不明确,作为未来研究的重点,我们将深入探讨这一问题。

参考

基于统计视角,评估知识蒸馏技术中不同数据增强策略的优劣性

https://arxiv.org/abs/2012.02909

全部评论 (0)

还没有任何评论哟~