最优传输论文(二十六):Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation论文原理
这篇论文提出了一种基于切片瓦瑟斯坦差异(SWD)的方法,用于无监督领域自适应任务。以下是论文的主要内容和贡献:
摘要
论文提出了一种新的无监督领域自适应方法,通过测量任务特定分类器之间的切片瓦瑟斯坦差异来对齐源域和目标域的特征分布。该方法利用了瓦瑟斯坦距离的几何意义,提出了切片瓦瑟斯坦差异(SWD),并将其集成到一个端到端可训练的框架中。实验结果表明,该方法在多个任务(如数字和符号识别、图像分类、语义分割和目标检测)上表现优异,优于现有方法。
引言
领域自适应是机器学习中的一个重要问题,特别是在跨领域任务中,如目标检测、图像分类等。然而,现有方法通常依赖于领域间对齐技术,如基于对抗网络的域内对齐,这些方法在实际应用中存在一些局限性。论文指出,任务特定的分类器输出之间的差异是衡量领域自适应的有效指标,并提出了基于切片瓦瑟斯坦差异的方法。
相关工作
MCD框架:通过最大化分类器输出之间的差异来对齐特征分布,但仅适用于分类任务。
JDOT方法:通过联合条件分布匹配,但需要多个鉴别器和复杂的优化过程。
SWD框架:通过切片瓦瑟斯坦差异直接衡量分类器输出之间的差异,无需中间对齐步骤,具有端到端可训练性。
方法
论文提出了一种基于切片瓦瑟斯坦差异的方法,具体步骤如下:
框架设置:定义无监督领域自适应的设置,包括源域和目标域的输入数据。
最优传输和瓦瑟斯坦距离:介绍瓦瑟斯坦距离及其在概率度量中的应用,特别是切片瓦瑟斯坦差异。
切片瓦瑟斯坦差异(SWD):通过将高维测度投影到多个方向,计算一维最优传输问题的变分形式。具体来说,使用多个随机投影θ将高维测度转化为一维,计算排序后的差异。
端到端训练:将SWD集成到特征生成器和分类器的优化过程中,通过最小化差异损失来对齐特征分布。
实验
实验在多个任务(如数字和符号识别、图像分类、语义分割和目标检测)上进行,使用了多个基准数据集(如SVHN、MNIST、USPS、GTSRB)。结果表明,所提出的方法在所有任务中均优于现有方法,特别是在数字和符号识别任务中,通过SWD自适应对从SVHN到MNIST的适应,生成了更多有区别的特征表示。
结论
论文提出了一种创新的方法,通过任务特定的分类器差异和切片瓦瑟斯坦差异,实现了高效的无监督领域自适应。该方法在多个任务中表现出色,未来的工作包括扩展到其他领域自适应任务,如开放集自适应和零触发自适应。
代码
论文提供了PyTorch版本的
文章目录
-
引言部分旨在阐述本研究的背景、意义及其研究目标。
-
本研究摘要主要阐述了所提出的方法及其在多个领域的应用效果。
-
介绍部分详细阐述了本文的研究背景、研究目标以及所采用的方法框架。
-
相关工作部分系统回顾了现有研究的成果及其与本文方法的异同点。
-
方法部分将详细阐述本文所提出的核心技术及其实现原理。
-
方法部分包含三个主要子部分,分别是方法框架设定、最优运输与Wasserstein距离,以及基于切片Wasserstein距离的学习方法。
-
方法框架设定部分介绍了本文所采用的系统架构设计。
-
最优运输与Wasserstein距离部分详细阐述了最优运输理论及其在度量空间中的应用。
-
基于切片Wasserstein距离的学习方法部分探讨了如何通过切片技术提升模型的泛化能力。
-
4. Experiments
-
- 4.1.数字和符号识别
-
结论
-
代码
-
前言
文章来自2019年的CVPR会议。
本文属于领域自适应与最优传输研究领域的第26篇系列文章。该系列文章的所有代码均已发布在**https://github.com/CtrlZ1/Domain-Adaptation-Algorithms**,希望各位读者能够给予关注与支持,点赞和星标不吝啬哦~
在阅读本文之前,建议读者先阅读以下两篇论文,以确保阅读体验更佳:
- Wasserstein barycenter及其在纹理混合中的应用
博文链接: - 基于无监督域适应的最大分类器不一致度
博文链接:
这些内容对应于本文的文献52和57。
摘要
在本研究工作中,我们将两个概念关联起来,用于领域间的无监督自适应:通过基于任务特定决策边界进行的领域间特征分布对齐[57]和瓦瑟斯坦度量[72]。我们提出切片瓦瑟斯坦差异(SWD),旨在捕捉任务特定分类器输出间的不相似性。该方法为检测与源域特征显著不同的目标样本提供了具有几何意义的指导,并通过端到端可训练的方式实现了分布对齐。在实验中,我们验证了该方法在数字和符号识别、图像分类、语义分割和目标检测方面的有效性与普适性。
1.介绍
深度卷积神经网络是现代机器感知系统发展的关键性技术,已在分类、语义分割、目标检测等核心任务中取得显著成果。然而,尽管深度学习模型在学习能力和泛化性能方面表现出色,但在领域转换问题上仍面临挑战——在合成与真实领域数据之间的转换关系难以有效建模。在这一领域,模型的泛化性能难以在不同领域间保持稳定,导致在目标领域上表现不佳。领域转换问题主要以以下三种形式存在:协变量转移(Covariate shift)、先验概率转移(prior probability shift)和概念转移(Concept shift)。
2.Related Work
无监督领域自适应方法的大量研究旨在通过统计矩匹配技术学习领域不变的特征表示,以缩小源领域与目标领域的差距。其中,一些方法利用最大平均差异(MMD) [38,39]来匹配深度神经网络中某些层的隐藏表示,而其他方法则采用中心矩差异(CMD)方法[75]以显式匹配高阶矩的每个阶和每个隐藏坐标。此外,自适应批处理标准化(AdaBN) [33]也被提出,以调整域间网络中所有批处理标准化层的统计数据。
另一种策略通过利用GANs的对抗性学习机制来解决领域适应问题[19]。该方法首先在特征级进行应用,其中训练域鉴别器以正确分类输入特征的域,并训练特征生成器以欺骗域鉴别器,从而使得到的特征分布达到域不变性[71,24,18]。随后,该技术被应用于像素级,以在原始输入空间中执行分布对齐,将源域转换为目标域的“样式”,并通过在转换后的源数据上训练的模型来实现目标[36,70,5,62,23,59,45]。最近的研究则通过假设输出空间包含用于特定任务(如语义分割)的相似空间结构,将该技术扩展到了输出级[69]。在[60,25]中,还提出了其他混合方法。
相比之下,Saito等人的工作则通过将任务特定的分类器作为鉴别器,直接对齐分布。他们的框架通过最大化两个分类器输出之间的差异,来检测源支持之外的目标样本,然后通过最小化差异来生成相对于决策边界在源支持之内的特征表示。这种方法与基于启发式的流形对齐不同,而是直接重塑目标数据区域的实际需求。
3. Method
在第3.1节中,我们首先介绍了无监督域自适应设置。然后,我们简要回顾了第3.2节中的最佳运输概念。最后,我们详细介绍了如何利用第3.3节中的切片瓦瑟斯坦差异来训练所提出的方法。
3.1. Framework Setup
给定从源集\{X_s,Y_s\}提取的输入数据X_s和相应的ground truth y_s,以及从目标集X_t提取的输入数据x_t,无监督领域自适应学习机制旨在建立从标记源集到无标记目标集的知识转移过程,如文献[47]所述。当两个数据分布X_s和X_t足够接近时,可以简化为仅最小化联合概率分布P(X_s,Y_s)的经验风险。然而,当这两种分布存在显著差异时,仅在源信息上优化模型会导致泛化能力的下降。
基于最大分类器差异(MCD)框架[57],我们训练特征生成器网络G以及分类器网络C_1和C_2,它们通过从G生成的特征响应,分别生成相应的条件概率p_1(y|x)和p_2(y|x)(如图1所示)。

图1展示了建议的切片瓦瑟斯坦差异(SWD)计算示例。SWD旨在通过量化特定任务分类器C_1和C_2之间概率测度p_1,p_2 ∈ R^d的不同性,这些概率测度是从特征生成器g中获取输入的。SWD通过使用单位球面S^{d-1}上的均匀测度进行径向投影,结合瓦瑟斯坦度量的变分公式,实现了端到端的直接训练。这种方法为检测那些偏离源支持域的目标样本,提供具有几何意义的指导。详情请参考第3.3节。
优化过程包含三个步骤:(1) 基于源域\{X_s,Y_s\},训练生成器G以及分类器集合(C_1,C_2),用于对齐或回归源域样本:

其中,L_s可以是任何关注的损失函数,例如交叉熵损失函数或均方误差损失函数。
(2) 冻结生成器G的参数并更新分类器(C_1、C_2)以最大化目标集合X_t上两个分类器的输出之间的差异,识别特定任务决策边界外的目标样本。

在步骤(3)中,固定两个分类器的参数,并通过优化生成器G,以最小化目标集合X_t上两个分类器输出之间的差异。其中,差异损失函数L_{DIS}(X_t)采用文献[57]中提出的L_1形式。同时,目标函数L_s(X_s,Y_s)也被纳入该优化过程,以保持源域信息。

