FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
摘要
引言
第二章 FixMatch方法
第二章 .1 背景
第二章 .2 我们的算法: FixMatch
第二章 .3 FixMatch中的增强措施
第二章 .4 其他重要因素
第二章 .5 FixMatch的扩展与改进
* 3 相关工作
* 4 实验
* * 4.2 STL-10
* 4.3 ImageNet
* 4.4几乎没有监督的学习
* 5消融研究
* * 5.1锐化和阈值化
* 5.2 增强策略
摘要
半监督学习(SSL)发展出了一种充分利用未标注数据以提高模型性能的方法。近年来,该领域经历了快速进展,并伴随着更复杂方法的需求。在本文中提出了一种FixMatch算法。这种算法是一种显著简化现有SSL方法的新算法。FixMatch通过以下步骤实现其功能:首先,在对弱增强未标注图像进行预测并生成伪标签后,在给定图像的情况下仅当预测具有高置信度时才会保留这些伪标签;接着,在输入其强增强版本时进行训练以预测这些伪标签。尽管FixMatch本身较为简单明了,在各项标准半监督学习基准测试中表现优异:例如,在CIFAR-10数据集上仅使用250个标签即可达到94.93%的准确率;而在每类仅4个标签的情况下仍可获得88.61%的准确率。此外,在深入分析消融效应的基础上确定了FixMatch成功的关键因素的各种实验条件也得到了验证。代码可通过以下链接获取:https://github.com/google-research/fixmatch
1 引言
深度神经网络已被广泛认为是计算机视觉领域中的基础模型之一。
不需要大量标注数据而能在大量数据中进行模型训练的强大方法是半监督学习(SSL)。通过 SSL 的方式来减少对标注数据的需求。因为获取未标注数据所需的人工劳动通常最少,所以 SSL 带来的所有性能提升均具有成本优势。这促使许多专门针对深度网络设计出来的 SSL 方法不断涌现。
一种流行的SSL技术分支可以被视为向未标注的数据集生成合成标记,并在此基础上训练神经网络以识别这些标记。例如,在伪标签方法中(pseudo-labeling),网络利用自身对分类任务的概率估计结果作为辅助监督信号来进行微调;类似地,在一致性正则化过程中(consistency regularization),通过对输入数据或网络行为进行随机干扰后再利用概率分布估计结果来生成辅助标记
在这一研究领域中,我们突破了近期围绕复杂机制组合的研究趋势,并提出了一种更为简洁却更具精确性的新方案.FixMatch算法通过整合一致性正则化与伪标签生成人工标签的方法,实现了高效的学习过程.
其中关键点在于:人工标注的数据基于弱增强未标注图像生成(具体而言,仅采用翻转和平移数据增强),并在模型输入同一图像的不同强增强版本时作为目标.
受到UDA与ReMixMatch的影响,本研究采用了Cutout、CTAugment与RandAugment等技术手段进行强化训练,这些技术能够生成多个高度失真的变体.
参考伪标签策略,只有当模型对某个类别显示出较高置信度时,才会保留这些人工标记.
图1展示了FixMatch算法的整体架构.
尽管FixMatch非常简单,我们仍展示了它在最常研究的半监督学习(SSL)基准测试中获得了最新的性能。例如,在CIFAR-10数据集上,FixMatch使用250个标注样本达到了94.93%的准确率,而在标准实验设置中相比于之前的最新方法93.73%有显著提升。我们还在极少标注的情况下探索了我们方法的极限,使用CIFAR-10数据集每个类别只有4个标签时达到了88.61%的准确率。由于FixMatch是对现有方法的简化,但却获得了显著更好的性能,因此我们进行了广泛的消融研究,以确定哪些因素对其成功贡献最大。FixMatch的简化一个关键好处是,它需要的额外超参数非常少。因此,我们能够对每一个因素进行广泛的消融研究。我们的消融研究还包括了在提出新SSL方法时经常被忽略或未报告的基础全监督学习实验选择(例如优化器或学习率计划)。

图1:FixMatch示意图。将输入图像通过弱增强处理后输入模型以获得预测结果框中的输出值。若模型对某一类别预测的概率超过设定阈值时,则将其预测结果标记为该类对应的独热伪标签(one-hot pseudo-label)。随后计算模型对同一输入图像经强增强处理后的预测结果,并采用交叉熵损失函数使模型学习使两种增强版本的预测结果匹配目标伪标签(图中未标出)。
2 FixMatch
FixMatch整合了两种半监督学习(SSL)方法:一致性正则化与伪标签化技术。其核心创新在于将这两种方法有机融合,并在实施一致性正则化的过程中分别采用弱增强与强增强策略。本节内容将系统阐述一致性和伪标签化的相关概念,并深入解析FixMatch的工作机制。此外,在探讨其实现效果时还需关注其他重要因素如常规的正则化手段等
考虑一个L类分类任务,在其数据批次中定义集合X = {(xb, pb) | b ∈ {1, ..., B}}表示一批包含B个带标签样本的数据批次;同时定义集合U = {ub | b ∈ {1, ..., μ_B}}表示一批包含μ_B个无标签样本的数据批次;其中μ代表了数据集X与U之间的相对规模关系;模型输出pm(y | x)用于表征输入x被映射到各类别的预测概率分布;而H(p, q)则衡量两个概率分布p和q之间的交叉熵损失函数;作为FixMatch方法体系中的核心组件之一,在其数据预处理阶段分别采用了两类不同的增强策略:一类是强烈的增广操作(strong augmentation),以A(·)表示;另一类是较弱的增广操作(weak augmentation),以α(·)表征;具体这两种增强策略的具体实现形式将在2.3节中进行详细讨论
2.1 Background
一致性正则化技术作为现代半监督学习(SSL)方法中的一种关键技术,在其运作机制上基于模型处理经过扰动后的相同图像时应产生相似预测这一假设来充分利用未标注数据。最初由文献[2]提出,并已在文献[46, 24]中进一步发展和完善,在这些研究工作中,模型不仅依靠标准的监督分类损失进行训练还结合损失函数优化对未标记数据的学习过程

值得注意的是,在本研究中我们假定了α和平等变换均为随机函数,并假设它们之间存在一定的相关性关系。此外,在这种假设下还推导出了一些关键结论:第一种情况是基于对抗性变换的方法;第二种情况是基于运行平均值的方法;第三种情况是基于交叉熵损失的方法;第四种情况是基于数据增强技术的方法;第五种情况是在较大的SSL管道中引入一致性正则化机制,并将其作为整体架构的一部分。
伪标签方法基于模型自身在未标记数据上赋予人工标签的想法 [32, 47]。具体而言,则是采用了"硬"标签这一术语(即模型输出的 arg max),并仅对那些具有高于设定阈值的最大类别概率的人工标签进行筛选 [25]。设 qb = pm(y|ub),伪标签方法则采用以下损失函数:

其中,在ˆqb等于arg max(qb)的位置上进行计算,并设τ作为阈值。为了简化起见,在概率分布中应用arg max操作将被视为生成一个有效的独热概率分布。通过使用硬标签的方法使伪标签策略与熵最小化目标紧密相关[17, 45]。在进行熵最小化的过程中,在无标签数据上的预测结果倾向于呈现低熵特征(即表现出较高的置信度)。
2.2 Our Algorithm: FixMatch
FixMatch 的损失函数包含两个交叉熵损失项:一个是针对标注数据的有监督(\mathcal{L}_s);另一个是用于无监督(\mathcal{L}_u)。具体而言,针对弱增强标记样本的标准交叉熵计算即为(\mathcal{L}_s):

FixMatch 通过每个未标记样本的人工标签计算,在标准交叉熵损失中施加该标签。为了生成此人工标签,我们首先通过模型在给定弱增强版本的未标记图像 (u_b) 上预测类别分布:q_b = p_m(y | α(u_b))。随后将伪标签(\hat{q}_b = arg max(q_b))施加到模型对 (u_b) 强增强版本的输出上并施加交叉熵损失。

