Advertisement

【FREEMATCH: SELF-ADAPTIVE THRESHOLDING FOR SEMI-SUPERVISED LEARNING】

阅读量:

FREEMATCH: SELF-ADAPTIVE THRESHOLDING FOR SEMI-SUPERVISED LEARNING

    • 摘要

    • 引言

    • 2 动机

    • 3 准备工作

    • 4 freematch

      • 4.1自适应阈值
      • 4.2 自适应公平
    • 5 实验

      • 5.1 设置
      • 5.2定量结果
      • 5.3定性分析
      • 5.4消融研究
    • 6 相关工作

    • 7 结论

      • A “双月”数据集的实验细节
      • 附录B:定理2.1的证明
      • C 算法
      • 附录D:超参数设置
      • 附录E:广泛的实验细节和结果
        • E.1 显著性测试
    • E.2 CIFAR-10 (10) 标记数据

    • E.3 详细结果

      • E.4 FixMatch 和 FlexMatch 上预定义阈值的消融实验
      • E.5 CIFAR-10 (40) 上 EMA 衰减的消融实验
      • E.6 FlexMatch 和 FreeMatch 上 SAF 的消融实验
      • E.7 不平衡 SSL 的消融实验
      • E.8 STL-10 (40) 上的 T-SNE 可视化
      • E.9 CIFAR-10 (10) 上的伪标签准确率
      • E.10 CIFAR-10 (10) 混淆矩阵

SEMI-SUPERVISED LEARNING)

摘要

半监督学习(SSL)由于伪标签和一致性正则化等各种方法带来的出色表现,取得了巨大成功。然而,我们认为现有方法可能未能更有效地利用未标记数据,因为它们要么使用预定义的/固定的阈值,要么采用临时的阈值调整方案,导致性能下降和收敛速度缓慢。我们首先分析了一个激励示例,以直观理解理想阈值与模型学习状态之间的关系。基于此分析,我们提出了FreeMatch,通过根据模型的学习状态自适应地调整置信度阈值。我们进一步引入了自适应类别公平正则化惩罚,以在早期训练阶段鼓励模型进行多样化的预测。大量实验表明,FreeMatch在标记数据极其稀少的情况下具有明显的优势。在CIFAR-10(每类1个标签)、STL-10(每类4个标签)和ImageNet(每类100个标签)上,FreeMatch相较于最新的最先进方法FlexMatch分别实现了5.78%、13.59%和1.28%的错误率减少。此外,FreeMatch还能提升不平衡半监督学习的表现。代码可在https://github.com/microsoft/Semi-supervised-learning获取。

引言

深度学习的卓越性能在很大程度上依赖于使用充足标注数据的监督训练 (He et al., 2016; Vaswani et al., 2017; Dong et al., 2018)。然而,获取大量标注数据仍然是一项费时且昂贵的任务。为减轻这种依赖,半监督学习(SSL)(Zhu, 2005; Zhu & Goldberg, 2009; Sohn et al., 2020; Rosenberg et al., 2005; Gong et al., 2016; Kervadec et al., 2019; Dai et al., 2017) 被开发出来,通过利用大量未标记数据来提高模型的泛化性能。伪标签 (Lee et al., 2013; Xie et al., 2020b; McLachlan, 1975; Rizve et al., 2020) 和一致性正则化 (Bachman et al., 2014; Samuli & Timo, 2017; Sajjadi et al., 2016) 是现代SSL中的两大主流范式。最近,二者的结合显示出了有希望的结果 (Xie et al., 2020a; Sohn et al., 2020; Pham et al., 2021; Xu et al., 2021; Zhang et al., 2021)。其核心思想是,模型应当在不同的扰动下为相同的未标记数据产生相似的预测或相同的伪标签,这符合SSL中的平滑性和低密度假设 (Chapelle et al., 2006)。

这些基于阈值的方法的一个潜在局限性在于,它们要么需要一个固定阈值 (Xie et al., 2020a; Sohn et al., 2020; Zhang et al., 2021; Guo & Li, 2022),要么需要一个临时的阈值调整方案 (Xu et al., 2021) 来仅使用高置信度的未标记样本计算损失。具体而言,UDA (Xie et al., 2020a) 和 FixMatch (Sohn et al., 2020) 采用了一个固定的高阈值来确保伪标签的质量。然而,固定的高阈值(0.95)在训练早期阶段可能导致数据利用率低,并忽略不同类别的学习难度差异。Dash (Xu et al., 2021) 和 AdaMatch (Berthelot et al., 2022) 提出随着训练的进行逐步提升固定的全局(数据集特定)阈值。虽然未标记数据的利用率有所提高,但它们的临时阈值调整方案受超参数的任意控制,与模型的学习过程脱节FlexMatch (Zhang et al., 2021) 表明,不同类别应有不同的局部(类别特定)阈值。尽管局部阈值考虑了不同类别的学习难度,它们仍然是从预定义的固定全局阈值映射得来的 。Adsh (Guo & Li, 2022) 通过优化每个类别的伪标签数量,从预定义阈值中获得自适应阈值,以应对不平衡的半监督学习。总的来说,这些方法可能无法或不足以根据模型的学习进展调整阈值,从而阻碍了训练过程,特别是在标记数据过于稀少,无法提供充分监督的情况下。
在这里插入图片描述
图1:演示FreeMatch如何在“两个月”数据集上工作。(a)FreeMatch和其他SSL方法的决策边界。(b)每类两个标记样本的自适应公平性(SAF)的决策边界改进。©类平均置信阈值。(d)训练期间FreeMatch的Classaware采样率。实验详情见附录A。

例如,如图1(a)所示,在“two-moon”数据集上,每个类别只有一个标记样本时,先前方法获得的决策边界未能满足低密度假设。因此,两个问题自然产生:1)是否有必要基于模型的学习状态来确定阈值?2)如何自适应地调整阈值以获得最佳训练效率?

在本文中,我们首先通过一个激励示例展示了不同数据集和类别应基于模型的学习状态来确定它们的全局(数据集特定)和局部(类别特定)阈值。直观地,我们在训练早期阶段需要一个较低的全局阈值,以利用更多未标记数据并加速收敛。当预测置信度增加时,需要一个更高的全局阈值来过滤错误的伪标签,以减轻确认偏差 (Arazo et al., 2020)。此外,每个类别的局部阈值应根据模型对其预测的置信度来定义 。图1(a)中的“two-moon”示例显示,当根据模型的学习状态调整阈值时,决策边界更为合理。

接下来,我们提出了FreeMatch,根据每个类别的学习状态以自适应方式调整阈值(Guo et al., 2017)。具体而言,FreeMatch 使用自适应阈值调整(SAT)技术,通过未标记数据置信度的指数移动平均(EMA)来估算全局(数据集特定)和局部阈值(类别特定)。为了更有效地处理极少监督的场景(Sohn et al., 2020),我们进一步提出了一个类别公平性目标,鼓励模型在所有类别之间产生公平(即多样化)的预测,如图1(b)所示。FreeMatch的整体训练目标是最大化模型输入与输出之间的互信息(John Bridle, 1991),在未标记数据上产生高置信度且多样化的预测。基准测试结果验证了其有效性。总之,我们的贡献是:

通过一个激励示例,我们讨论了为什么阈值应反映模型的学习状态,并为设计阈值调整方案提供了一些直观思路。

我们提出了一种新方法——FreeMatch,该方法包含自适应阈值调整(SAT)和自适应类别公平正则化(SAF)。SAT 是一种无需手动设置阈值的阈值调整方案,SAF 则鼓励多样化的预测。

大量结果表明,FreeMatch 在各种半监督学习基准测试中表现出卓越的性能,尤其是在标记数量非常有限的情况下(例如,在CIFAR-10数据集上,每类只有1个标记样本的情况下,错误率减少了5.78%)。

2 动机

在本节中,我们介绍了一个二分类示例,以激发我们对阈值调整方案的讨论。尽管这个示例简化了实际的模型和训练过程,但分析产生了一些有趣的启示,并为如何设置阈值提供了见解。

我们旨在展示自适应性和在SSL(半监督学习)中使用更精细的置信度阈值的重要性。受(Yang & Xu, 2020)的启发,我们考虑一个二分类问题,其中真实分布是两个高斯分布的均匀混合(即标签Y为正类(+1)或负类(-1)的概率相等)。输入X具有以下条件分布:
在这里插入图片描述
在这里插入图片描述
接下来,我们推导以下定理来展示自适应阈值的必要性:

定理 2.1 对于上述提到的二分类问题,伪标签 (Y_p) 的概率分布如下:
在这里插入图片描述
在这里插入图片描述
在附录B中提供了定理2.1的证明。定理2.1有以下几项含义或解释:
在这里插入图片描述

定理2.1的直观解释是,在训练初期,(\tau) 应该较低,以鼓励产生多样的伪标签,提高无标签数据的利用率并加快收敛。然而,随着训练的进行和 (\beta) 的增大,持续使用低阈值会导致不可接受的确认偏差。理想情况下,阈值 (\tau) 应该随着 (\beta) 的增加而升高,以保持稳定的采样率。由于不同类别具有不同的类内多样性(不同的 (\sigma))且某些类别比其他类别更难分类((\mu_2 - \mu_1) 较小),因此需要精细的类别特定阈值来鼓励对不同类别的伪标签进行公平分配。挑战在于如何设计一个能够考虑到所有这些因素的阈值调整方案,这是本文的主要贡献。

我们通过绘制训练过程中平均阈值趋势和边际伪标签概率(即采样率)来展示我们的算法,如图1©和1(d)所示。总结来说,我们应该通过模型的预测来估计学习状态,以确定全局(数据集特定)和局部(类别特定)的阈值。接下来,我们将详细介绍FreeMatch。
在这里插入图片描述在这里插入图片描述

3 准备工作

在半监督学习(SSL)中,训练数据由标注数据和未标注数据组成。设DL = {(xb, yb) : b ∈ [NL]} 和 DU = {ub : b ∈ [NU]}为标注数据和未标注数据,其中NL和NU分别表示它们的样本数量。标注数据的监督损失为:
在这里插入图片描述
其中B是批大小,H(·, ·)指的是交叉熵损失,ω(·)表示随机数据增强函数,pm(·)是模型的输出概率。对于未标注数据,我们专注于使用置信度阈值的交叉熵损失进行伪标签,通过熵最小化。此外,我们还采用了UDA(Xie等,2020a)介绍的“弱和强增强”策略。正式地,未标注数据的无监督训练目标是:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4 freematch

4.1自适应阈值

我们主张,确定半监督学习(SSL)阈值的关键在于阈值应反映学习状态。学习效果可以通过经过良好校准模型的预测置信度来估计(Guo et al., 2017)。因此,我们提出自适应阈值调整(SAT),该方法通过利用训练期间模型的预测自动定义和自适应调整每个类别的置信度阈值。SAT 首先估计一个全局阈值,作为模型置信度的指数移动平均(EMA)。然后,SAT 通过估计每个类别的概率的 EMA 调节全局阈值,以得到局部的类别特定阈值。当训练开始时,阈值较低,以接受更多可能正确的样本参与训练。随着模型信心的增强,阈值自适应增加,以过滤掉可能不正确的样本,从而减少确认偏差。因此,如图 2 所示,我们将 SAT 定义为 Tt(C),表示第 t 次迭代中类别 c 的阈值。

自适应全局阈值

我们根据以下两个原则设计全局阈值。首先,SAT中的全局阈值应与模型对未标记数据的置信度相关,反映整体学习状态。此外,全局阈值在训练过程中应稳定增加,以确保丢弃不正确的伪标签。我们将全局阈值 Tt 设置为模型对未标记数据的平均置信度,其中 t 表示第 t 次时间步(迭代)。然而,由于未标记数据的数量庞大,在每个时间步或甚至每个训练周期计算所有未标记数据的置信度将非常耗时。相反,我们将全局置信度估计为每个训练时间步置信度的指数移动平均(EMA)。我们将 Tt初始化为1/C,其中C表示类别的数量。全局阈值 Tt 定义和调整为:
在这里插入图片描述
其中λ ∈(0,1)是EMA的动量衰减。

自适应局部阈值

局部阈值的目的是以类别特定的方式调节全局阈值,以考虑类别内的多样性和可能的类别邻接性。我们计算模型对每个类别 c 的预测期望,以估计该类别特定的学习状态:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4.2 自适应公平

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

5 实验

5.1 设置