这一步使目标特征流形更接近源。
3.2.最优运输和瓦瑟斯坦距离
在MCD框架中,域自适应的效果主要由差异损失的可靠性决定。若无差异损失的学习,实质上就是在训练过程中省略了步骤2和步骤3,仅在源域上进行监督学习。

其中,T^\#(\mu) = \nu代表从测度\mu到测度\nu的一种单射前推操作,作用于所有Borel子集族A \subset \Omega。测地线度量函数c: \Omega \times \Omega \rightarrow \mathbb{R}^+可以是线性的或二次的。值得注意的是,由于概率测度不满足分裂条件,解T^\#的存在性不能被保证,例如在将Dirac测度映射至非Dirac测度的情形下。
Kantorovitch在文献[27]中提出了方程4的一个较为松散版本,该方法寻求一个联合概率分布的运输计划,属于概率测度空间P(Ω×Ω),以满足特定的优化条件。

其中,

π_1和π_2表示Ω×Ω到Ω的两个边缘投影。解γ^*被称为最优运输计划或最优耦合[72]。
- 对于q ≥ 1,P(Ω)中µ与ν之间的q-Wasserstein距离定义为:

最小化总成本的目标基于最优运输计划的优化。我们的方法采用了1-Wasserstein距离,也称为地球动子距离(EMD),该指标衡量的是概率分布间的差异。
3.3. Learning with Sliced Wasserstein Discrepancy
在本研究中,我们建议采用1-瓦瑟斯坦距离应用于3.1节中描述的域自适应框架。在该框架中,我们基于几何的1-瓦瑟斯坦距离作为步骤2和步骤3的差异度量标准。具体而言,我们关注分类器输出概率p_1(y|x)和p_2(y|x)的离散形式。计算W_1(p_1,p_2)需要通过求解线性规划问题[27]来获得最佳传输耦合γ^*,这一过程在实际应用中并不高效。尽管已有多种优化方法[11:Sinkhorn distances: Lightspeed computation of optimal transport,16]被提出,但如何通过端到端可训练的方式有效地直接优化W_1(p_1,p_2)仍是一个未解之谜。

其中,R_θ对应于概率测度µ或ν上的一维线性投影操作 ,θ是R^d空间中单位球面S^{d-1}上的均匀分布测度 ,满足\int _{S^{d-1}}dθ= 1。通过这种方式,计算切片的瓦瑟斯坦差异相当于对应于几个一维最优运输问题,这些问题具有明确的解析解[52]。
具体而言,假设α和β是排列n个样本的n个一维线性投影的排列,满足∀0≤i< N-1,有R_θ µ_{α(i)}≤ R_θµ_{α(i+1)}以及R_θ v_{β(i)}≤ R_θv_{β(i+1)},那么最小化这种一维瓦瑟斯坦距离的最佳耦合γ^*通过排序算法将 R_θ µ_{α(i)}分配给R_θν_{β(i)}。对于离散概率度量,我们的SWD的具体计算公式可以表示为:

基于随机采样的θ和c的二次损失,除非另有说明。SWD本质上是原始Wasserstein距离的一种变分形式,其计算开销相对较小[4]。值得注意的是,SWD以其紧致形式具有可微性,从而,我们可以将最优传输理论可靠地应用于特征生成器和分类器的优化。具体方法在算法1中概述,图1则展示了SWD的计算流程。
- 算法1:

θ是定义在R^d单位球面S^{d-1}上的均匀测度,满足\int _{S^{d-1}}dθ= 1,该测度将高维空间中的测度转换为一维表示,其本质上相当于一个归一化的系数,同时与测度p进行线性组合。该测度的模长为1,确保了其在测度空间中的标准化特性。为了构建完整的测度系统,我们通常选取M=128个这样的θ。以分类数为10为例,这样可以得到一个10维的高维测度表示。根据输入的batch数量,输入的p尺寸为[batchsize,10],这与M个10x1的θ矩阵进行乘法运算,得到一个形状为[batchsize,128]的矩阵。该矩阵的每一行对应一个样本与所有θ的点积结果。为了评估预测结果的准确性,我们采用二次损失函数,并对两个测度所生成的[batchsize,128]矩阵进行MSE(均方误差)计算。
4. Experiments
基于我们的方法,它适用于任何领域自适应任务,并且无需在输入或输出空间中引入相似性假设。
4.1.数字和符号识别
在本实验研究中,我们采用了五个标准化基准数据集进行性能评估,包括街景门牌号(SVHN) [46]、MNIST [31]、USPS[26]、合成交通标志(SYNSIG) [42]以及德国交通标志识别基准(GTSRB) [65]。针对每对域偏移,我们采用了斋藤等人[57]提出的精确CNN架构设计。在所有实验任务中,我们采用了batch大小为128的Adam优化器 [28]。为了训练网络模型,我们应用了梯度反转层(GRL) [17],因此无需调整生成器与分类器之间的更新频率。我们的研究方法的独特之处在于关注径向投影M的数量这一超参数。通过实验研究,我们详细分析了不同M值对模型性能的敏感性,具体结果展示在图2(a)和图2(b)中。

图2展示了不同径向投影数量对分类精度的影响,其中(a) SVHN对MNIST具有良好的适应性,(b) SYNSIG对GTSRB同样表现出良好的适应性。实验结果表明,当设置参数M=128时,能够稳定实现优化过程并获得较高的分类精度。T-SNE方法[40]通过仅利用源域特征和自适应距离度量(©)实现了从SVHN到MNIST特征的可视化效果。图中,蓝色和红色点分别标识源样本和目标样本。与仅依赖源域特征的设置相比,本方法能够生成更多具有显著区分度的特征表示。
- SVHN → MNIST :我们首先检查从谷歌街景图像获得的真实世界的门牌号[46]到手写数字[31]的适应性。这两个域展示了不同的分布,因为来自SVHN数据集的图像包含来自街道的杂乱背景和图像边界附近的裁剪数字。我们使用标准训练集作为训练样本,测试集作为源域和目标域的测试样本。特征生成器包含三个5×5 conv图层,在前两个conv图层之后放置两个3×3最大池。对于分类器,我们使用3层全连通网络 。
- SYNSIG → GTSRB :在这个设置中,我们评估了从合成图像SYNSIG到真实图像GTSRB的适应能力。我们随机选择了31367个样本进行目标训练,并对其余样本的准确性进行了评估。特征生成器包含三个5×5 conv图层,在每个conv图层后放置两个2×2最大池。对于分类器,我们使用2层全连通网络 。对两个域之间的43个公共类进行了性能评估。
- MNIST ↔ USPS :对于双向 域转换实验,我们还遵循[57]提供的协议,即我们使用标准训练集作为训练样本,测试集作为源域和目标域的测试样本。特征生成器包含两个5×5 conv图层,在每个conv图层之后放置步长为2的2×2最大池。对于分类器,我们使用3层全连通网络 。
- 结果 :表1列出了通过四个不同的域偏移获得的目标样本的准确度。我们观察到,我们的社署方法在所有情况下都优于竞争方法。所提出的方法也大大优于直接可比方法MCD [57],四种设置的绝对精度平均提高了2.8%。图2(a)和2(b)显示了消融研究对径向投影数M的敏感性。在我们的实验中,我们根据经验发现M = 128 在所有情况下都很有效。我们还在图2©和2(d)中可视化了所学的特性。与仅源设置相比,我们的方法生成了更多有区别的特征表示。
表1:

表1展示了跨数字和交通标志数据集在无监督领域自适应任务中的结果。为了确保结果的可靠性,我们重复每个实验5次,并计算了准确度的平均值和标准偏差。与直接可比方法MCD [57]以及其他方法相比,我们的方法在性能上具有显著优势。
- 有趣的是,任务特定的差异感知方法,如MCD [57]、DeepJDOT [12]和建议的SWD,是当前处理此处任务的主要方法。这证明了利用任务特定的决策边界(差异)来指导迁移学习过程的重要性,而不是在大多数其他分布匹配方法中简单地匹配像素、特征或输出空间中源域和目标域之间的分布 。特别是,基于对抗训练的方法需要一个单独的生成器和多个鉴别器,它们通常比主任务网络本身更大。例如[23]中的方法使用10层生成器、6层图像级鉴别器和3层特征级鉴别器,而主任务网络是4层网络。此外,辅助鉴别器在训练完成后被丢弃。
- 此外,所提出的方法和DeepJDOT [12]方法之间的主要区别在于DeepJDOT需要一个多阶段的训练过程——它训练一个CNN并迭代地解决一个线性规划任务。DeepJDOT还假设,当将伪标签从源域传播到目标域时,小批量中每对样本之间的真正最佳传输耦合会收敛,而实际情况往往并非如此。这强调了选择几何上有意义的差异度量的重要性,该度量不假设标签空间中的最佳传输耦合,并且是端到端可训练的,优化一个差异损失,而不是独立解决多个损失。
- 我们注意到[63]中的方法通过各种工程努力,例如使用实例归一化、添加高斯噪声和利用更深的18层网络,在SVHN到MNIST自适应任务上获得了99.4%的结果。这种基于聚类假设的方法实现了与我们相同的性能,我们将这种架构搜索留给未来探索。
- 后面作者又在多个领域进行了实验,这里就不继续说了。
结论
- 在本文中,我们开发了一种新的无监督域自适应方法,该方法通过测量任务特定分类器之间的切片瓦瑟斯坦差异 来对齐分布。与瓦瑟斯坦度量的联系为以有效的方式更好地利用其几何上有意义的嵌入铺平了道路,在过去,这主要局限于在标签空间中获得一对一的映射。我们的方法是通用的,并且在数字和符号识别、图像分类、语义分割和对象检测任务上取得了优异的结果。未来的工作包括扩展我们的域随机化方法[67],开放集自适应[58],以及零触发域自适应[48]。
- 所谓Sliced Wasserstein distance,其实就是利用投影 θ将特征投影到一维空间,θ是R^d上的单位球面S^{d-1}上的均匀测度,即\int _{S^{d-1}}dθ= 1,它负责将高维的测度转化为一维,也就是一个系数,与测度p线性组合,另外其模长为1。同时我们需要M个θ,一般M取128个 。举个栗子,那我们以分类数为10 为例,那么得到的高维测度就是10维 的,根据batch的数量,输入的p尺寸是[batchsize,10],对应一个M个10x1的θ,由于M一般取128,那么M个θ就是[10,128],且在列上是归一化的。然后θ与p相乘后得到[batchsize,128]的矩阵,某一行的每个维度都是一个θ的输出结果。在一维情况下Wasserstein问题有闭式解,对投影之后进行排序,然后按照公式计算即可。我们根据代码来详细理解一下:
s = p1.size(1)
if s > 1:
# For data more than one-dimensional, perform multiple random projection to 1-D
theta = torch.rand(10, 128)
theta = (theta / torch.sum(theta,dim=0)).to(device)
p1 = torch.matmul(p1,theta)
p2 = torch.matmul(p2,theta)
p1 = sort_rows(p1)
p2 = sort_rows(p2)
wdist = (p1-p2)**2
代码
- 官方发布:https://github.com/apple/ml-cvpr2019-swd
- PyTorch版本:我的GitHub地址是https://github.com/CtrlZ1/Domain-Adaptation-Algorithms
- 官方采用的是moons数据集,而我采用的是usps数据集,经过usps→mnist的数据转换,目前尚未进行参数调优,当前实验结果尚不理想,但整体方法框架已经较为完善。
