Advertisement

论文阅读:Scene Dynamics: Counterfactual Critic Multi-Agent Training for Scene Graph Generation

阅读量:

本文提出了一种基于图级度量的多代理策略梯度方法来优化scene graph的质量。作者指出现有方法在将目标识别与关系识别结合时存在不足,即假设每个object的重要性相同,而实际上hub node(如bike)对关系的影响远大于non-hub node(如tree)。为此,他们提出通过图级Recall和SPICE度量作为监督信号,并采用多代理策略梯度方法直接优化检测结果。为了提高计算效率,他们引入了counter-factual critic,并通过蒙特卡洛方法进行策略评估。实验表明该方法在检测与通信模型上表现良好,在不影响关系识别模型的前提下提高了预测准确率。

Scene Dynamics

长期以来,在研究场景图(scene graph)这一领域时,我认为最关键的因素是强化对关系识别能力的提升。然而,在阅读本文后我发现,并非必须完全割裂两者的关系——通过将目标识别与关系识别进行有机整合,在某种程度上可以取得更好的效果。

本文主张:现有的关系检测算法普遍未从graph层面进行目标检测。
虽然message passing方法在一定程度上被应用到这里 但现有算法仍未能充分挖掘graph结构的特点。
对于目标的识别 现有[算法](
作者通过观察发现 这种假设存在问题。
举个例子来说 在图b中将bike错误地判断为man会导致之后的4个关系识别出错 而将tree错判为man只会影响单个关系。
这种情况下 将具有较大影响的对象称为hub node 而对影响较小的对象则称为non-hub node。
为了进一步优化检测效果 我们考虑在训练过程中依据各物体间的关系网络统计频率 将那些hub node节点识别出来 并在构建损失函数时赋予它们更高的权重 从而使得模型能够更加关注对整体关系网络有重要影响的对象。

值得注意的是,在hub节点上的误判相比non-hub节点会产生更大的负面影响 ,因此问题的核心在于如何将此指标用作模型训练过程中的监督信号源。文章采用了基于multi-agent policy-gradient的方法来最大化上述graph-level指标。这种方法之所以被称为multi-agent方法的原因在于:它将每张图片中的每一个object都视为独立的agent;而其策略则专注于对每个object进行类别识别。具体而言,

\mathcal{L} = \sum_{v} \sum_{c} w_{v,c} \cdot \text{loss}(y_v^c, \hat{y}_v^c)

其中,

  • v表示所有物体(agents)索引集合;
  • c表示类别标签;
  • w_{v,c}代表各分类任务的重要度权重;
  • \text{loss}(y_v^c, \hat{y}_v^c)表示各分类任务对应的损失函数。

在本文中, 策略π等价于分类概率, 因此可用符号p来表示. 大部分policy-gradient算法(如actor-critic架构)通常会使用单一的神经网络来估计Q值以实现策略优化. 而本文则采用蒙特卡洛方法来进行策略优化

其中v代表物体类别(即动作),h代表经过处理后的物体特征(状态),θ代表策略网络参数(涉及检测与通信)。奖励函数R是一个关键组件,在其定义中包含了两部分:首先是一个关系识别模型;其次结合了基于图层度量计算的方法。具体来说,在使用H和V数据后系统会先推导出关系识别的结果;接着通过计算Recall或SPICE指标来评估策略表现。值得注意的是,在优化算法收敛性方面采取了创新措施:将原始奖励函数减去一个基准值以获得优势形式;这种设计使得整个框架更具竞争力和适应性。 baseline 的计算过程非常有趣

对于每一个object, 我们会将其类别的判断替换为其他类别, 并基于分类概率对每次获得的graph-metric进行加权汇总. 这种方法被称为counter-factual critic, 它的一个显著优点是可以有效地区分每个agent的具体行动.

观察到对于求和中的每一项R,在该过程中agent i在采取动作v_i时所对应的收益增量代表了其相对于其他action所能带来的额外收益。由于在这一过程中其他agent均保持固定不动,最终我们基于多智能体策略的方法论基础构建了相应的理论框架。

在采用policy gradient方法时, 我们也可以将其与其它损失函数计算出的梯度相加, 从而实现网络的整体优化. 其中值得注意的是, CMAT并不会直接影响relationship recognition model的相关参数, 而只会对前面提到的communication和detector中的未被冻结(freeze)的参数产生影响.

关于训练的具体细节:

本研究采用了双阶段训练策略。在第一个阶段中,我们按照常规的关系检测算法框架,并采用交叉熵损失函数进行优化;随后,在第二个阶段中引入了基于CMAT的方法进一步提升性能。值得注意的是,在整个训练过程中,我们对RoIAlign模块中的相关参数均进行了固定设置,以确保后续的学习过程不受其影响。经过上述两个优化步骤后,最终模型达到了预期的性能目标

对于每次遇到的物体类别识别起来存在疑问,在模型中通常是通过比较各类别的概率值来判断。具体来说,在模型输出结果中是否存在两种不同的可能性:一种仅选择具有最高概率的类别;另一种则是采用基于概率分布的采样方法进行判断。为了更好地理解这一过程以及相关的参数设置,请尽量通过查阅文献或联系作者获取进一步的信息

在进行CMAT计算时,在对每个object进行处理时,在所有其他类别上逐一尝试会产生巨大的计算负担。经过一系列实验分析后发现,并非需要在所有类别的概率值上都做改动才能达到预期效果。事实上只需替換两个概率值较低的类别以及背景类(background)即可,并将带来大约70倍的速度提升,并且性能损失极少。

最后观察一下效果如何, 相比而言, 虽然不如LinkNet, 但整体上也有不错的进步. 在评估过程中发现, predcls的召回率有所提升, 这一发现也不可避免地引起了我的注意, 因为CMAT对于relationship recognition model本身并没有造成任何影响, 然而CMAT却对detector和communication model造成了直接的影响. 不过这一结果也可能暗示着multi-task学习可能在某种程度上促进彼此间的协同作用.

全部评论 (0)

还没有任何评论哟~