24-cvpr-Extracting Graph from Transformer for Scene Graph Generation 学习笔记
Formulation
One-stage Object Detector
我们的模型框架主要借鉴了DETR [1](一种单阶段物体检测器)的设计理念。其基础架构主要源自该架构师所提出的创新性解决方案。在这一过程中,图像特征表示主要是通过结合CNN提取器与Transformer编码器进行深度学习完成的。值得注意的是,在解码阶段,其解码过程借助自注意力机制与跨注意力机制协同工作以提升检测精度。此外,在最终的特征提取环节中,检测头不仅能够识别物体类别信息而且还能准确估计目标边界区域
该编码器通过多头自注意力机制来提升图像表示的效果;经过一个1×1卷积层将输入通道数从C缩减至d_{model};随后将特征图展平为长度为H_F×W_F的序列作为Transformer编码器的输入序列;其中位置编码被引入以捕获特征图的空间信息
Decoder
该解码器接收由N个独立的对象查询构成的输入序列,并输出相应的表征。每个独立的对象被单独检测,在这种情况下参数量常被设为足够大以确保能捕捉图像中的所有目标特征。解码器依次通过自注意力机制和交叉注意力机制处理信息,在此过程中未对自注意力层施加因果掩码操作从而允许各注意头之间进行无约束的信息传递。值得注意的是,在多头自注意机制的学习过程中产生了完整的N \times N维注意力权重矩阵如上所示:
对于每个层级l和每个注意力头h而言,
我们通过计算得到:
A^{(l)}_h = f(Q^{(l)}_h, K^{(l)}_h) = \text{softmax}\left(\frac{Q^{(l)}_h (K^{(l)}_h)^\top}{\sqrt{d_{\text{head}}}}\right)
其中,
A^{(l)}_h \in \mathbb{R}^{N \times N}
表示第l层第h个注意力头对应的权重矩阵。
而Q^{(l)}_h \in \mathbb{R}^{N \times d_{\text{head}}} 和 K^{(l)}_h \in \mathbb{R}^{N \times d_{\text{head}}} 分别代表该注意力头下的查询向量与键向量。
EGTR

我们研发了一种高效轻量型的关系提取器EGTR,并基于DETR架构设计的自注意力机制进行操作(如图3所示)。由于公式(1)中的自注意力权重计算了所有N×N个物体查询之间的相互关联,在整个LL层中我们通过将注意力查询视为主语并将其与键视为宾语的方式解析出谓词信息。
为了保持自注意力层中所学到的重要信息内容, 我们引入了另一个关系函数ff来进行关注与键之间的关联处理, 而不依赖于公式(1)中的标准点积注意力机制. 在处理过程中, 通过连接操作能够完整地保持表示. 我们将图3(a)所示的所有N×N个注意力查询与键的表示进行整合, 从而获得第l层的关系表示Ral∈RN×N×2dmodelR^l_a ∈ \mathbb{R}^{N × N × 2d_{model}}. 在配对连接之前, 在处理过程中加入了一个线性投影步骤以区分查询与键的角色作用.

其中,在数学空间Q^l ∈ \mathbb{R}^{N × d_{model}}和K^l ∈ \mathbb{R}^{N × d_{model}}中分别代表权重参数。
我们还采用了物体查询中最顶层的表示形式ZL∈RN×dmodelZ^L \in \mathbb{R}^{N \times d_{model}}这一特性,在物体检测任务中进行了应用:这些表示均按照一致的方式被应用于物体检测任务中

其中,在形状为 R^{d_{model} \times d_{model}} 的线性权重矩阵中定义了 WSW_S 和 WOW_O 两个参数块。基于此,在各层之间能够更加高效地共享特征表示的同时也能够更好地捕捉到不同位置之间的关联关系

其中,galg^l_a 和 gz∈RN×N×1g^z \in \mathbb{R}^{N \times N \times 1} 表示通过一个单线性层 WG∈R2dmodel×1W_G \in \mathbb{R}^{2d_{model} \times 1} 分别为 RalR^l_a 和 RzR^z 获得的门控值。最后,我们通过所有层关系表示的门控求和提取关系图,如下所示:

其中,在本研究中,
\hat{G} \in \mathbb{R}^{N \times N \times |C_p|}
表示预测的关系图,
多层感知机(MLP)_{rel}
是一个具有ReLU激活函数的三层网络。
值得注意的是,
我们使用sigmoid函数σ
来处理多个关系可以存在于对象之间的这一情况。
Learning and Inference
通过多任务学习的方法来训练EGTR系统

其中,由公式(2)定义的物体检测损失项为Lod\mathcal{L}{\text{od}}。另外两个关键指标分别为关系提取相关的损失函数Lrel\mathcal{L}{\text{rel}} 和连通性预测相关的损失函数 Lcon\mathcal{L}_{\text{con}}。值得注意的是,在使用显式定义的物体检测损失项后,EGTR能够有效识别场景图中的各个节点。具体各任务对应的损失计算细节见下文详细说明。
Relation Extraction

