【Unbiased Scene Graph Generation from Biased Training】场景图评估指标详解(附代码解读)
文章目录
- 前文
- 在第k个位置上的召回率为 R @ K
- 无图约束下的召回率在第k位(ngR @ k)
- 平均召回率在第k位(mR @ k)
- 零样本下的召回率在第k位(zR @ k)
- top-k准确度(A @ K)
前文
这篇文章所提到的主要场景图评估方法均基于Unbiased Scene Graph Generation from Biased Training的研究成果。下文是我引用作者表述基础上,并基于其源代码进行解读所得出关于每个评估指标更加详细的具体描述,并附有相关代码注释。
Recall@K (R@K)
该指标自是最具权威性的测量标准。
由卢老师在其论文https://arxiv.org/abs/1608.00187中首次提出。
鉴于VisualGenome数据集中的relation标签存在缺失,
其单一准确率无法充分评估模型性能。
卢老师将其视为一种检索任务,
并采用Recall作为评估指标,
不仅要求识别精确性,
还能有效区分非相关物体配对。
对于任意一张图片,在任意一张图片中
对于前K个未经排序且按顺序从前到后选取的预测关系(这K个预测关系直接从前往后取),找到与其相匹配的真实关系包括<主语, 动作, 宾语>。这些真实关系与预测集合的相关性计算中,并集的数量即为m。每张图像都包含其已标注的真实关联,并涉及的真实关联数量为n。其中,在判断预测与真实关联是否匹配时,默认条件是主语、动词和宾语完全一致,则其R@k值计算公式如下:
R@k = \frac{m}{n}
需要注意的是,在判断预测与真实关联是否匹配时,默认条件是主语、动词和宾语完全一致。
def calculate_recall(self, global_container, local_container, mode):
pred_rel_inds = local_container['pred_rel_inds']
rel_scores = local_container['rel_scores']
gt_rels = local_container['gt_rels']
gt_classes = local_container['gt_classes']
gt_boxes = local_container['gt_boxes']
pred_classes = local_container['pred_classes']
pred_boxes = local_container['pred_boxes']
obj_scores = local_container['obj_scores']
iou_thres = global_container['iou_thres']
# [主体序号, 客体序号, 最匹配的关系序号]
pred_rels = np.column_stack((pred_rel_inds, 1+rel_scores[:,1:].argmax(1)))
# 每个预测关系对应的关系分数
pred_scores = rel_scores[:,1:].max(1)
gt_triplets, gt_triplet_boxes, _ = _triplet(gt_rels, gt_classes, gt_boxes)
local_container['gt_triplets'] = gt_triplets
local_container['gt_triplet_boxes'] = gt_triplet_boxes
pred_triplets, pred_triplet_boxes, pred_triplet_scores = _triplet(
pred_rels, pred_classes, pred_boxes, pred_scores, obj_scores)
# Compute recall. It's most efficient to match once and then do recall after
pred_to_gt = _compute_pred_matches(
gt_triplets,
pred_triplets,
gt_triplet_boxes,
pred_triplet_boxes,
iou_thres,
phrdet=mode=='phrdet',
)
local_container['pred_to_gt'] = pred_to_gt
for k in self.result_dict[mode + '_recall']:
# the following code are copied from Neural-MOTIFS
match = reduce(np.union1d, pred_to_gt[:k])
rec_i = float(len(match)) / float(gt_rels.shape[0])
self.result_dict[mode + '_recall'][k].append(rec_i)
return local_container
No Graph Constraint Recall@K (ngR@K)
该指标最初由Pixel2Graph提出,并被命名为Neural-MOTIFS。其主要目标是解决传统Recall计算方法中的局限性:即在排序过程中仅允许一对物体之间通过单一的关系进行比较。然而ngR@K的方法则不同:它允许同一对物体之间的多个关系参与到排序过程中。例如,在人类-骑-马序列中(human(0.9)-riding (0.6)-horse (0.9)),总分计算为三个关系强度的乘积(即 ),而如果存在人类-在-马序列(human(0.9)-on (0.3)-horse (0.9)),总分则为三个关系强度的乘积(即 )。需要注意的是尽管后者得分较低但在某些情况下仍然可能成立。与传统的R@K相比 ngR@K的表现显著优于后者。
在预测关系对的选择机制上存在差异。具体而言,在No Graph Constraint场景下,在给定一张图片中所有主客体对<主体, 客体>分别与各自对应的关系类别中的分数相乘(即:每个主客体对<主体, 客体>分别与该主客体对所对应的每个关系类别的分数相乘)。假设共有 k 个关系类别,则每个主客体对<主体, 客体>会生成 k 个预测的关系三元组 <\text{主体}, \text{关系}, \text{客题}> )。随后将所有预测的关系三元组按照' 主题得分 × 客题得分 × 关系得分 '从高到低进行排序。
选择前K个预测的关系对,在图像上与之匹配的真实关系对被筛选出来,并集总数(去除重复项)设为m。每张图片上已标注的真实关系数目设为n。那么该图片上的检索准确率在前K步(即Recall@K)等于这个比值:R@K=\frac{m}{n}
注意: No Graph Constraint 也可以被用来介绍接下来将要讲解的mR@K和zR@K评估指标,并不仅限于使用 R@K 作为主要的标准。
def calculate_recall(self, global_container, local_container, mode):
obj_scores = local_container['obj_scores']
pred_rel_inds = local_container['pred_rel_inds']
rel_scores = local_container['rel_scores']
pred_boxes = local_container['pred_boxes']
pred_classes = local_container['pred_classes']
gt_rels = local_container['gt_rels']
# 每对关系的主语和宾语的分数相乘
obj_scores_per_rel = obj_scores[pred_rel_inds].prod(1)
# 主语和宾语分数的乘积与它们对应的每个关系类别的分数相乘
nogc_overall_scores = obj_scores_per_rel[:,None] * rel_scores[:,1:]
# 找出总乘积分数前一百对关系索引(主体客体可以单独重复)
nogc_score_inds = argsort_desc(nogc_overall_scores)[:100]
# 前一百对关系的主客体和最高分的类别 [主体序号, 客体序号, 分数最高的关系序号]
nogc_pred_rels = np.column_stack((pred_rel_inds[nogc_score_inds[:,0]], nogc_score_inds[:,1]+1))
# 前一百对关系的关系分数从高到低排序
nogc_pred_scores = rel_scores[nogc_score_inds[:,0], nogc_score_inds[:,1]+1]
nogc_pred_triplets, nogc_pred_triplet_boxes, _ = _triplet(
nogc_pred_rels, pred_classes, pred_boxes, nogc_pred_scores, obj_scores
)
# No Graph Constraint
gt_triplets = local_container['gt_triplets']
gt_triplet_boxes = local_container['gt_triplet_boxes']
iou_thres = global_container['iou_thres']
nogc_pred_to_gt = _compute_pred_matches(
gt_triplets,
nogc_pred_triplets,
gt_triplet_boxes,
nogc_pred_triplet_boxes,
iou_thres,
phrdet=mode=='phrdet',
)
local_container['nogc_pred_to_gt'] = nogc_pred_to_gt
for k in self.result_dict[mode + '_recall_nogc']:
match = reduce(np.union1d, nogc_pred_to_gt[:k])
rec_i = float(len(match)) / float(gt_rels.shape[0])
self.result_dict[mode + '_recall_nogc'][k].append(rec_i)
return local_container
Mean Recall@K (mR@K)
该指标是由我的VCTree以及另一位同学开发的KERN于2019年在CVPR会议上共同提出的研究成果之一。值得注意的是,在这项研究中虽然我是VCTree项目的合作者之一,在附录部分完整展示了实验结果表的数据支撑。然而由于VisualGenome数据集所呈现的长尾分布特征,在传统评估指标Recall中仅掌握核心领域的几个常见关系(如on、near等)便能达到较好的分类效果。这种现象并非我们所追求的目标状态。因此通过计算均值Recall的方法实现了对各类别关系均等关注的效果。这一改进使模型的学习目标从专注于大量重复出现的关系转向全面覆盖各类别关系的学习。
假设有一张图片,并设定其包含的关系类别总数为n个。对于每一个具体的关系类别i(其中i=1,2,...,n),我们都可以计算出对应的该特定类别的检索性能指标(记作R@K)。为了全面评估这张图片的整体检索性能表现,在完成所有具体类别的评估后需要将各具体类别的R@K值进行求平均处理
mR@K=\frac{\sum_1^n{R@K_i}}{n}
def collect_mean_recall_items(self, global_container, local_container, mode):
pred_to_gt = local_container['pred_to_gt']
gt_rels = local_container['gt_rels']
for k in self.result_dict[mode + '_mean_recall_collect']:
# the following code are copied from Neural-MOTIFS
match = reduce(np.union1d, pred_to_gt[:k])
# NOTE: by kaihua, calculate Mean Recall for each category independently
# this metric is proposed by: CVPR 2019 oral paper "Learning to Compose Dynamic Tree Structures for Visual Contexts"
recall_hit = [0] * self.num_rel
recall_count = [0] * self.num_rel
'''
统计该图片的所有真实关系对中出现的每一种关系类别的次数
'''
for idx in range(gt_rels.shape[0]):
local_label = gt_rels[idx,2]
recall_count[int(local_label)] += 1
recall_count[0] += 1
'''
统计该图片的所有有匹配到真实关系对的预测关系对中出现的每一种关系类别的次数
'''
for idx in range(len(match)):
local_label = gt_rels[int(match[idx]),2]
recall_hit[int(local_label)] += 1
recall_hit[0] += 1
'''
对每个关系类别进行统计,对每个关系类别计算该类别下预测关系对的召回率
'''
for n in range(self.num_rel):
if recall_count[n] > 0:
self.result_dict[mode + '_mean_recall_collect'][k][n].append(float(recall_hit[n] / recall_count[n]))
Zero Shot Recall@K (zR@K)
早期研究中也采用过Zero Shot Recall指标来辅助视觉关系识别任务的表现评估。
然而,在后来的研究与 advancements中逐渐被遗忘。
为了弥补这一缺陷,在本研究中重新引入了该指标。
Zero Shot Recall指的是那些虽然在训练数据集中未曾见过的关系类型,并非指从未见过的所有可能的关系类型。
具体而言,在training过程中只针对那些主语-谓语-宾语三元组组合 unseen的情况进行评估。
这些组合所涉及的具体object类别与relation类别仍然是已经被学习器完全掌握过的。
因此,在测试阶段仍可正常进行推理与预测。
按照作者所述,在ZeroShot关系对中指的主语、谓语、宾语的三元组组合是未曾出现在训练中的。
给定一张图片,在训练数据中未出现的真实语义关联被识别出作为真实ZeroShot关系对<主体, 关系, 客体>进行标记,并记其数量为n。随后从图像生成器中获得的前K个候选语义关联被与上述标记的关系进行一一比对筛选出与之匹配的K个候选零样本关联<主体, 关系, 客体>。将这些候选零样本关联的实际存在性进行统计汇总其对应的真实零样本关联对其并集的数量记为m即这些真实零样本关系的总数量(不计重复))。每张图片对应的零样本准确率在第K位(zeroR@K)即:
zeroR@K = \frac{m}{n}
def prepare_zeroshot(self, global_container, local_container):
# 123
gt_rels = local_container['gt_rels']
gt_classes = local_container['gt_classes']
# 这个就是作者所指的训练中没见过的<主, 谓, 宾>三元组组合
zeroshot_triplets = global_container['zeroshot_triplet']
sub_id, ob_id, pred_label = gt_rels[:, 0], gt_rels[:, 1], gt_rels[:, 2]
gt_triplets = np.column_stack((gt_classes[sub_id], gt_classes[ob_id], pred_label)) # num_rel, 3
'''
对真实关系对进行检查,如果某一真实关系对在zeroshot关系对(这是预先写到文件上的)出现,
则记录下这个真实关系对的序号
'''
self.zeroshot_idx = np.where( intersect_2d(gt_triplets, zeroshot_triplets).sum(-1) > 0 )[0].tolist()
def calculate_recall(self, global_container, local_container, mode):
pred_to_gt = local_container['pred_to_gt']
for k in self.result_dict[mode + '_zeroshot_recall']:
# Zero Shot Recall
# 真实关系对中有匹配到预测关系对的序号,记录的是真实关系对的序号
match = reduce(np.union1d, pred_to_gt[:k])
# 如果该图片中某一真实关系对在zeroshot关系对中出现
if len(self.zeroshot_idx) > 0:
# 如果match的类型不是list和tuple类型,将它转换为列表
if not isinstance(match, (list, tuple)):
match_list = match.tolist()
else:
match_list = match
# zeroshot_match: 预测关系对匹配到真实关系对中zeroshot关系对的数量
zeroshot_match = len(self.zeroshot_idx) + len(match_list) - len(set(self.zeroshot_idx + match_list))
zero_rec_i = float(zeroshot_match) / float(len(self.zeroshot_idx))
self.result_dict[mode + '_zeroshot_recall'][k].append(zero_rec_i)
Top@K Accuracy (A@K)
这个是不好的评估方法!
该指标源自某位研究者因误解PredCls与SGCls概念而得出的一种结论,并不建议将其纳入文章报告范围。为此列出此指标仅为警示作用,请各位学者切勿重蹈覆辙。具体而言,在实验过程中该同学不仅为每一个object提供了bounding box标注结果,在同时又对主语-宾语的所有possible pair进行了配对标记。这样一来实际上已偏离Recall这一核心指标的意义范畴,在这种设定下评估的对象已经演变为两个物体间的关联性是否正确。
在 PredCls 和 SGcls 中 ,通过 Top@K Accuracy(A@K)来反映 Recall 的表现 。这也是一种常见的误区 。因为 PredCls 和 SGcls 提供的是所有 object 的边界框信息 ,并非具体的主语-宾语对 。一旦有了 pair 信息 ,则 Recall 的排名就不再重要 。这个误区最初发现于对比损失函数的研究中 。通过持续近两个月的研究与实践 ,我逐渐明白了他所采用的方法为何如此高效 。
在本文中特别指出的一篇论文被称作《Graphical Contrastive Losses for Scene Graph Parsing》,其完整的研究工作主要基于图的对比损失函数用于场景图解析,并可参考原文链接获取更多信息
给定一张图片后,在图像上存在的所有预测的关系元组<主语, 动作, 对象>中,请您找出那些主语与对象同时存在于相应的真实事件中的情况,并将此情况标记为其真值。随后,请您检索与上述标记相关的实际事件元组集合。将此集合中去重后的结果数量作为变量m值。假设该图片已经具有其真实的关系列表,则变量n值即代表该列表中的元素数量。最后计算的结果公式即为:
A@K=\frac{m}{n}
作者意思的个人理解:
当pair信息被确定后,实际上不再存在recall排序,并且仅仅依赖于准确性评估。
通常所说的召回率(Recall),旨在衡量模型在信息检索中的能力。其计算关注的是预测与真实数据之间是否存在完整的对应关系,并非基于实体间的关系配对情况。该方法假设已有一份包含所有可能实体及其关联的对象列表,在这种情况下,召回率排名指标不应仅依赖于部分完整的配对情况
而是给定两个物体,来判断他们relation的正确率。
当给定两个具有相同主体和客体的关系预测对时(即具有相同的实体间联系),仅有的区别在于它们的谓词不同(即实体间联系的性质不同)。如果这些预测的关系谓词与真实的关系谓词不一致,则匹配结果为空。
使用Top@K Accuracy (A@K)来报告成Recall
据我所知,在实际应用中这个评估指标表现还算不错。我认为作者的主要观点是将该指标视为召回率(Recall)而导致召回率出现了显著提升。
def calculate_recall(self, global_container, local_container, mode):
pred_to_gt = local_container['pred_to_gt']
gt_rels = local_container['gt_rels']
for k in self.result_dict[mode + '_accuracy_hit']:
# to calculate accuracy, only consider those gt pairs
# This metric is used by "Graphical Contrastive Losses for Scene Graph Parsing"
# for sgcls and predcls
if mode != 'sgdet':
gt_pair_pred_to_gt = []
'''
pred_to_gt 表示预测关系对<主体, 关系, 客体>在真实关系对中是否出现,如果出现则该项为
该预测关系对对应真实关系对的序号。
self.pred_pair_in_gt 表示预测主客体对<主体, 客体>在真实关系对中是否出现(不考虑关系是否一致),
如果出现则该项为True,否则为False。
'''
for p, flag in zip(pred_to_gt, self.pred_pair_in_gt):
if flag:
gt_pair_pred_to_gt.append(p)
if len(gt_pair_pred_to_gt) > 0:
gt_pair_match = reduce(np.union1d, gt_pair_pred_to_gt[:k])
else:
gt_pair_match = []
self.result_dict[mode + '_accuracy_hit'][k].append(float(len(gt_pair_match)))
self.result_dict[mode + '_accuracy_count'][k].append(float(gt_rels.shape[0]))
参考:
KaihuaTang/Scene-Graph-Benchmark.pytorch
CVPR2020 | 最新最完善的场景图生成 (SGG) 框架,集成目前最全 metrics,已开源