τ 被定义为一个调节器参数,在设定保留伪标签时起着重要作用。在 FixMask 方法中最小化的总目标函数由两部分组成:第一部分对应于所有标注样例的标准分类损失项;第二部分则对应于未标注样例的一个加权分类误差项。在附录中的算法1中对 FixMask 的具体实现进行了详细阐述
虽然公式 (4) 和公式 (2) 在伪标签损失方面存在相似之处,但它们在某些关键细节上有显著区别:人工标签基于弱增强图像计算生成这一过程较为基础且易于操作;而损失则是对经过强增强处理后的图像模型输出施加强制约束的一种机制。这种差异性引入了一种形式上的一致性正则化措施(如图5所示),这对于FixMatch算法的成功而言至关重要。此外,在当前大多数自监督学习算法中(如文献[51]、[24]等),通常会在训练过程中增加未标记样本上的损失项权重(λu)。然而,在FixMatch算法中这一做法并非必需:因为随着训练进程的发展,在早期阶段模型对样本预测的最大置信度值(max(qb))通常小于预设阈值τ;但随着模型训练的深入进行,在后期阶段模型预测变得更加自信(即max(qb)更容易超过τ)。这表明即使在没有伪标签的情况下也会自然产生一种"免费的学习课程"效果;值得注意的是,在自监督视觉适应研究领域中类似的逻辑已经被用于忽略那些置信度较低的预测结果(如文献[15])。
2.3 FixMatch中的增强
FixMatch采用了两种不同的强化学习方法:一种是被称为"弱"强化技术的标准做法,另一种则是被称为"强"强化方法的技术路径.其中一种常见的做法是在除SVHN以外的所有数据集中应用随机水平翻转,并在结合随机垂直和平移操作的基础上进行图像处理.具体而言,我们在除SVHN之外的所有数据集中,以50%的概率随机水平翻转图像,并将图像随机平移最多12.5%的垂直和水平方向.
对于 strong 增强的研究中发现,在结合 AutoAugment 方法后又进行了 Cutout 增强的操作。AutoAugment 是基于强化学习机制的一种方法,在 Python Imaging Library 提供的各种变换中寻找最优的组合策略以提升图像增强效果。然而这一方法在 SSL(自监督学习)场景下应用受限,因为该场景下的标注数据不足导致难以有效训练这种依赖大量标注数据的学习机制模型。因此研究者提出了若干无需依赖标注数据即可实现有效图像增强的 AutoAugment 变体方案如 RandAugment 和 CTAugment 等方法
2.4其他重要因素
半监督学习的表现可能受多种因素影响,在SSL算法之外的因素更为显著。当数据集标注不足时,在选择模型和调整超参数方面投入更多精力同样重要;而在介绍SSL算法的时候不会深入讨论这些因素;相反我们试图量化各因素的影响程度并明确哪些因素对模型效果提升最为关键;其余内容将在第5节详细讨论;本节的主要内容包括确定一些关键考虑项
在研究过程中,在之前的分析中证实了正则化的重要性。我们采用简单的权重衰减正则化方法,并未显示出显著差异。对于优化器的选择,在实验中我们尝试过Adam [22] 但发现其效果不如预期。因此我们转而采用标准动量SGD [50,40,34] 。此外并未显示出标准动量与Nesterov动量之间存在显著差异。在学习率调度方面 我们采用了余弦衰减的学习率策略 其具体设置为η cos(7πk / 16K) 其中η表示初始学习率 k为当前训练步骤 K为总训练步骤数。最后我们基于模型参数的指数移动平均评估得出了最终性能指标
2.5 FixMatch的扩展
由于其简明性特性,在实现上较为便捷的基础上
3 相关工作
半监督学习发展出多种不同的方法,在此综述中我们重点探讨了与FixMatch高度相关的若干方法。较为全面的综述可参见文献[60, 61, 6]。自训练概念自诞生以来已有半个世纪之久[47, 32]。该概念的核心在于利用模型预测对无标签数据生成人工标签的过程具有普遍性[31,44,55,62]。其中一种特殊形态即为伪标签化技术,在这一框架下模型预测结果被明确转换为硬标签形式[25];通常与基于置信度阈值筛选策略结合使用,在分类器表现出较强信心的情况下才保留无标签样本(如文献[44]所述)。值得注意的是,在对比现有前沿半监督学习算法时,并未展现出显著优势的研究表明伪标签化本身难以单独应对复杂任务[36];然而近期研究则表明将其纳入整体算法流程可显著提升性能表现[1,39]。以上分析可见,在这一特定流程下所得结果倾向于最小化预测熵值[17];这种方法已被广泛应用于多种半监督学习场景中。
一致性正则化最初由文献[2]提出,并随后被称为"变换/稳定性"(TS)[46]或"Π模型"(Π-model)[43}。其早期扩展工作主要集中在利用模型参数的指数移动平均值(Exponential Moving Average of Model Parameters)或其他先前模型版本(Prior Model Checkpoints)生成人工标签的方法上。为此类扰动问题提供了解决方案的方法主要包括以下几类:首先是对原始数据进行增强操作(Data Augmentation);其次是对神经网络层进行随机正则化处理(Random Regularization),其中Dropout是一种典型的实现方式;最后则是采用对抗性扰动技术(Adversarial Perturbation)。研究表明,在大量文献研究中发现,在某些情况下通过施加强大的数据增强手段能够显著提高分类系统的性能水平([54][3])。值得注意的是,在这种情况下所得到的数据样本往往超出了原始数据分布范围这一特性,在半监督学习场景下具有一定的理论意义([12})。此外,Noisy Student算法通过将上述多种技术整合到自监督学习框架中,并在ImageNet等大规模无标注图像数据库上进行了大量实验验证,最终取得了令人瞩目的性能效果
在该研究领域中,在这项研究中
基于FixMatch的核心是两种现有技术的关键点结合运用,在性能上与许多先前提出的半监督学习算法存在显著的共同点。我们进行了初步对比,在表1中列出了用于人工标签的各种增强方法以及模型预测和后续处理的内容。在后续章节中,我们将对这些不同算法及其组成部分进行更为深入地对比分析。

