论文研读系列——“TabDDPM: Modelling Tabular Data with Diffusion Models”
TabDDPM: Modelling Tabular Data with Diffusion Models
arxiv [Submitted on 30 Sep 2022]
代码:https://github.com/rotot0/tab-ddpm
摘要
本文阐述了TabdDPM作为一种新型扩散模型用于生成表格数据的技术。该技术在多个领域中展现出良好的应用前景,在计算机视觉与语音处理等领域受到广泛关注。其独特之处在于能够有效地处理连续型与离散型的不同特征类型;从而使得它能够适应多种表征形式的数据;因此,在面对各种表征形式的数据时表现出色。作者不仅在生成能力上表现出色;而且通过结合简单的MLP架构以及正弦时间嵌入方法实现了高效的反向过程建模;同时利用CatBoost优化超参数配置从而提升了整体性能。研究表明这种模型能够生成高质量的合成表征形式的数据;而这些合成表征形式的数据对于提升分类器与回归器的表现同样具有重要意义
1 INTRODUCTION(引言)
近年来,在生成模型领域内,去噪扩散概率模型(DDPM)受到了广泛的关注。由于它们在单个样本的真实性和多样性方面通常超过了替代方案,在自然图像领域中展现了最令人瞩目的成功。其优势被成功应用于诸如着色、修复、分割、超分辨率以及语义编辑等多个方面。此外,在如自然语言处理(NLP)、波形信号处理以及分子图等领域也进行了深入研究。这些研究不仅验证了扩散模型在其应用范围内的普遍性,并且进一步扩大了其潜在的应用前景。
作者的主要研究方向是探讨扩散判别过程(DDPM)的普遍适用性能否延伸至通用表格问题领域。这些通用表格问题广泛存在于多个工业领域,并可通过一组多元化的特征进行表征。受现行严格隐私法规(如《一般数据保护条例》GDP)的限制,在这种情况下公开真实用户数据变得困难;然而通过生成式模型产生的合成数据则可实现合法的数据共享,并且对生成高质量模型的需求愈发强烈。相较于计算机视觉和自然语言处理(NLP),训练高质量的表格数据分析模型更具挑战性:这主要归因于个体特征间的异构性和传统表征方式所具有的较小区域覆盖度。在此研究工作中,在面对这两个主要障碍:一是个体特征间的复杂差异;二是传统表征方式所带来的局限性时;我们发现尽管面临这两个主要障碍:一是个体特征间的复杂差异;二是传统表征方式所带来的局限性时;扩散模型仍可有效地模拟常规分布模式,并在多项基准测试中取得了最先进水平的表现。
更具体地说,作者工作的主要贡献如下:
作者阐述了TabDDPM这一方法——一种专门用于表格问题的基础型DDPM设计;该模型能够适用于所有表格任务,并支持处理数值型数据以及分类特征等多种数据类型。
作者证实了TabDDPM在多个数据集上展现出更高的性能水平,并超过了专为表格数据设计的替代方案;其中基于GAN和VAE的模型,并进一步解释道这些优势源自何处。
研究者发现,在面对隐私保护的需求时,基于TabDDPM生成的数据表现出了显著的优势。值得注意的是,在合成数据被用来替代不能共享的真实用户数据时
2 RELATED WORK(相关工作)
扩散模型 是一种生成框架,在计算机视觉领域具有重要应用价值。该框架旨在通过马尔可夫链的极限状态逼近目标分布,并被定义为从给定参数分布起始状态开始的一阶动态系统。每一步马尔可夫步骤由深度神经网络来处理,成功地学习了如何逆转利用已知高斯核的过程来实现扩散。研究表明,在这一过程中,Ho 等人展示了扩散模型与得分匹配理论之间的等价性,并将其视为逐步将简单已知分布转换为目标分布的不同视角之一。Nichol 等人与 Dhariwal 等人最近提出的创新架构进一步强化了这一领域的能力,使得基于扩散模型的方法在生成质量和多样性方面相较于GAN在计算机视觉领域中占据了明显优势。在作者的研究工作中,我们发现该方法同样适用于表格问题这一新兴领域,展现了其广泛的应用潜力
生成模型是机器学习社区当前活跃的研究领域。由于高质量的合成数据对于许多表格任务的需求极为巨大。首先,在大多数情况下, 表格数据集规模有限, 而视觉或自然语言处理(NLP)问题则能够在互联网上获得海量'额外'的数据资源.其次, 合成的数据集合不含真实用户信息, 因此无需受类似《通用数据保护条例》(GDPR)的规定限制, 在保证隐私的前提下即可公开分享.近期的研究开发了许多先进的模型, 包括表征生成对抗网络(VAE)以及基于生成对抗网络(GAN)的方法.研究者们通过在多个公共基准上进行广泛评估, 展示了他们的TabDDPM模型在多个基准测试中显著超过了现有替代方案.
表面层次合成生成过程与非结构化的图像或自然语言文本存在显著差异。由于表格数据具有明确的结构特征,并且其建模过程通常无需采用多层次深度架构以提高准确性。因此,在面对类别不平衡问题时,简单的插值方法如SMOTE仍可作为高效解决方案,在Camino等研究工作中已证明其在少数类过采样方面的优势不仅体现在性能上,在基于表格的生成对抗网络(GAN)模型中也表现优异。实验结果表明,在隐私保护的前提下,基于TabDDPM模型产生的合成数据显著优于传统插值技术所生成的数据集
3 BACKGROUND(背景)
论文深入阐述了扩散模型的基本原理及其在数据生成方面的应用方式。作为一种基于概率密度估计的概率生成模型D,扩散模型主要依靠正向马尔科夫链逐步推进数据分布的变化状态,并通过逆过程逐步推导出目标样本的条件概率分布。具体而言,在正向过程中,马尔可夫链从简单的初始分布逐步推进到复杂的数据分布;而在反向过程中,则实现了从复杂分布回归到简单分布的目标。
- 前向扩散过程 :该过程通过逐步添加噪声来逐渐改变初始数据样本,这些噪声是从预定义的分布中采样得到的。这个过程由一系列马尔可夫步骤组成,每一步都由一个深度神经网络执行,以学习如何逆转扩散过程。
- 反向扩散过程 :与前向过程相对应,反向过程逐步去噪一个潜在变量,并允许从数据分布中生成新的数据样本。反向过程中的分布通常是未知的,需要通过神经网络进行近似,网络参数通过优化变分下界来从数据中学习。
- 高斯扩散模型 :在连续空间中操作,其中前向和反向过程由高斯分布特征化。这部分介绍了如何使用高斯分布来定义数据的扩散和去噪步骤。
- 多项式扩散模型 :设计用于生成分类数据,其中数据通过多项式分布进行扩散,通过在类别上均匀添加噪声来破坏数据。
4 TABDDPM
在本节中,作者描述了TabDDPM的设计以及影响模型效果的主要超参数。
TabDDPM通过多项式扩散机制来模拟分类与二元属性,并通过高斯扩散模型处理数值属性。具体而言,在处理表格数据样本x = [x_{num}, x_{cat_1}, ..., x_{cat_C}]时(其中x_{num}表示数值属性集x_{num} ∈ R^{N_{num}}}),该方法将每个分类属性独立地进行正向扩散建模(即所有属性噪声部分独立采样)。在此过程中采用了基于scikit-learn库的分位数标准化方法对原始数据进行预处理。在逆向传播过程中,则设计了一种多层神经网络架构用于逆向传播过程:该网络输出与原始输入数据x_0维度相同的重建结果,在此过程中前N_{num}个维度用于预测正态噪声(对应于数值属性),其余维度则用于预测独热编码形式转换后的分类属性值。整个框架旨在实现对混合型表格数据的有效生成与修复操作。

