Advertisement

车辆重识别相关Loss

阅读量:
复制代码
    class TripletLoss(object):
    """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
    Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
    Loss for Person Re-Identification'."""
    
    def __init__(self, margin=None):
        self.margin = margin
        if margin is not None:
            self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        else:
            self.ranking_loss = nn.SoftMarginLoss()
    
    def __call__(self, global_feat, labels, mask=None, normalize_feature=False):
        """
    
        :param global_feat:
        :param labels:
        :param mask:  [N, N] 可见性mask。不可见的mask将不会被选择。若全部不可见,则对结果*0
        :param normalize_feature:
        :return:
        """
        if normalize_feature:
            global_feat = normalize(global_feat, axis=-1)
        dist_mat = euclidean_dist(global_feat, global_feat)
        dist_ap, dist_an = hard_example_mining(
            dist_mat, labels, mask=mask)
        y = dist_an.new().resize_as_(dist_an).fill_(1)
        if self.margin is not None:
            loss = self.ranking_loss(dist_an, dist_ap, y)
        else:
            loss = self.ranking_loss(dist_an - dist_ap, y)
        return loss, dist_ap, dist_an

rankingloss用的是nn.MarginRankingLoss(margin=margin)F.soft_margin_loss(input, target, reduction=self.reduction)

nn.SoftMarginLoss()
F.margin_ranking_loss(input1, input2, target

目的都是拉近ap的距离,拉远an的距离。

an和ap是在同一个batch中计算出来的,这里的batch_size=64

全部评论 (0)

还没有任何评论哟~