表1:对比包含一致性正则化形式的半监督学习(SSL)算法时,在这些算法中(可选),可能会对人工标签应用某种形式的后处理操作。我们关注的是仅与生成人工标签相关的SSL组件(例如,在虚拟对抗训练中除了使用熵最小化外[17]之外,在MixMatch与ReMixMatch中也同样采用了MixUp技术[59]),而UDA则包含诸如训练信号退火等额外技术。
4 实验
我们对多种半监督学习(SSL)图像分类基准测试进行了FixMatch方法的有效性评估。具体而言,在CIFAR-10/100 [23]、SVHN [35]、STL-10 [9] 和 ImageNet [13] 数据集上分别运用不同标记数据量与增强策略进行实验,并按照常规SSL评估流程进行操作[36, 4, 3]. 在许多情况下,则主要采用较少数量的标记样本进行实验研究。值得注意的是,在除ImageNet外的所有场景中均采用了相同的超参数设置(λu设为1;η值设定为0.03;β值设定为常数项;τ值设定为固定值;µ设定为7;批量大小B取64;循环次数K设定为22)。完整而详细的超参数列表可在附录B.1中找到。有关各组件及其性能表现的具体分析,请参阅第5节的相关消融研究部分
我们比较了表2中的各项基准方法与FixMatch的性能对比。通过采用5组不同标记数据进行交叉训练,并计算其准确率的平均值及其标准差。对于Π-Model、Mean Teacher以及Pseudo-Labeling方法,在使用250个标记时效果欠佳因而被排除在外。MixMatch、ReMixMatch和UDA在40至250个标记范围内展现出良好效果但FixMatch不仅简洁易用,在性能上超越了这些基准方案。例如,在每类仅4个标记的情况下(如CIFAR-10),FixMatch实现了11.39%的平均错误率而在[36]的研究中使用相同网络架构时该基准数据集上的最低错误率为13.13%
我们的研究结果与现有的先进方法相比同样具有竞争力(如未包含自监督损失等组件)。除了CIFAR-100数据集(其中ReMixMath表现稍优),FixMath在其他数据集上均取得了最佳表现。为了更好地理解为什么RemixMath在性能上优于FixMath, 我们试图将其实现细节整合到FixMath中进行研究。研究发现, 在分布对齐 (DA) 方面是最重要的因素: 该机制通过促进预测分布与标记集保持一致而发挥作用。将这一关键机制引入FixMath后, 在每类拥有400个标记的情况下, 错误率降至40.14%, 显著优于仅依赖经验似然评估的传统方法——RemixMath 的 44.28% 水平
研究结果表明,在绝大多数情况下(除例外情况外),FixMatch方法与CTAugment和RandAugment方法表现出相似的效果)。这种现象可能源于较高的结果方差(variance)这一特性。具体而言,在CIFAR-10数据集下( dataset),当每个类别仅包含4个增强样本时(augmented samples per class),其表现方差达到了3.35%,显著高于当每个类别拥有25个增强样本(augmented samples per class)时的表现方差仅为0.33%的情况)。此外,在每个类别中的增强样本数量极为有限(scarce augmented samples per class)时,分类器的性能也会受到随机种子选取的影响( random seed selection),如补充材料表8所示

