车辆重识别相关Loss
发布时间
阅读量:
阅读量
class TripletLoss(object):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, margin=None):
self.margin = margin
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_feat, labels, mask=None, normalize_feature=False):
"""
:param global_feat:
:param labels:
:param mask: [N, N] 可见性mask。不可见的mask将不会被选择。若全部不可见,则对结果*0
:param normalize_feature:
:return:
"""
if normalize_feature:
global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = hard_example_mining(
dist_mat, labels, mask=mask)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an
rankingloss用的是nn.MarginRankingLoss(margin=margin) :F.soft_margin_loss(input, target, reduction=self.reduction)
nn.SoftMarginLoss() :
F.margin_ranking_loss(input1, input2, target
目的都是拉近ap的距离,拉远an的距离。
an和ap是在同一个batch中计算出来的,这里的batch_size=64
全部评论 (0)
还没有任何评论哟~
