最优传输论文(三):Joint distribution optimal transportation for domain adaptation论文原理
摘要
本文探讨了基于最优传输理论的无监督领域自适应(Domain Adaptation)方法。通过分析现有技术的假设和限制条件,本文提出了一种新的框架JDOT(Joint Distribution Optimal Transport),该框架利用联合分布的最优运输策略来最小化边缘分布和条件分布之间的差异。具体而言:
理论基础:
- Optimal Transport (OT):用于衡量两个概率分布之间的距离。
- Wasserstein距离:无需假设支撑集相同的前提。
- 边缘分布与条件分布对齐:通过优化目标函数实现源域和目标域样本及标签的对齐。
方法细节:- 损失函数设计:结合分类损失与OT代价矩阵。
- 优化问题:引入超参数α平衡分类误差与OT约束。
- 算法实现:采用块坐标下降法交替优化函数f和转移矩阵r。
实验验证:- 通过数值实验展示了JDOT在分类任务中的有效性。
- 图表说明了源域与目标域样本及标签对齐的情况。
结论与展望:- JDOT提供了一个有效的框架解决领域自适应问题。
- 未来研究可以探索更复杂的模型结构或扩展到多源或多目标领域自适应任务。
该研究为理解并解决复杂的学习任务提供了新的视角和技术支持。
目录
-
前言
-
原理阐述
-
- 摘要
- 介绍
- Joint distribution Optimal Transport
-
- Optimal transport in domain adaptation
- Joint distribution optimal transport loss
-
该问题中目标误差的界限
-
基于联合分布的最优传输学习
-
其他相关研究
-
代码实现
前言
- 本文属于我撰写的关于最优传输的专题系列中的第三篇。该专栏致力于分享本人硕士阶段关于最优传输论文的研究思路和复现过程。
- 本专栏文章的核心内容是对研究思路的详细阐述。
- 论文具体的翻译及复现代码可在文章的GitHub存储库中找到。
原理阐述
摘要
本文探讨了无监督跨域适应问题,在未带任何标注数据的情形下基于已标注源域的知识推断出一个预测函数f。研究者在此基础上提出了以下假设:两个领域在联合特征/标记空间中的分布间存在一种非线性转换关系,并且这种转换可借助最优传输进行建模。为此他们运用最优传输理论构建了一种解决方案方案的核心在于同步优化最佳耦合关系与预测函数f(X)以恢复目标分布 P^f_t = \{(X, f(X))\}。研究结果表明该方法能够达到最小的目标误差界限并提供了有效的算法实现同时确保算法收敛。
介绍
作者指出当下几种较为有效的解决方案间的显著差异均源于对数据分布变化的假设前提上。具体而言,在covariate shift这一假设下,我们假定条件概率P(Y|X)保持一致不变;然而边缘概率密度函数P(x)却会发生明显变化。
在后续部分中,请详细阐述这些方法的理论基础及应对策略。


但是作者认为这些要求虽然看似独立但实际上是共享一个支撑集的基础上提出的。其中最优传输通过定义推进算子T使得源分布与目标分布之间满足关系式P_s(X) = P_t(T (X)) 并且通过这一机制实现分布间的全局运输努力或成本最小化。这种相关的散度被称为Wasserstein距离 它具有自然形式的拉格朗日公式从而无需借助核估计方法来处理连续分布 因此它显著降低了对共享支撑集的需求与KL散度、JS散度不同 Wasserstein距离甚至在两个支撑集差异较大时也能有效衡量两分布之间的距离
本研究中提出了一种新的无监督域适应框架。
Joint distribution Optimal Transport
符号说明:设Ω∈R^d为一个d维紧致输入可测空间,在实际应用中我们发现很多文献中提到的紧致空间(Compact space)或紧集(Compact set)在非数学专业领域通常可以被简单的有界闭集(Bounded and closed set)所替代。这一简化假设对于实际应用已经足够满足需求。其中C代表标签集合,在分类问题中通常取值于实数范围内的有限类别集合。而P(Ω)则表示定义在该输入空间上的所有概率测度集合。具体而言,在本研究中我们关注的是源域数据(X_s, Y_s)与目标域数据(X_t, Y_t)之间的联合分布关系。为了分别刻画这两个领域的特征分布情况,则采用边缘概率分布的概念来描述各自变量的信息。具体来说,在源域中我们分别用P_s(X,Y)和P_t(X,Y)来表示联合概率分布,在目标域上则分别用µ_s(X)和µ_t(X)来表示各自变量的概率分布情况。
Optimal transport in domain adaptation
这一部分主要借鉴了前人研究的结果,在现有的理论基础上进行了拓展与补充。具体而言,在蒙日与Kantorovich问题的研究框架下,并结合了熵正则化的思想以及分类别处理的方法,在之前的博文中进行了探讨与分析。
Joint distribution optimal transport loss
- 作者认为核心在于同步配准边缘分布与条件分布(因为联合概率可分解为边缘概率与条件概率的乘积形式即P(x,y)=P(y|x)P(x))。
- 前期采用OT方法结合熵正则化与类别正则化确定最优运输策略继而通过源域插值获得新的源域样本这些样本在空间维度上与目标样本高度契合随后在此基础上构建并训练分类器最终实现跨域推理。
- 作者另辟蹊径避免先配准后分类而是将分类器直接嵌入到损失函数C中实现同步进行特征配准与类别配准这一创新性思路使得模型能够更高效地完成跨模态匹配任务。
- 其新的目标如下:

其中新的代价D:

通过分析这个代价函数可以看出,该代价函数由两部分组成:第一部分旨在对齐边缘概率分布,第二部分则致力于对齐条件概率分布。初次见到这个式子时会产生一定的困惑,不过随着深入理解,其核心含义变得清晰:其中d是求解最优运输矩阵过程中所涉及的一个关键变量,它反映了n_source_samples与n_target_samples之间元素间的传输成本。这些标签之间的差异实际上构成了衡量源域样本与目标域样本间传输成本的重要指标,具体而言,当标签越接近时,对应的损失L值会显著减小,从而降低传输成本(当然前提是目标域预测结果足够准确)。为了实现这一目标,通常的做法是在使用源域数据进行预训练后进行微调训练效果会更好,因为随机初始化一个分类器直接上场可能会导致预测结果不够理想)。此外还涉及到两个超参数α、λ它们分别用于调节两个不同损失项之间的相对重要性而c则用于量化样本在不同域之间的分布差异程度)。

用于衡量模型性能时,在分类任务中通常会采用分类损失函数来评估预测结果与真实标签之间的差异程度。其中一种方法是使用 hinge 损失函数或者交叉熵损失函数作为衡量标准。可以看出,在这里作者引入了标签相关的损失项,并结合了分布对齐的要求来优化模型性能。从这一观点出发,在源域样本与目标域样本之间距离较近的情况下(即 d(x_1,x_2) 值较小),对应的预测标签之间的相似度 L(y_1,y_2) 也会较小从而有助于减少整体误差。
- 然后就是关于这个式子最小值存在的证明了。作者提到,可以证明,只要D()是下半连续的(见[18],定理4.1),则式子(3)的极小值总是存在并且是唯一的,这是当D()是范数 并且对于每个通常的损失函数[19]时的情况。
- 另外由于y^t_j是未知的,所以作者寻找一个函数映射来作为标签,即:

下面是源域和目标域的经验联合分布:

针对Wasserstein距离的优化问题进行转换时发现,在缺乏对应的目标域标注信息的情况下, 我们采用了一种基于映射生成的人工辅助标注数据来替代真实标签. 这种情况下, 优化目标函数被重新定义为:f(x)

其中W_1是一阶推土机距离:

该研究者对其合理性的理论证明具体展示了通过f替代目标域标签的合理性,并通过图1提供具体案例来支持这一观点。

该方法被应用于回归任务领域。通过观察左图中的y边缘概率分布发现其与另一侧具有较高的相似度。相比之下x边缘概率分布之间的差异显著。左二图展示了源域与目标域联合概率分布之间的显著差异而右二图则表明了所提出的运输方案r与实际值之间具有高度的一致性。通过观察右图我们发现所提出的JDOT方法生成的目标域联合概率分布预测结果与真实值高度吻合。
- 关于 α 的选择: 研究者建议设定 α = 1/\max_{i,j} d(x^s_i, x^t_j) 。在数值试验中显示这种设定在三个测试案例中表现优异,在其中两个案例中展现出特别好的性能。值得注意的是,在某些特定场景下(如当数据分布存在特殊特性时),采用交叉验证方法可能进一步提高该参数的性能。
A Bound on the Target Error
- 定义目标域和源域损失:

err_S(f)同样地。假设损失函数L满足:被限制的、对称且满足k-Lipschitz条件以及符合三角不等式。(We assume the loss function L to be bounded, symmetric, satisfies a k-Lipschitz condition and meets the triangle inequality.)
- 首先是PTL定义:


从直觉上讲,在考虑由两个测度 μ_s 和 μ_T 定义的一个耦合 Π 的情况下, 我们关注的是在 (1/λ) 球内找到与 Π 相关联的不同标记的源-目标实例对发生的概率是多少。
- 上面PTL是为下面的定理3.1服务的:

附录中有相关的证明,见下:
首先回顾一下k-lipschtiz条件:

其他一些简单公式:

正题:

证明:

如何理解第十一行的内容呢?其实P_t^f就是前面提到在目标域标签无法直接获得的情况下,并通过函数f(x)来表示这一概率密度关系。由于我们并不知道函数f的具体形式,并且根据文章开头所述的假设条件,在这种情况下才能从(10)式推导出(11)式的结果。进一步可以看出,在原理论文中考虑的是所有可能的映射函数f的情况;但在当前研究中,并没有直接使用原来的映射函数f而是引入了一个满足特定条件(PTL)的目标域误差表示方法f^*。这种做法是为了在后续推导中能够得到更加有效的优化公式,并且这种基于PTL条件的方法在后续章节中将发挥重要作用,请读者注意后续的具体应用过程。

看起来像(16、17)和(18、19)的是相同的。此外需要注意的是,在Kantorovitch-Rubinstein定理的对偶形式中表现为公式(13)。该定理表明,在任何情况下Π ∈ Π(P_s, P^f_t)时成立。

不等式对于任何耦合都成立。然后我们继续往下看后面的证明,

(18)式的转移至(20)式是基于k-lipschtiz条件这一原理;而(22)式则非常有趣。作者指出前面所述的情况:即f^*与耦合Π^*是以概率1-Φ(λ)满足PTL条件的;换句话说,则存在以概率Φ(λ)不满足PTL条件的情形。由于f^*是H上的一个Lipschitz标签函数,并且满足\left|f^*(x_1)-f^*(x_2)\right|\leq M这一性质;因此引入了一项kMΦ(λ)来表示不满足PTL的情况;而在满足时,则对应于概率为1-Φ(λ)的情形。

即概率测度P(|f(x_1)-f(x_2)| <= λd(x_1, x_2)) >= 1 - Φ(λ)。
首先我们有积分不等式\int k|f^*(x_t) - f^*(x_s)| <= kMΦ(λ) + (1 - Φ(λ))\int λ d(x_t, x_s) dΠ。
其次通过进一步分析可以得出\int λ d(x_t, x_s) dΠ。
再看(24)式,则是基于距离d的对称性质,并设定参数α = kλ;同时遵循L空间中的三角形不等式规则。
接下来我们将继续应用三角形不等式的相关结论进行推导。

(28)式子则是由定理D.1得到的,这里不展开了,有兴趣可以参照原文。
- 看过证明我们就能理解这段话了:

前两项直接关联于目标函数(5),我们主张最小化相关联的采样界限。最后一项φ(λ)用于衡量概率利普希茨性质不成立的可能性。与文献[23,24]中的现有研究相似, 涉及f*的其他两个术语直接关联于联合误差, 表明仅在两个领域均表现出良好预测能力时, 域自适应才能发挥作用。若剩余各项数值足够微小,则通过有效地对齐Ps和Pft^f_t使自适应成为可能的前提条件是最低化联合误差的目标函数具备φ-Lipschitz可转移性。
Learning with Joint Distribution OT
- 本文假设f所属的函数空间H是RKHS,或者是被一些参数w∈R^p参数化的函数空间。这个框架包括线性模型、神经网络和核方法。因此,我们将在f上定义一个正则化项Ω(f)。根据H的定义方式,Ω(f)要么是由RKHS诱导的平方范数的非递减函数(以便表示定理(representer theorem ) 是适用的),要么是向量参数上的平方范数。我们进一步假设Ω(f)是连续可微的。如前所述,f是根据下面的优化问题学习的:

其中损失函数L是连续的并且相对于它的第二个变量是可微的。
-
然后是优化过程 :根据上述关于f和L的假设,问题(6)是光滑的(特指无穷阶可导的函数) ,约束可根据f和γ分离 。因此,解决问题(6)的一个自然方法是依靠交替优化参数γ和f。该算法被称为块坐标下降(Block Coordinate Descent,BCD)或高斯-塞德尔方法(Gauss-Seidel method,该算法的伪代码在论文附录中给出)。下面将详细讨论块优化步骤。
-
用固定f求解归结为一个经典的OT问题(确实,固定了f的话,求解该问题就是一个最优传输问题了),其中损失矩阵C,这样C_{ij}= αd(x^s_i,x^t_j) + L(y^s_i,f(x^t_j))。我们可以使用经典的OT求解器,如网络单纯形算法,但也可以考虑其他策略,如正则化OT [25](即使用对偶和Sinkhorn)或随机版本[26](著名的论文:A. Genevay, M. Cuturi, G. Peyré, and F. Bach. Stochastic optimization for large-scale optimal transport. In
NIPS, pages 3432–3440, 2016.)。 -
具有固定γ的优化问题导致新的学习问题,表示为