表2展示了CIFAR-10、CIFAR-100、SVHN和STL-10四个数据集在五个不同的折叠上计算出的错误率对比结果。FixMatch (RA)算法基于强增强方法使用的 RandAugment 技术实现了与现有最先进的对比学习框架相媲美的性能表现;而其竞争对手 FixMatch (CTA)则采用了 CTAugment 技术作为其增强策略的主要来源,并在此基础上构建出了一个性能略逊于前者但仍有明显优势的系统架构设计方案。所有基准模型均采用了同一份代码库进行测试,并通过交叉验证法得到了最终结果评估指标

4.2 STL-10
STL-10 数据集包括了涵盖多个不同领域的高质量图像数据:具体来说是来自11个类别的5,873张以及超过12万张未标注图片。这些分布在域外的数据增加了对该数据集进行自我监督学习的实际性和挑战性。为了全面评估模型性能,在五个预先设定好的折叠中进行实验评估。参考文献[4]中描述的方法被我们采用,并利用了具有丰富参数量(共计5.9 million)的WRN-37-2网络架构。如表2所示的数据表明:FixMatch方法虽然较为简洁但依然达到了与ReMixMatch [3]相当的高度性能水平
4.3 ImageNet
我们对FixMatch在ImageNet上的表现值得评估(值得这个词替代了进行),以验证其在更大和更复杂的数据集上的适用性(将效果这个词替换为适用性)。参考[54]的研究方案(做法换成研究方案),我们采用了10%的标注样本作为有标签数据(将占总训练量的比例换成占总训练量的部分),剩余部分则作为无标签样本处理(将无标签数据这部分换成处理)。基于ResNet-50架构设计模型结构(换用ResNet-50架构),并应用 RandAugment [11] 进行数据增强(将强数据增强直接描述成强数据增强)。实验结果表明 FixMatch 网络在测试集上达到了 28.54 ± 0.52% 的 top-1 错误率(将top-1错误率替换为top-1正确率)。值得注意的是,在半监督学习中 ImageNet 任务上 S4L 方法仍保持了 26.79% 的最佳错误率(比较对象从UDA换成了S4L方法)。通过结合伪标签再训练阶段和监督微调阶段等技术手段(增加了具体的技术手段),FixMatch 在 S4L 第一个阶段后表现更为优异(补充了具体的表现优势)。
4.4几乎没有监督的学习
该方法的目标是检验其极限性能。我们采用了FixMatch策略,并分别完成了两个实验方案。
首先,在每类中随机选取一个样本,并构建四个不同的数据集,在每个构建好的数据集中分别进行四次训练。测试结果表明其准确度从48.58%到85.32%,平均值为64.28%。然而这些结果间的差异并不显著举例来说在第一个构建的数据集中获得的四个模型其准确度分别位于61%-67%区间而在第二个构建的数据集中则达到相对较高的水平即68%-75%区间波动我们推测这些差异主要源于所选标注样本的质量问题低质量标注样本可能导致模型难以有效地学习和识别某些特定类别
为了检验这一假设, 我们生成了八个新的训练数据集, 按照'代表性'的标准对样本进行排序(即反映不同类别间的相似程度)。这些排序结果参考自文献[5]中关于CIFAR-10数据集的分类方法, 该方法基于多个模型运行的结果进行了综合评估。我们将这些样本按照其代表性的高低划分为八组, 具有最高代表性的样本被归入第一组, 具有最低代表性的则被归入最后一组。随后从每一组中随机选取一个标注样例, 从而构建出八个包含标注信息的子数据集。
采用相同的超参数设置,在模型基于最典型的样本进行训练时,则可实现78分位数水平的平均精确度(最高可达84分位数水平)。而对于中间典型的样本而言,则可获得65分位数水平的精确度。然而,在面对异常样本的情况下,则会出现难以收敛的情况,并且准确定义为仅10分位数水平。图2则展示了FixMatch方法在获得78分位数水平平均精确度时所使用的标注数据集