为了进行关系抽取任务, 我们采用了二元交叉熵损失函数. 该方法旨在通过优化损失函数来提升模型对实体间关系的学习能力. 在这一过程中, 我们首先将真实三元组集合E转换为编码形式后得到真实关系图G∈RN×N×∣Cp∣, 其中未满足真实关系条件的关系区域被置零处理. 接着, 在对象检测机制中获取的位置索引用于对预测图中的相关部分进行重新排列. 最终, 通过对比置换后的预测关系图\hat{G}'与重新构造的真实关系图\hat{G}, 计算出基于二元交叉熵损失的关系抽取损失Lrel=Lr(G^′,G). 然而, 由于预测关系图随着实体数量平方增长而变得非常稀疏. 当实体数目NN设定为200时, 如Visual Genome [12]所示的数据集验证集密度仅为10−1410^{-14}. 因此, 我们将原始的关系矩阵分解为三个独立的部分: (1) 真实的关系区域GT; (2) 负样本区域BG; 和 (3) 不匹配区域OM. GT表示在原始矩阵中值为1的位置; BG由主语和宾语均为实际存在的实体但不存在明确关联的所有三元组构成; OM则对应于填充零值的位置. 每个区域均采用不同的训练策略与实体检测模块相结合以提高模型性能
Adaptive Smoothing
我们开发了一种新型自适应平滑技术来处理 GT 区域的问题。针对 GT 区域中的 GijkG_{ijk}实例,在此场景下模型被设计用于推断主体实体 viv_i 和客体实体 vjv_j之间的第k个谓词类别。然而,在训练初期阶段,与 viv_i 和 vjv_j分别对应的检测结果 v^i′\hat{v}'_i和 v^j′\hat{v}'_j缺乏足够的表示能力以捕捉真实对象特征。因此,在某些情况下预测一个实例对的概率值接近于1可能会带来误导性结论。通过自适应平滑处理每个候选对象的检测性能评估结果,并将其映射到关系标签上。首先通过相应的二分匹配成本衡量每个对象候选的不确定性;对于对象候选 v^i′\hat{v}'_i我们定义其不确定性如下:

其中,
\text{cost}_i
代表匹配成本,
而
\text{cost}_{\text{min}}
则指当
\hat{v}'_i
完全吻合
v_i
时的具体数值。
参数
\alpha
则用于衡量最小的不确定度。
我们定义
G_{ijk} = (1 - u_i)(1 - u_j),
这一设定旨在量化各不确定因素之间的相互影响。
通过采用具有不确定度调节的关系标签,
使得目标检测与关系抽取的多任务学习能够根据检测对象的质量进行动态优化。
Negative and Non-matching Sampling.
我们并非全部采用所有负样本,而是从"negative region"这一特定区域中进行抽样。受Liu等人的[22]研究启发,我们基于预测的关系得分为依据,对所有候选的负样本进行排序后选取前k_neg × |E|个最具挑战性的Negative sample作为训练集。同样地,我们从中提取k_non × |E|个具有代表性的Difficult sample用于进一步的数据增强过程。由于未匹配区域通常占据了图 G的大部头面积,这种方法显著降低了数据稀疏性带来的性能损失。
Connectivity Prediction
基于Graph-RCNN [43]通过相关性预测来修剪对象对的影响的基础上, 我们提出了一种基于关系的连通性预测方法.该方法不仅能够判断两个对象节点之间是否存在至少一条连接边, 还能进一步分析其间的具体关系类型.为了构建连通性图, 我们采用了与公式(6)类似的方式, 利用相同的特征表示方法获得了一个\hat{E} \in \mathbb{R}^{N \times N \times 1}的连通性图.针对多标签场景, 我们设计了一个MLP网络con_{\text{con}}来进行二分类任务.最后, 我们从排序后的连通性图中计算出二分类用的交叉熵损失函数: \mathcal{L}_{\text{con}} = \mathcal{L}_{\text{r}}(\hat{E}', E)
Inference
在模型推理过程中,在计算每个三元组(Gijk, Eijk)的得分时,在计算过程中我们采用了如下方法:首先将谓词得分Gijk与其对应的类别的得分vci和vcj进行相乘运算得到该三元组的整体得分;其次为了防止主语与宾语指代同一个实体导致出现自连接的问题我们将Giii设定为零值;此外我们还引入了连通性得分Eijk来提升各个三元组的整体得分水平;最后通过对这些整体得分为基础并结合连通性得分来进行计算处理从而能够有效地剔除那些不存在主语与宾语之间真实关系的不合理的三元组