我们在常见的基准数据集上评估 FreeMatch:CIFAR-10/100 (Krizhevsky et al., 2009)、SVHN (Netzer et al., 2011)、STL-10 (Coates et al., 2011) 和 ImageNet (Deng et al., 2009)。遵循之前的工作 (Sohn et al., 2020; Xu et al., 2021; Zhang et al., 2021; Oliver et al., 2018),我们在不同数量的标记数据上进行实验。除了常用的标记数量外,参照 (Sohn et al., 2020),我们还包括 CIFAR-10 的最具挑战性的情况:每个类别仅有一个标记样本。

为了公平比较,我们使用统一的代码库 TorchSSL (Zhang et al., 2021) 训练和评估所有方法,使用相同的骨干网络和超参数。具体而言,我们对 CIFAR-10 使用 Wide ResNet-28-2 (Zagoruyko & Komodakis, 2016),对 CIFAR-100 使用 Wide ResNet-28-8,对 STL-10 使用 Wide ResNet-37-2 (Zhou et al., 2020),对 ImageNet 使用 ResNet-50 (He et al., 2016)。我们使用带有 0.9 动量的 SGD 作为优化器。初始学习率为 0.03,使用余弦学习率衰减计划,即 η = η₀ cos( 7πk / 16K),其中 η₀ 是初始学习率,k(K) 是当前(总)训练步数,我们为所有数据集设置 K = 220。在测试阶段,我们使用训练模型的动量为 0.999 的指数移动平均来对所有算法进行推理。标记数据的批量大小为 64,ImageNet 为 128。我们使用相同的权重衰减值、预定义阈值 τ、未标记批量比率 µ 和为伪标签 (Lee et al., 2013)、Π 模型 (Rasmus et al., 2015)、均值教师 (Tarvainen & Valpola, 2017)、VAT (Miyato et al., 2018)、MixMatch (Berthelot et al., 2019b)、ReMixMatch (Berthelot et al., 2019a)、UDA (Xie et al., 2020a)、FixMatch (Sohn et al., 2020) 和 FlexMatch (Zhang et al., 2021) 引入的损失权重。

我们基于 UDA 实现 MPL,如 (Pham et al., 2021) 所述,其中我们将温度设置为 0.8,wu 设置为 10。我们不对标记数据进行 MPL 的微调,因为我们发现微调会导致模型过拟合标记数据,特别是当标记数据非常少时。对于 Dash,我们使用 (Xu et al., 2021) 中的相同参数,唯一的不同是我们在标记数据上进行了 2 个时期的预热,因为过多的预热会导致过拟合(即 2,048 次训练迭代)。对于 FreeMatch,我们在所有实验中将 wu 设置为 1。此外,对于 CIFAR-10(10 个标签)、CIFAR-100(400 个标签)、STL-10(40 个标签)、ImageNet(100k 个标签)和 SVHN 的所有实验,我们将 wf 设置为 0.01。对于其他设置,我们使用 wf = 0.05。对于 SVHN,我们发现,在早期训练阶段使用较低的阈值会阻碍模型对未标记数据的聚类,因此我们为 SVHN 采用了两种训练技巧:(1)在仅使用标记数据的情况下对模型进行 2 个时期的预热,如 Dash 所述;(2)将 SAT 限制在范围 [0.9, 0.95] 内。详细的超参数介绍在附录 D 中。我们使用不同的随机种子训练每个算法 3 次,并报告所有检查点的最佳错误率 (Zhang et al., 2021)。
在这里插入图片描述
表1:CIFAR-10/100、SVHN和STL-10数据集的错误率。STL-10的完全监督结果不可用,因为我们没有其未标记数据的标签信息。粗体表示最佳结果,下划线表示次佳结果。每个数据集的显著性检验和平均错误率见附录E.1。

5.2定量结果

CIFAR-10/100、SVHN 和 STL-10 的 Top-1 分类错误率在表 1 中报告。每个类别 100 个标签的 ImageNet 结果在表 2 中。我们还在附录 E.3 中提供了精确度、召回率、F1 分数和混淆矩阵的详细结果。这些定量结果表明,FreeMatch 在 CIFAR-10、STL-10 和 ImageNet 数据集上实现了最佳性能,并且在 SVHN 上的结果与最佳竞争者非常接近。在 CIFAR-100 中,当标签数量为 400 时,FreeMatch 优于 ReMixMatch。ReMixMatch 在 CIFAR-100(2500)和 CIFAR-100(10000)上的良好表现可能得益于 mix up(Zhang et al., 2017)技术和自监督学习部分。在每个类别 100k 标签的 ImageNet 上,FreeMatch 的表现显著优于最新的对手 FlexMatch,提升了 1.28%。我们还注意到,FreeMatch 在 ImageNet 上的计算速度较快,如表 2 所示。需要注意的是,FlexMatch 的速度远慢于 FixMatch 和 FreeMatch,因为它需要维护一个记录每个样本是否干净的列表,这在大型数据集上需要大量的索引计算资源。