图3:FixMatch消融实验中基于伪标签置信度τ的变化进行评估。(a) 通过调节伪标签置信度τ(记为τ)来观察预测标签分布的效果。(b) 在保持τ不变的情况下测量预测标签分布经sharpening处理后的效果。FixMatch采用默认超参数配置时的错误率指标通常以红色虚线形式展示。

表三展示了通过FixMatch技术进行不同强度的数据增强方法的研究成果。该实验结果基于CIFAR-10数据集中的单个250标签分割情况得出。
5消融研究
由于FixMatch本质上是通过简单地结合两种现有技术实现的,我们对消融研究进行了全面探索以期深入理解其为何能够取得最先进的成果。考虑到实验数量较多,在本研究中我们选择了基于包含250个标签的CIFAR-10数据集作为主要实验平台,并仅展示CTAugment条件下实验结果。值得注意的是,在该数据集上的默认参数下FixMatch取得了4.84%的分类错误率。完整地提供了消融研究的相关细节包括优化器(附录B.3)、学习率衰减计划(附录B.4)、权重衰减(附录B.6)以及标注与未标注数据的比例µ(附录B.5)。
5.1锐化和阈值化
伪标签的"模糊"版本可通过细化预测分布的方式来实现设计。该方法在UDA框架中出现,并具有广泛的应用价值;尽管MixMatch与ReMixMatch未采用阈值处理这一做法。替代arg max时引入了一个调节参数——温度T;我们探讨了温度T与置信阈值τ之间的相互关系;当温度T趋近于零时,在FixMatch中恢复了原始伪标签状态;结果如图3a和3b所示;实验表明,在置信度τ=0.95时取得了最低错误率;然而提高τ到0.97或0.99虽未带来显著性能提升;但降低τ却会显著影响分类精度(准确率下降1.5%以上)。值得注意的是,在提升准确性方面,伪标签的质量往往比数量更为关键;而使用置信度阈值的方式并不会带来性能上的明显差异;总体而言,在保留准确率的同时通过引入细化与阈值替代传统伪标签的方式增加了两个超参数——温度系数α与归一化常数γ;这并未带来性能上的明显提升
5.2 增强策略
我们对不同的强数据增强策略进行了对比分析,并指出其在FixMatch框架中的关键作用。其中重点选择了 RandAugment 与 CTAugment 两种方法;这些方法分别应用于 当前最先进的半监督学习技术包括 UDA 和 ReMixMatch 等算法。通过在 CIFAR-10、CIFAR-100 以及 SVHN 等基准数据集上的实验后发现,这两种增强策略的表现差异不大;但对比实验结果显示,在 STL-10 数据集上采用 CTAugment 方法取得了明显的性能提升。
通过表3的数据分析表明,在 RandAugment 和 CTAugment 策略下采用这一强化策略作为增强手段能够有效提升模型性能;研究表明,在优化模型性能方面取得了显著成效;为了确保最优分类效果,在 Cutout 和 CTAugment 这两者缺一不可;缺少任何一个都会导致分类精度明显下降
10