表1: TabDDPM的主要超参数。

在分类问题中进行建模时(如图1所示),TabDDPM模型通过最小化高斯扩散项对应的均方误差L^{simple}_t以及每个多项式扩散块对应的KL散度L^i_t来进行学习。为了提高模型性能,在计算多项式扩散的整体损失值时(即\sum_{i=1}^{n} L^i_t),会将其进一步归一化处理并除以分类特征的数量。
L^{TabDDPM}_t = L^{simple}_t + \frac {∑_{i≤C} L^i_t} C
对于分类数据集而言,在研究中采用了基于条件概率的分类模型,并具体而言是学习p_θ(x_{t−1}|x_t, y)这一过程。而在回归问题中,则将目标值视为一个特殊的数值特征,并致力于构建相应的联合分布模型以捕捉潜在关系
研究者旨在模仿反向工程的过程,并采用了基于多层感知机(MLP)的简洁架构设计;这一架构源自Gorishniy等的研究者:
MLP(x) = Linear(MLPBlock(....(MLPBlock(x))))
MLPBlock(x) = Dropout(ReLU(Linear(x)))
根据Nichol等人(2021)以及Dhariwal与Nichol(2021)的研究表明,在表格输入x_{in}、时间步t和类别标签y方面有详细的说明
t\_emb = Linear(SiLU(Linear(SinTimeEmb(t))))
y\_emb = Embedding(y)
x = Linear(x_{in}) + t\_emb + y\_emb
其中SinTimeEmb代表具体来说是(Nichol, 2021; Dhariwal & Nichol, 2021)中所述的具体实现的sinusoidal positional embeddings,并且其维度被设定为128。在方程5中使用的每一个Linear层都具有固定投影维度128。