值得注意的是,在标签数据极其有限的情况下,FreeMatch 始终以较大幅度优于其他方法:在 CIFAR-10 上有 10 个标签时提高了 5.78%,在 CIFAR-100 上有 400 个标签时提高了 1.96%,而在 STL-10 上有 40 个标签时惊人地提高了 13.59%。与其他数据集相比,STL-10 是一个更现实且具有挑战性的数据集,包含 100k 张大量未标记的图像。这些显著的提升证明了 FreeMatch 的能力和潜力,可以在现实世界的应用中部署。
在这里插入图片描述
表2:ImageNet上的错误率和运行时间,每个类100个标签。

5.3定性分析

我们进行了一些定性分析:FreeMatch 为什么以及如何有效?它带来了哪些其他好处?我们评估 FreeMatch 在 STL-10(40)上的类别平均阈值和平均采样率(即在 STL-10 上的 40 个标记样本),以展示其如何与我们的理论分析相一致。在训练期间,我们记录阈值并计算每个批次的采样率。采样率的计算方式为:\frac{1}{\mu_B} \sum_b 1(\max(q_b) > \tau_t(\arg \max(q_b))) 我们还绘制了准确率的收敛速度和混淆矩阵,以显示 FreeMatch 中提议的组件如何帮助提高性能。从图 3(a) 和图 3(b) 中可以观察到,FreeMatch 的阈值和采样率变化与我们的理论分析大致一致。也就是说,在训练的早期阶段,FreeMatch 的阈值相对较低,与 FlexMatch 和 FixMatch 相比,导致更高的未标记数据利用率(采样率),加快了收敛速度。随着模型学习的改善和信心的增强,FreeMatch 的阈值升高到较高的值,以减轻确认偏差,导致稳定的高采样率。相应地,FreeMatch 的准确率大幅提高(如图 3© 所示),并且导致更好的类别准确率(如图 3(d) 所示)。值得注意的是,由于高采样率的应用,Dash 在 100k 次迭代之前未能正确学习。

为了进一步展示 FreeMatch 中类别特定阈值的有效性,我们在附录 E.8 的图 5 中展示了 FlexMatch 和 FreeMatch 在 STL-10(40)上的特征 t-SNE(Van der Maaten & Hinton, 2008)可视化。我们展示了每个类别的相应局部阈值。有趣的是,FlexMatch 对类别 0 和类别 6 采用了高阈值(即预定义的 0.95),但它们的特征方差非常大,与其他类别混淆。这意味着 FlexMatch 中的类别阈值无法准确反映学习状态。相反,FreeMatch 更好地聚类了大多数类别。此外,对于相似的类别 1、3、5 和 7,这些类别彼此混淆,FreeMatch 保留了一个高于 FlexMatch 的平均阈值(0.87 vs. 0.84),使得能够掩盖更多错误的伪标签。我们还在附录 E.9 中研究了伪标签的准确性,并显示 FreeMatch 可以减少训练过程中的噪声。

5.4消融研究

自适应阈值 我们对 FreeMatch 中 SAT 组件进行实验,并与 FlexMatch (Zhang et al., 2021)、FixMatch (Sohn et al., 2020)、类别平衡自训练 (CBST) (Zou et al., 2018) 和 AdaMatch (Berthelot et al., 2022) 中的相对阈值 (RT) 进行比较。消融实验在 CIFAR-10(40 个标签)上进行。如表 3 所示,SAT 在所有阈值方案中表现最佳。自适应全局阈值 τt 和局部阈值 MaxNorm(˜pt©) 的结果也与固定阈值 τ 可比,证明所提出的全局和局部阈值是良好的学习效果估计器。