下面作者说:

主要意思是f()的参数设置必须基于r", 而其规模较大带来了较大的计算负担。进一步地,在H是RKHS的情况下, 通过核技巧和表示定理, 问题(7)得以被重新表述为一个优化问题, 其中共有N_t个参数均属于实数域R,并且这些N_t个参数实际上指的是f()在每个测试样本上的表现。
当前已固定变量r的取值范围,请确定f的具体数值,则问题等价于求解Estimating f for least square regression problems.

即损失函数被称为平方损失函数。在随后的过程中, 作者对每个源域的标签进行了加权操作。特别关注的是y'_j

所以才有公式(8)的样子。
- 作者还考虑了损失函数是铰链损失函数的情况:Estimating f for hinge loss classification problems.


作者的主要致力于利用one-to-all方案来估计多类分类器。
在这一方案中,
我们主要采用二进制矩阵P来构建相应的判别模型,
以实现对多类别数据的有效分类。

优化问题可以表示为:

其中

其中 P^{\ast} 实际上就是衡量源域样本对应哪一类别的一种指标,在这种情况下每一行都采用了 one-hot 编码的方式进行表示:仅当对应的类别标记为 1 时才为 1(即赋值为 1),其余情况则为 0(即赋值为 0)。进一步说明的是:矩阵 r^{\ast} 的维度则是 N_s \times N_t, 而其转置后的矩阵 \hat{P}^{\ast} 维度则变为 N_t \times k, 其具体含义即是基于转移矩阵 r^{\ast}, 源域样本来自各分类的情况来进行推断。
总体而言, 作者采用了坐标下降法的基本思路, 即固定某一变量, 进而优化另一个变量. 比如说, 在本研究中, 可以选择固定r参数, 进而优化f函数; 或者选择固定f函数, 进而优化r参数. 具体而言, 前者可以通过最小二乘回归方法和 hinge loss 函数实现; 后者则等价于纯 OT (最优运输) 问题, 可借助 OT 问题的专用求解算法.
此外, 作者所构建的联合概率框架具有重要意义. 同样不可或缺的是后续的数学证明部分. 因为从理论层面讲, 这两项内容共同构成了支撑论文创新性研究的基础.
其他
这些内容主要涉及数学优化和证明方面的知识,在此我的阐述可能不够清晰。如有兴趣进一步了解,请参考原文。
代码
- JDOT内容其实是比较简单的,其实就是在利用各种方法计算OT矩阵的时候,考虑的损失不单单是源域样本和目标域样本的损失(比如l2损失),同时考虑使用某个分类器对样本进行分类之后的分类损失。所以说是同时对齐输入数据的边缘概率分布和条件概率分布。
- 然后在官方代码里,作者生成了一些数据来进行域适应。在分类任务中,用SVM作为分类器,在损失中加入这个分类器的分类损失。整体的流程就是计算C,然后用C计算OT矩阵,然后应用OT得到预测标,重复迭代即可。
- 回归任务是先确定了源域和目标域的两个回归函数,然后添加噪声,并根据这两个函数生成源域和目标域数据,看下图:

蓝线分别代表源域的真实函数、绿线则代表目标域的真实函数;红线则是通过数据xvisu进行预测而得到的结果。
- 官方代码地址:https://github.com/rflamary/JDOT
- 我的代码:https://github.com/CtrlZ1/Domain-Adaptation-Algorithms/tree/main/pytorch-JDOT
如果可以,请您动一下勤劳的小手,在github上给个Star哦~
主要是在官方代码的基础上添加了部分注释,以及将原有的keras框架换做了pytorch,另外单独做了生成数据然后分类的工作,分类器不是官方的SVM,而是神经网络,然后加入了Label propagation的内容,注意不能使用神经网络分类的模型来进行分类,具体原因是分类模型只是用来在训练过程中衡量每个源域样本分别与每个目标域样本的差异的,而不是用来做预测的,目标预测应该使用label propagation方法,即通过转移矩阵G来预测标签。而SVM分类器可以,因为它分类的依据就是label propagation,见下面的代码:

Yst实际上是基于Label propagation技术实现的。其预测能力源自模型本身,在代码实现中将详细阐述其内在机理,请关注后续更新。如需进一步了解相关内容,请访问我的另一篇博客:
