Large-Scale Generative Data-Free Distillation
Large-Scale Generative Data-Free Distillation
我们提出了一种创新的方法,并通过训练教师网络的内在normalization层统计数据来训练生成式图像模型。这种策略使得我们可以构建一个无需依赖训练数据使用的生成器集合,并从而有效地生成后续蒸馏过程中的替代输入。实验结果表明,在无数据环境下进行蒸馏时,在CIFAR-10数据集上的准确率达到95.02%,而在CIFAR-100上则达到77.02%。此外,请注意我们在ImageNet数据集上的相关研究也取得了显著成果:据我们所知,在无数据环境中未曾使用过生成模型以实现类似效果的研究报告中提到的相关技术路径与之不同。

图1. 提出了一种无需生成数据提取的方法。我们提出了基于生成式的无数据提取方法,在缺乏真实图像的情况下,默认情况下使用以下方式进行训练:首先使其目标标签被预先训练的 teacher 最大化预测概率;其次通过匹配 batchnorm 层的均值 μ 和方差 σ² 的统计量(见公式 (7))来实现这一目标。随后,在利用生成器产生的合成图像后,则可进行知识提取过程。更多合成图像的具体实例可见图4展示。
在本研究中,我们基于生成式图像建模的方法实现了高效的数据显示,同时探究其在大规模数据集上的扩展可能性.为此,我们开发出一种称为"无数据提炼"的方法,该方法可在无需原始训练数据的前提下开发一个自动生成器,并在虚拟样本中提取知识.该自动生成器通过最小化以下两个优化目标得以优化:第一部分为目标匹配损失(Moment Matching Loss),第二部分为提升判别器梯度激活量的最大化损失(Inceptionism Loss).
针对 [21-55 ] 范围内的无数据图像合成方法,在 moment-matching loss 的不同变体方面已有研究关注其不同变体。此外值得注意的是,在训练批归一化层的设计过程中会采用这些信息;它们普遍应用于当前主流架构设计中,并包括如 resnet [
基于Deep Dream Style [42]的方法框架中引入了Inceptionism损失函数作为指导原则。我们的核心目标在于寻找一种能够最大可能地提升预训练模型识别特定类型图像能力的输入样本,并将其转化为一个最小化交叉熵损失函数的问题。通过将这一目标与Moment Matching损失相结合使用,在仅依赖于预训练教师模型的前提下无需真实样本即可训练生成器网络,并从而能够系统性地提取出高质量的合成图像。
为了评估该方法的效果,我们在三个图像分类数据集上进行了实证研究。随后,在不使用真实图像的情况下进行无监督蒸馏实验,我们发现生成的图像质量及真实性较之前的方法更为突出。
这些图像也可以有效地辅助下面的知识提炼。
学习的学生在与之前的方法相比表现出明显的优势,并取得了新的SOTA结果;同时,在监督训练的基础上表现更为出色。
随后,我们研究并构建了一个包含多个生成器的集合,在CIFAR-100和ImageNet数据集上进行实验验证,并对该集合的有效性进行了进一步验证
我们的主要贡献总结如下:
我们开发了一种创新性方法用于指导预训练教师网络模型高效生成合成数据
我们在CIFAR-10和CIFAR-100数据集上深化无数据提炼技术,在准确率方面均取得了显著成果:95.02%与77.02%的成功率均超越基于有监督训练的同类方法。
我们基于多组生成器的发展性研究,在ImageNet数据库上实现了非生成式提取技术的应用。据记录显示,在现有研究中,“仅凭无标签数据实现蒸馏过程”的情况尚属首次。
3. Generative Distillation in Data-Free Setting
在本节中,我们首先复习了核心的知识提炼方法,然后介绍了经过预先训练的教师如何构建生成模型的方法.
3.2. Knowledge Distillation
知识蒸馏的主要目标是将教师网络T(x;θt)T(x;\theta_t)所包含的知识转移到学生模型S(x;θs)S(x;\theta_s),实现教学过程中的经验传递。在分类场景下,这两个神经网络通常会在K个预设类别上生成概率分布表示。为了使学生模型能够有效模仿教师的行为,在训练过程中需要确保其能够捕获并复制教师网络在训练数据集上的统计特性。从形式上看这种知识提取过程可被视为优化问题中的目标函数最小化问题。

pdatap_{data}代表训练数据的分布。
3.3. Generative Image Modeling
计算公式(1)中的损失函数需要明确了解数据分布pdatap_{data}的信息,在这种情况下,在无数据环境下难以确定该分布。

然后,在未在pdata的访问情况下进行生成器的训练,并仅依赖预先训练完成的教师模型T后裔的基础上,在当前的关键问题上寻求合适的指导目标以驱动生成器的学习过程。这些指导目标将在本节后续内容中进行介绍。
该损失函数基于Inceptionism风格(Inceptionism-style42)的图像合成技术,并亦称作DeepDream技术的一种变体。它是一种用于可视化输入图像的方法,在训练好的预设模型中激发特定神经网络响应。具体而言,在我们希望确定模型预测'狗'类别的输入图像时,激励方法会从随机噪声生成的可训练图像x出发,在经过优化以最大化模型输出概率的过程中逐步调整x为最'类狗'的图像形式。在数学形式上,在给定预设标签y^\hat y以及一个经过训练的教师模型T的情况下,则需要找到使分类分布交叉熵最小化的x值(即p=T(x)p=T(x)中最接近预设类别分布\hat p=OneHot(y^)\hat p= OneHot(\hat y))。

在实际工作中我们通常不会仅进行目标优化而是在施加一个必要条件的情况下使合成图像复制自然图像的统计特性例如相邻像素间的特定相关性这可以通过在式(2)中添加正则化项来实现

其中,在本文中我们采用[21,55] (Dream to distill)中的total variation loss and l2-norm作为正则器。

匹配矩损失. 与之相比,在inceptionism损失中仅对输入(图像)和输出(概率)施加约束,并未对中间层的激活进行限制。现有研究表明[18, 34, 38]:深度卷积神经网络的不同层级往往承担着不同的功能——如底层倾向于识别边缘及曲线等初级特征;而高层则擅长学习编码更为抽象的高级特征。此外,在文献[21]中指出:基于传统inceptionism方法训练得到的图像可能引发中间层异常激活现象,并偏离真实数据分布观察结果;这些发现均表明应当引入一个正则化项以约束教师模型在中间层统计上的偏差
批归一化层是大多数神经网络中常用的组件 能够提供这样的统计信息 [21 55] 。其主要目的是通过利用训练阶段计算出的移动平均值和方差对手动激活进行重新中心和缩放处理以实现归一化 这实际上意味着该方法假设生成图像所使用的统计特性与真实数据中的统计特性是一致的 因此 从而使得生成图像所提取的关键统计特征与真实样本中的特征相匹配 [10 48 53] 。
在教师模型的BN层中找到移动均值\mu_{\text{hat}}和方差\sigma_{\text{hat}}^2;我们的目标是使生成数据与真实数据之间的均值\mu(x)及方差\sigma^2(x)达到最优状态。\ 在基于isotropic高斯分布假设的情况下,则可通过最小化它们之间的Kullback-Leibler散度来实现:

其中\mathcal N(\cdot,\cdot)被定义为高斯分布。在本文中,我们采用后者,并通过将所有batch-norm layers的其惩罚值相加来构建moment matching loss。

Generator training objective. 通过融合Inceptionism损失与矩匹配损失相结合的方式进行训练, 能够实现生成器的优化目标

我们旨在基于这些损失来训练一个核心生成模型。通过将生成器G(z|y)纳入公式(8)中的变量x部分进行应用, 从而确定生成器的训练目标。

采用多组generator进行训练时会遇到模式崩溃的问题[16, 40, 47]。每个generator并不负责创造多样化的图像而是专注于产生单一类型的图像或仅限于少数变化形式的分布其输出与潜在变量之间关系较弱。我们假设当我们的generator在训练过程中产生与高置信度教师预测一致的图像则交叉熵损失会消失即使其他损失项(如LM\mathcal L_M)尚未完全优化generator仍能学习仅输出该特定结果。As shown in Figure 2 our generator experiences a pattern collapse scenario where it can successfully produce realistic images for the 'car' class but generates all instances in red.

基于先前的研究[13, 35]所指出的那样,在处理此问题时可以通过多组生成器来实现一种相对简单而有效的解决方案
对于我们的方法来说,在采用包含k个生成器的方式下,并将各类均匀地分布在各个生成器上以确保每个类别仅由一个特定的生成器负责。此外,在这种配置下,每一个生成器都会专注于其所分配到某一特定类别的一致性目标即最大化其inceptionism损失函数值
在处理moment matching时,在这种情况下我们并未从BN层中存储的moments出发
4.2. CIFAR100

Single generator. 如表2所示, 使用单一生成器的知识蒸馏获得了相应的测试精度数据. 通过该方法, 我们的模型在测试集上达到了76.42%的分类准确率, 显著优于以往的研究成果. 然而, 该结果与具有监督学习能力的ResNet-18模型相比仍存在一定差距, 其中ResNet-18模型基于教师网络直接进行监督学习, 或者是从教师网络的训练数据中提取特征进行推理.
Multiple generators. 我们采用两种方法来采集每个类别的统计数据。一种直接的方式是。(a)从训练数据中每类选取100张图像 (b)将这些图像回传至预训练模型以估算每层所需的矩 ©在生成器训练过程中将它们作为元数据输入模型训练器。尽管我们只需要少量图像来获取这样的统计信息但它并不能被视为完全无数据的方法在最严格的情况下必须无数据的情况下我们还有另一种选择即利用方程(8)作为优化目标学习若干可训练的图像[21 55]。按照这种方法我们可以获得少量合成的数据用于计算各类统计量的具体数值并用于生成器的训练过程在蒸馏过程中我们只需从所有生成器中均匀随机抽取样本进行特征提取
表2的最后两行呈现了通过集合生成器进行知识提取的结果。
在实验过程中, 我们对不同架构的学生进行了一系列对比分析(见表5)。其中, 老师采用的是ResNet-50模型, 其最高一级分类准确率达到75.45%. 我们采用了统一的生成器组别来处理所有参与比较的学生数据。与基于有监督学习的传统方法相比, ResNet-50模型上的蒸馏性能表现最优, 但其分类精度较之下降了约5.7%. 然而, 在使用ResNet-18和MobileNetV2[49]等结构时, 表现相对较为薄弱, 与其他学生的差距较为显著. 这些结果表明, 学生与教师结构之间可能存在某些制约因素, 导致基于ResNet-50教师模型的学习所获得的生成器在MobileNetV2和ResNet-18等结构上表现出较差的效果. 关于进一步提升泛化能力的技术研究将继续作为未来研究的重点方向