在使用 CPLM(β©) 调整 τt 时,结果较固定阈值更差且具有更大的方差,表明 CPL 的潜在不稳定性。AdaMatch (Berthelot et al., 2022) 使用的 RT 可以视为在第 t 次迭代中基于标记数据的预测计算的全局阈值,而 FreeMatch 则对未标记数据进行 EMA 计算 τt,能够更好地反映整体数据分布。对于类别特定阈值,CBST (Zou et al., 2018) 维持预定义的采样率,这可能是其表现不佳的原因,因为正如我们在第 2 节中分析的,采样率在训练过程中应该发生变化。值得注意的是,我们在此消融实验中未包含 Lf,以确保公平比较。附录 E.4 和 E.5 中关于 FixMatch 和 FlexMatch 的不同阈值的消融研究表明,SAT 有助于减少超参数调优计算或在优化选择阈值的情况下缩短整体训练时间。
在这里插入图片描述
表3:不同阈值方案的比较。

自适应公平性 如表 4 所示,我们还对 CIFAR-10(10 个标签)中 SAF 的效果进行了实证研究。我们研究了公平性目标的原始版本,如 (Arazo et al., 2020) 所述。在此基础上,我们研究了通过直方图对概率进行归一化的操作,并表明对抗不平衡底层分布的影响确实有助于模型学习并实现更好的多样性。值得注意的是,仅仅添加原始公平性正则化就已经有助于提升性能。然而,在对数操作中添加归一化操作却会损害性能,这表明底层批量数据确实不是均匀分布的。我们还评估了 ReMixMatch (Berthelot et al., 2019a) 和 AdaMatch (Berthelot et al., 2022) 中的分布对齐 (DA) 以实现类别公平性,结果显示其效果不如 SAF。DA(AdaMatch)表现较差的一个可能原因是它仅使用标记批次的预测作为目标分布,而无法反映真实数据分布,尤其是在标记数据稀缺的情况下;将目标分布更改为真实均匀分布,即 DA (ReMixMatch),对于极少标记的情况更好。我们还证明了 SAF 可以轻松插入到 FlexMatch 中,并带来改进,详细内容见附录 E.6。EMA 衰减的消融和不平衡设置的性能分析见附录 E.5 和 E.7。
在这里插入图片描述
表4:不同类别公平性项目的比较。

6 相关工作

为了减少伪标记中的确认偏差(Arazo et al., 2020),提出了基于置信度的阈值技术以确保伪标签的质量(Xie et al., 2020a;Sohn et al., 2020;Zhang et al., 2021;Xu et al., 2021),其中仅保留置信度高于阈值的未标记数据。UDA (Xie et al., 2020a) 和 FixMatch (Sohn et al., 2020) 在训练过程中保持固定的预定义阈值。FlexMatch (Zhang et al., 2021) 根据通过有置信的未标记数据估计的每类学习状态以类特定的方式调整预定义阈值。一项同时进行的工作 Adsh (Guo & Li, 2022) 明确优化 SSL 目标中每个类别的伪标签数量,以获取适应性阈值以应对不平衡的半监督学习。然而,它仍然需要用户预定义的阈值。Dash (Xu et al., 2021) 根据标记数据的损失定义阈值,并根据固定机制调整该阈值。一项较新的工作,AdaMatch (Berthelot et al., 2022),旨在通过预定义的阈值乘以标记数据批次的平均置信度来统一 SSL 和领域自适应,以屏蔽噪声伪标签。它需要预定义的阈值,并忽视未标记数据分布,尤其是在标记数据稀少以至于无法反映未标记数据分布的情况下。此外,分布对齐 (Berthelot et al., 2019a;2022) 也在 Adamatch 中被利用,以鼓励对未标记数据的公平预测。由于忽视了模型学习状态与阈值之间的关系,以前的方法可能无法选择有意义的阈值。Chen et al. (2020) 和 Kumar et al. (2020) 尝试从理论角度理解自训练/阈值化。一个动机示例用于推导根据学习状态调整阈值的意义。

除了一致性正则化外,基于熵的正则化在半监督学习(SSL)中也被使用。熵最小化(Grandvalet et al., 2005)鼓励模型对所有样本做出自信的预测,而不考虑实际预测的类别。还提出了对所有样本的熵期望最大化(Andreas Krause, 2010;Arazo et al., 2020),以引导模型公平性,强制模型以相同的频率预测每个类别。但是,前者假设了底层数据分布是均匀的,并且忽略了批量数据分布。分布对齐(Berthelot et al., 2019a)根据标记数据分布和模型预测的指数移动平均(EMA)来调整伪标签。

