Advertisement

2018 Dataset Distillation

阅读量:

数据集蒸馏

作者:Tongzhou Wang、Jun-Yan Zhu、Antonio Torralba、Alexei A. Efros

机构:Facebook、MIT CSAIL、UC Berkeley

目录

数据集蒸馏

背景

目前存在的问题

贡献

方法介绍

3.1 数据集的蒸馏方法:单步迭代蒸馏

3.2 随机初始化蒸馏

3.3 简单线性情况的分析

3.4 单步迭代蒸馏拓展到多步迭代蒸馏

3.5 不同初始化方式对比

3.6 不同目标对象的蒸馏

实验结果

蒸馏图像效果

四种不同初始化比较

快速微调结果

总结


背景

训练最先进的神经网络模型需要的数据集规模越来越庞大,这对于内存以及训练资源的需求越来越高,所以,将庞大的数据集高效地进行压缩是一个十分重要的研究方向。

传统的数据集压缩方法是将原数据集去掉不重要的部分,提炼为子集,这一方法的实际效果欠佳。

作者从知识蒸馏中得到启发,传统的知识蒸馏目的是从复杂模型中蒸馏出知识让简单模型进行学习,从而让简单模型的性能能够接近复杂模型,类比提出了数据集蒸馏。数据集蒸馏定义为:固定训练的模型,从大型训练数据集中蒸馏知识让小型训练集学习,从而让小型数据集训练的模型性能能够接近在大型数据集上训练的模型。

目前存在的问题

为什么数据集蒸馏是有效的呢?/是否能将一个数据集压缩成一小组合成数据样本?

传统观点认为数据集蒸馏是不行的,因为合成训练数据不遵循真实数据的分布,理论上认为合成数据样本无法训练一个好的分类器。

贡献

1.提出了数据集的蒸馏方法

2.推导了在线性网络下,达到与完整数据集训练相同性能所需的蒸馏数据大小的下界;

3.在MNIST、CIFAR10上验证了合成数据训练分类器的有效性;

4.通过蒸馏数据集完成预训练模型的快速微调工作;

5.完成有害数据攻击应用,通过蒸馏图像快速攻击训练好的分类器对某一个类的识别准确率。

方法介绍

3.1 数据集的蒸馏方法:单步迭代蒸馏

真实数据集 期望能得到一个合成数据集****

随机初始化模型参数**** ,在蒸馏数据上迭代一次,

设置损失函数 L,目标函数如下,使用L进行反向传播更新蒸馏数据

3.2 随机初始化蒸馏

在训练过程中编码了真实数据 和一个固定的模型参数 ,故泛化能力弱。所以将模型参数推广到一个特定分布,目标函数改变为下,其余步骤如3.1所示

3.3 简单线性情况的分析

结论:对于一个二元损失的线性模型来说,在相同的一个梯度下降步骤,蒸馏数据的数量至少要大于真实数据的向量维度值才能达到相同的性能。证明过程如下:

对于一个真实数据集***,训练一个二元线性模型,损失函数如下,d代表N个数据,每个数据维度为D,t代表D个数据标签,维度为1,权重矩阵为θ(D1)

蒸馏数据为 ,训练一个梯度之后

此时,希望对于任意 都能满足在训练集上相同的测试性能,假设为 ,满足上式,真实数据和标签d,t应该满足下式

代入得:

对于任意 都成立,故应满足dd满秩且M>=D

3.4 单步迭代蒸馏拓展到多步迭代蒸馏

单步迭代蒸馏:

多步迭代蒸馏:

其余步骤如3.2**,** 使用反向梯度计算优化策略加快梯度计算,反向梯度优化将必要的二阶项表述为有效的Hessian-vector积

3.5 不同初始化方式对比

使用了四种初始化方式:随机初始化、固定初始化、随机预训练权重、固定预训练权重

3.6 不同目标对象的蒸馏

蒸馏用于恶意数据中毒:通过蒸馏数据对训练好的分类器再次训练一个梯度下降步骤,破坏分类器对于某个类的分类准确率。

目标函数如下:

总的算法如下:

实验结果

蒸馏图像效果

固定网络:

随机初始化网络:

四种不同初始化比较

快速微调结果

总结

提出了数据的蒸馏方法,使用真实数据在蒸馏模型上进行测试作为损失函数,反向传播更新蒸馏数据,缩小测试损失;证明了合成数据对于训练模型的有效性,但准确率还有较大提升空间,可以尝试推广到更高分辨率的数据集上。

全部评论 (0)

还没有任何评论哟~