超参数 在TabDDPM框架中扮演着至关重要的角色,在多次实验验证下发现它们对于提升模型性能具有显著作用。参考表1可以看出,在这项研究中所涉及的主要超参数及其对应的搜索区间被系统地进行了阐述。这些设置建议是由论文作者所提出的,并且已经被证明是有效的选择;有关如何实施这些设置的具体步骤和细节,在论文实验部分进行了全面阐述。
5 EXPERIMENTS(实验)
在本节中, 作者对TabDDPM展开了全面的评估, 为了与其现有的替代方案进行对比分析.
数据集:为了系统地研究表格生成模型性能的目的,在这项研究中作者选取了15个不同来源的真实世界公共数据集作为研究对象。这些数据集在规模、属性、特征数量以及分布等方面存在显著差异,并且其中大部分曾在(Zhao等人, 2021; Gorishniy等人, 2021)的研究中被用于评估表格生成模型的效果。详细信息及其属性在表2中列出。
基线:针对大量用于表格数据的生成模型而言,在实际应用中很难穷尽所有可能的生成方法。因此,在研究过程中作者主要关注并评估了每个生成模型范式中的最佳实践,并结合开源实现进一步验证其适用性
- TVAE (Xu等人, 2019) — 基于表格数据生成的最新变分自编码器。据我们的了解,在性能上除了TVAE之外,并未发现其他VAE类模型同时具备超越其性能且提供公开源代码。
- CTABGAN (Zhao等人, 2021) — 一个在多样化的基准测试上超越现有表格GAN近期提出的基于GAN的模型。该方法不具备处理回归任务的能力。
- CTABGAN+ (Zhao等人, 2022) — CTABGAN增强版模型在最新的预印本中正式发布。我们目前还不清楚后续提出的基于GAN的表格数据生成方法是否提供了公开源代码。
- SMOTE (Chawla等人, 2002) — 这种"浅层"插值方法将合成点生成为真实数据点与其k近邻的凸组合。最初这种方法被提出用于少数类过采样应用,并在此基础上被泛化应用于合成数据生成作为一个简单的参考。
评价标准:作者的核心评价标准是机器学习(ML)效率(或效用)。具体而言, ML效率量化了基于合成数据训练并在真实测试集上进行评估的分类器或回归器的表现.通过使用高质量合成数据训练出的有效分类器或回归器,在实际应用中往往能与基于真实数据训练出的有效模型相媲美甚至超越.其中一种方案因其广泛的引用背景而更为常见(Xu等人, 2019; Zhao等人, 2021; Kim等人, 2022).研究团队分别计算了不同主流机器学习算法(如决策树、随机森林、逻辑回归等)的有效性.另一种方案则专注于当前最先进算法的表现,即分别针对CatBoost (Prokhorenkova等人, 2018)和(Gorishniy等人, 2021)中的深度学习架构进行了详细评估.通过系统化的参数优化流程确保最佳性能.研究者认为这一采用最新技术的研究方法更能反映出合成数据的真实价值,因为在大多数实际应用场景下,从业者更倾向于采用性能优越的技术而非较为落后的方法.

调整过程 。研究者采用了Optuna库来进行超参数优化。研究者依据在保留验证数据集上的合成数据生成情况评估了ML效率(针对Catboost),并且计算所得分数是在五个不同的采样种子上平均得出的结果。(具体设置可在表1中找到详细说明)此外,结果显示使用CatBoost指导调参不会产生任何‘CatBoost偏见’的影响,并且经过该方法调参后的TabDDPM生成效果显著优于其他模型如MLP。
5.1 定性比较
在研究过程中发现,在对比分析TabDDPM与TVAE以及CTABGAN+等模型时发现了一种显著的优势模式:具体来说,在每个实验设置下作者均生成了一个与特定数据集真实训练规模相当的虚拟样本集合;其中对于分类属性而言,在生成虚拟样本时遵循了真实数据分布的比例;随后通过图2展示了真实样本与虚拟样本典型个体特征的表现形式;进一步地则详细阐述了各类别特征的表现差异;结果表明:无论是在数值属性还是类别属性中TabDDPM均展现出更强的真实分布表现;而相较于CTABGAN+这种表现优势更为明显(1)在数值均匀分布的情形下(2)在类别属性数量较多的情况下以及(3)在同时包含连续与离散特性的混合型属性情况下

