论文笔记- Learning from Synthetic Data: Addressing Domain Shift for Semantic Segmentation
论文信息
* 标题: Learning from Synthetic Data: Addressing Domain Shift for Semantic Segmentation
* 作者:Swami Sankaranarayanan, Yogesh Balaji, Arpit Jain, Ser Nam Lim, Rama Chellappa
* 机构:University of Maryland, College Park, MD; GE Global Research, Niskayuna, NY; Avitas Systems, GE Venture, Boston MA
* 出处:CVPR 2018
代码链接
* https://goo.gl/3Jsu2s
论文主要贡献
* 提出使用生成模型将源和目标的分布在特征空间中进行对齐,主要将 DCNN 获取的中间特征表示投影到图像空间中,训练使用重建模型(结合 L1 损失和对抗损失)
* 本文的域适应的对齐过程主要通过使源域特征生成目标域图像,结合对抗损失进行训练,或者是相反方向。随着训练推进,生成图像质量逐渐提高,生成的特征也逐渐趋于域无关的特征
论文要点翻译
摘要
* 视觉领域的域适应问题有极其重要的地位,之前的方法表明甚至深度神经网络也难以学习 domain shift 带来的信息表示问题,这个问题在一些手工标注数据复杂度高、代价大的场景中尤为突出
* 本文聚焦在适应分割网络学习的合成数据和真实数据对应的特征表示上,和之前方法不同,之前方法使用简单的对抗学习目标或者超像素信息指导域适应过程,本文提出了基于生成对抗网络的方法,使得学习的不同域之间的特征表示尽可能相近
* 为了验证方法的泛化能力和扩展能力,本文在两个场景的合成到真实的域适应场景下进行测试,并添加额外的探索性实验验证了本文提出的方法能够较好地泛化到未知数据域,并且方法可以使得源和目标的分布得以对齐
引言
* 深度神经网络带来新的计算机视觉革命,在许多诸如图像分类、语义分割、视觉问答等场景中获得较大的性能提升,这样的性能提升主要归功于丰富的标注训练数据带来的模型能力的提升,对于图像分类这样的任务而言,标注数据的获取相对简单,但是对于其他任务而言,标注数据可能是费时费力的,语义分割任务就是这样的任务,由于需要很多人类工作才能获取每个像素对应的语义标签,标注像素级的语义标签是非常困难的,而获取数据本身就不简单,户外的自动驾驶等自然场景图像容易获取标签,但是医学图像风本身数据难以采集,而且标注数据也需要更大的代价
* 一个有望解决这些问题的方法就是使用合成数据进行训练,然而,合成数据训练的模型往往在真实数据中性能较差,这主要是因为合成场景数据和真实场景数据之间存在 domain gap 的问题,域适应技术就是用来解决域之间的 domain gap 的问题的技术,因此,本文的主要目标是研究适用于语义分割的域适应算法,具体来说,本文主要关注目标域标签数据不可达的情况,也就是无监督域适应 UDA 问题
* 传统的域适应方法主要是将源和目标数据分布进行量化最小化,典型的量化方法是最大均值差异 MMD 和使用 DCNN 学习的距离度量,两类方法在图像分类领域已经取得成功,但是语义分割问题中的域适应还没有得到很好地解决
* 本文提出的工作使用对抗框架进行域的对齐,最近的解决该问题的技术手段主要包括 FCN,该方法使用对抗框架,不像之前的方法判别器直接在特征空间进行操作,本文的方法将特征投影到图像空间(利用生成器),使用对抗 loss 在投影的图像空间中进行判别,以此改进了性能
相关工作
* 域适应方法
方法
X∈RM×N×CX \in \mathbb{R}^{M\times N \times C} 表示任意输入的 C 通道图像,Y∈RM×NY \in \mathbb{R}^{M \times N} 表示该图像对应的语义标签图,给定输入 X 的情况下,将 CNN 输出表示为 Y^∈RM×N×Nc\hat Y \in \mathbb{R}^{M\times N \times N_c},其中的 NcN_c 表示类别数量,Y^(i,j)∈RNc\hat Y(i,j) \in \mathbb{R}^{N_c} 是表示类别对应概率分布的向量,源和目标分别表示为 Xs,XtX^s,X_t
网络描述:训练时对几个网络组件进行迭代优化
* 基本网络,预训练的 VGG-16 之类的模型,分成特征编码部分 F 和像素级分类部分 C,C 的输出是一个上采样到与输入同大小的标签图
* 生成器网络 G,输入学习的特征表示,重建 RGB 图像
* 判别器网络 D 两个不同任务:(1)将输入判断为真实或者假的输入,该过程对两个域是一致的(2)像素级的标签任务,与网络 C 类似,由于训练的目标数据无标签,因此该过程主要从源数据进行
源数据和目标数据的处理
* 给定输入的源数据图像和对应的标签 Xs,YsX^s,Y^s,网络首先通过网络 F 提取得到特征表示。分类器将 F 输出的特征表示作为输入,输出图像大小的标签图 Y^s=C(F(Xs))\hat Y^s=C(F(X^s))
* 生成器 G 在给定图像特征表示的前提下重建 RGB 图像 X^s=G(F(Xs))\hat X^s=G(F(X^s)),网络中使用 dropout 技术代替原有的将输入和噪声组合的方式;判别器 D 主要对真实的源数据与生成重建的源数据进行真假判断,并根据源数据生成像素级的标签图
* 给定目标数据输入 XtX^t,生成器 G 输入 F 得到的特征表示,以此重建目标图像,与之前的过程类似,使用 D 对重建图像的真假进行判断,此阶段 D 只有这一个作用
迭代优化
使用的 within-domain 和 cross-domain 损失如表所示:
| 类型 | 损失 | 描述 |
|---|---|---|
| Within | Ladv,Ds\mathcal L_{adv,D}^s | (源)判真为真;判假为假 |
| Ladv,Gs\mathcal L_{adv,G}^s | (源)判假为真 | |
| Ladv,Dt\mathcal L_{adv,D}^t | (目标)判真为真;判假为假 | |
| Ladv,Gt\mathcal L_{adv,G}^t | (目标)判假为真 | |
| Cross | Ladv,Ft\mathcal L_{adv,F}^t | 假的源输入判为真目标 |
| Ladv,Ft\mathcal L_{adv,F}^t | 假的目标输入判为真源 |
D 更新:LD=Ladv,Ds+Ladv,Dt+Lauxs\mathcal L_D=\mathcal L_{adv,D}^s+\mathcal L_{adv,D}^t+\mathcal L_{aux}^s,Lauxs\mathcal L_{aux}^s 表示辅助的分类 Loss
G 更新:LG=Ladv,Gs+Ladv,Gt+Lrecs+Lrect\mathcal L_G=\mathcal L_{adv,G}^s+\mathcal L_{adv,G}^t+\mathcal L_{rec}^s+\mathcal L_{rec}^t
F 更新:LF=Lseg+αLauxs+β(Ladv,Fs+Ladv,Ft)\mathcal L_F=\mathcal L_{seg}+\alpha\mathcal L_{aux}^s+\beta(\mathcal L_{adv,F}^s+\mathcal L_{adv,F}^t)