7 结论

我们提出了 FreeMatch,该方法利用自适应阈值和类别公平性正则化来进行半监督学习(SSL)。在多种 SSL 基准测试中,FreeMatch 超越了强有力的竞争对手,特别是在几乎没有监督的设置下。我们相信,置信度阈值在 SSL 中具有更大的潜力。一个潜在的局限性是,这种自适应性仍然源于模型预测的启发式方法,我们希望 FreeMatch 的有效性能够激励更多关于最佳阈值的研究。

A “双月”数据集的实验细节

我们生成了仅有两个标记数据点(每个类别一个标签,分别用黑点和圆圈表示)和 1,000 个未标记数据点(以灰色表示),这些数据点位于二维空间中。我们训练了一个三层多层感知器(MLP),每层有 64 个神经元,使用 ReLU 激活函数,训练迭代次数为 2,000 次。红色样本表示那些置信值高于 FreeMatch 阈值但低于 FixMatch 阈值的不同样本。未标记数据的采样率通过以下公式计算:
在这里插入图片描述

附录B:定理2.1的证明

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

C 算法

我们展示了 FreeMatch 的伪代码算法。与 FixMatch 相比,FreeMatch 的每次训练步骤涉及从未标记数据批次更新全局阈值和局部阈值,并计算相应的直方图。FreeMatch 相比 FixMatch 引入了非常小的计算开销,这一点也在我们的主要论文中得到了证明。
在这里插入图片描述

附录D:超参数设置

为了便于重现实验,我们在表5和表6中分别展示了 FreeMatch 的详细超参数设置,包括依赖算法的超参数和不依赖算法的超参数。需要注意的是,对于 ImageNet 实验,我们使用了与其他实验相同的学习率、优化器方案和训练迭代次数,并采用了 128 的批次大小;而在 FixMatch 中,使用了 1024 的大批次大小和不同的优化器。从我们的实验中发现,仅用 220 次迭代训练 ImageNet 是不够的,模型在训练结束时才开始收敛。未来将探索在 ImageNet 上进行更长时间的训练迭代。我们在 CIFAR-10、CIFAR-100、SVHN 和 STL-10 的训练中使用单张 NVIDIA V100 显卡。在 CIFAR-10 和 SVHN 上训练大约需要 2 天,而在 CIFAR-100 和 STL-10 上训练则需要 10 天。
在这里插入图片描述

附录E:广泛的实验细节和结果

我们提供广泛的实验细节和结果,作为对主要论文中实验的补充。

E.1 显著性测试

我们使用 Friedman 测试进行了显著性检验。我们选择了在 4 个数据集上表现最好的 7 种算法(即 ( N = 4, k = 7 ))。然后,我们计算了 F 值 (\tau_F = 3.56),该值明显大于临界值 2.661 ((\alpha = 0.05)) 和 2.130 ((\alpha = 0.1))。这一测试表明,所有算法之间存在显著差异。

为了进一步展示我们的显著性,我们在表7中报告了每个数据集的平均错误率。可以看出,FreeMatch 在大多数 SSL 算法中显著优于其他算法。

E.2 CIFAR-10 (10) 标记数据

遵循 (Sohn et al., 2020) 的方法,我们通过为每个类别仅提供一个标记训练样本来研究 SSL 算法的局限性。所选的 3 个标记训练集在图4中可视化,这些训练集由 (Sohn et al., 2020) 使用排序机制 (Carlini et al., 2019) 获得。

E.3 详细结果

为了全面评估所有方法在分类设置中的性能,我们进一步报告了使用相同 10 个标签的 CIFAR-10、使用 400 个标签的 CIFAR-100、使用 40 个标签的 SVHN 以及使用 40 个标签的 STL-10 的精确率、召回率、F1 分数和 AUC (曲线下面积) 结果。如表8和表9所示,除了主要论文中报告的 top1 错误率外,FreeMatch 在精确率、召回率、F1 分数和 AUC 方面也表现最佳。
在这里插入图片描述

E.4 FixMatch 和 FlexMatch 上预定义阈值的消融实验

如表12所示,FixMatch 和 FlexMatch 的性能对预定义阈值 (\tau) 的变化非常敏感。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

E.5 CIFAR-10 (40) 上 EMA 衰减的消融实验

我们提供了关于公式 (5) 和公式 (6) 中 EMA 衰减参数 (\lambda) 的消融研究。可以观察到,不同的衰减 (\lambda) 在使用 40 个标签的 CIFAR-10 上产生接近的结果,表明 FreeMatch 对该超参数不敏感。较大的 (\lambda) 不被鼓励,因为它可能阻碍全局/局部阈值的更新。

E.6 FlexMatch 和 FreeMatch 上 SAF 的消融实验

在表13中,我们比较了在 CIFAR-10 上使用 10 个标签时不同类别公平性目标的表现。FreeMatch 在两种设置下均优于 FlexMatch。此外,当与 FlexMatch 结合时,SAF 也被证明是有效的。

E.7 不平衡 SSL 的消融实验

为了进一步证明 FreeMatch 的有效性,我们在不平衡 SSL 设置 (Kim et al., 2020; Wei et al., 2021; Lee et al., 2021; Fan et al., 2021) 中评估了 FreeMatch,其中标记数据和未标记数据均为不平衡。我们在 CIFAR-10-LT 和 CIFAR-100-LT 上进行了不同不平衡比例的实验。CIFAR 数据集上的不平衡比例定义为 (\gamma = N_{\text{max}}/N_{\text{min}}),其中 (N_{\text{max}}) 是头部(频繁)类别的样本数,(N_{\text{min}}) 是尾部(稀有)类别的样本数。需要注意的是,类别 (k) 的样本数计算为 (N_k = N_{\text{max}} \gamma^{-(k-1)/(C-1)}),其中 (C) 是类别数。遵循 (Lee et al., 2021; Fan et al., 2021),我们将 (N_{\text{max}}) 设置为 CIFAR-10 的 1500 和 CIFAR-100 的 150,每个类别的未标记数据数量是标记数据的两倍。我们使用 WRN-28-2 (Zagoruyko & Komodakis, 2016) 作为主干网络。我们使用 Adam (Kingma & Ba, 2014) 作为优化器。初始学习率为 0.002,采用余弦学习率衰减调度 (\eta = \eta_0 \cos(\frac{7\pi k}{16K})),其中 (\eta_0) 是初始学习率,(k(K)) 是当前(总)训练步数,我们对所有数据集设置 (K = 2.5 \times 10^5)。标记数据和未标记数据的批次大小分别为 64 和 128。权重衰减设置为 (4 \times 10^{-5})。每个实验在三个不同的数据分割上运行,我们报告最佳错误率的平均值。

结果总结在表14中。与其他标准 SSL 方法相比,FreeMatch 在所有设置中均取得了最佳性能。特别是在 CIFAR-10 上不平衡比例为 150 时,FreeMatch 比第二好的方法高出 2.4%。此外,当结合其他不平衡 SSL 方法 (Lee et al., 2021) 时,FreeMatch 在大多数设置中仍取得了最佳性能。

E.8 STL-10 (40) 上的 T-SNE 可视化

我们绘制了 FlexMatch (Zhang et al., 2021) 和 FreeMatch 在 STL-10 上使用 40 个标签的特征的 T-SNE 可视化。FreeMatch 展示出比 FlexMatch 更好的特征空间,集群混淆较少。

E.9 CIFAR-10 (10) 上的伪标签准确率

我们对三个随机种子计算了伪标签准确率的平均值,并在图6中报告了结果。这表明,像 FlexMatch 那样将阈值映射到高固定阈值可能会阻止未标记样本参与训练。在这种情况下,模型可能会过拟合于标记数据和少量未标记数据。因此,未标记数据的预测将包含更多噪声。在训练时引入适当的未标记数据可以避免过拟合于标记数据集和少量未标记数据,并带来更准确的伪标签。

E.10 CIFAR-10 (10) 混淆矩阵

我们在图7中绘制了 FreeMatch 和其他 SSL 方法在 CIFAR-10 (10) 上的混淆矩阵。值得注意的是,即使在我们的设置中使用最不具代表性的标记数据,FreeMatch 仍然取得了良好的结果,而其他 SSL 方法无法将未标记数据分离成不同的簇,显示出与 SSL 中低密度假设的不一致。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

全部评论 (0)

还没有任何评论哟~