随后,在不同数据集上分析真实与生成数据间的协方差矩阵差异时
5.2 机器学习效率
在本节中,作者对比了TabDDPM与其他生成模型在机器学习效率方面的性能表现。对于每一个生成模型,在保持与真实训练集规模相当的情况下,作者生成了一个合成数据集,并将其用于训练一个分类或回归模型。随后通过真实测试集的评估指标来验证该方法的效果。具体而言,在实验过程中采用了两种不同的协议:一种是基于F1分数的分类性能评价方法;另一种是基于R2分数的回归性能评价标准。
首先,我们对多种不同机器学习模型进行了平均机器学习效率评估(如之前的研究中采用的方法一致)。该集合包括决策树、随机森林、逻辑回归(或岭回归)以及scikit-learn库中的MLP神经网络模型(使用默认参数),其中最大深度参数设置为28(适用于决策树和随机森林),最大迭代次数参数设为500(适用于逻辑回归和岭回归),而针对MLP神经网络则设为100。
随后


本研究的核心研究发现表明,在不同协议下计算所得的人工智能效率评估指标分别列示于图1至图6中。具体而言,在模型性能评价指标方面采用了基于双重采样的方法,并结合多次交叉验证技术进行了系统性分析。其中:
- 在模型性能评价指标方面采用了基于双重采样的方法,并结合多次交叉验证技术进行了系统性分析。
- 关注的重点包括模型准确率、召回率以及F1分数等关键性能指标。
- 在两种评估协议中,在大多数数据集上TabDDPM超越了TVAE和CTABGAN+这一事实充分展现了扩散模型在表格数据分析中的优势,并非之前工作中简单重复的结果。
- 基于插值法设计的SMOTE方法其性能与TabDDPM不相上下并且普遍表现更为出色有趣的是目前关于表格数据分析的研究并未对SMOTE进行系统性对比分析这一看似简单的基准却难以被超越。
- 尽管诸多先前研究采用了第一种评估指标来衡量机器学习效率但我们认为采用第二种评估指标更为合理(即以CatBoost等前沿算法为基准)。表3和表4的数据表明当采用第一种指标时分类回归任务的表现数值明显低于基于CatBoost等先进技术的标准这让人难以忽视生成模型输出数据较真实数据更具优势的现象然而当采用调整后的机器学习模型时上述情况并不普遍。
- 因此在第一种指标下相对于真实数据训练得到的合成数据往往具有更强的优势这种现象令人信服地表明生成的数据可能比真实的数据更加有价值。
总体而言,TabDDPM模型展现出卓越的生成性能,并可被视为高质量合成数据的重要来源。值得注意的是,在机器学习效率方面,“shallow”渗透方法与该模型展开竞争引发了疑问:是否有必要采用复杂的深度生成模型?在下一节中将对此作出明确回答。
5.3 隐私
在此处,则作者在不公开个人或敏感信息的情况下探讨了共享数据的问题。 tabddpm相对而言则表现得更为出色。 为了评估合成数据的隐私性程度,则作者引入了一个关键指标distance to closest record(dc r)。 具体而言,在计算过程中对于每个生成样本x来说,则计算其与所有真实样本x中的最小距离d(x,x),然后取所有这些距离值的中位数作为dc r指标。 表5详细比较了smote和tabddpm的dc r值,并明确突显了tabddpm的优势所在。 此外,在图4中展示了从合成样本到最近真实样本的距离分布情况。 这一实验结果表明,在保障隐私的同时 tabddpm生成的数据不仅提升了机器学习效率而且能够更好地满足特定场景下的需求综上所述 在不违反隐私的前提下 tabddpm展现出显著的优势


6 CONCLUSION(结论)
本文研究者深入分析了扩散模型框架在表格数据处理领域的潜力。特别地,在介绍一种能够有效处理包含数值型、有序数以及分类特征的数据组合的DDPM架构时,则重点阐述了一种创新性的设计方法。此外,在阐述模型超参数的关键作用及其调节策略方面也进行了深入探讨。特别强调,在保证数据隐私的前提下进行评估时,则展现了所提出方法显著的优势:生成的数据质量持续优于基于GAN/VAE的传统方法以及插值技术等比较基准方案。
最后感谢您抽出时间阅读并停留在此。
这些观点均由我独自对论文内容进行深入分析所得。
仅限于个人学习目的使用。
如有任何疑问或侵权行为,请随时与我联系。
祝你天天开心,多笑笑。
