Advertisement

使用RFBNet训练kaggle RSNA数据检测胸片的肺炎

阅读量:

one stage 的RFBNet在保证速度的前提下,也有着不错的精度,所以拿来训练kaggle上的RSNA。这边主要介绍下对RFBnet源码修改支持RSNA的训练,如果想看关于RSNA数据分析的,可以去看kaggle上的kernels。

数据集介绍

RSNA跟常见的检测数据集(COCO,VOC,BDD100K,CITYSCAPE等)不一样的一个地方就是,图片中可能不存在标注,也就是说不存在foreground,我就隐隐觉得源码可能不支持这种情况,果然写完dataloader之后报错了,然后就需要修改源码了。

源码修改

1.自己写个支持RSNA的dataloader

大家都有自己的风格,主要就是:

1.用SimpleITK读dicom

2.当前图像没有标注时,load annotation返回 np.zeros((1, 5))

2.修改multibox_loss.py

源码会根据foreground的数量,按一定比例取一些background,但是如果没有foreground,background也没有,算正负样本分类的交叉熵就会报错。

我添加了一段逻辑,如果没有foreground,就选择4个background进行计算,对应下面代码55-58。

复制代码
     def forward(self, predictions, priors, targets):

    
     """Multibox Loss
    
     Args:
    
         predictions (tuple): A tuple containing loc preds, conf preds,
    
         and prior boxes from SSD net.
    
             conf shape: torch.size(batch_size,num_priors,num_classes)
    
             loc shape: torch.size(batch_size,num_priors,4)
    
             priors shape: torch.size(num_priors,4)
    
   10.             ground_truth (tensor): Ground truth boxes and labels for a batch,
    
             shape: [batch_size,num_objs,5] (last idx is the label).
    
     """
    
  
    
     loc_data, conf_data = predictions
    
     priors = priors
    
     num = loc_data.size(0)
    
     num_priors = (priors.size(0))
    
     num_classes = self.num_classes
    
  
    
     # match priors (default boxes) and ground truth boxes
    
     loc_t = torch.Tensor(num, num_priors, 4)
    
     conf_t = torch.LongTensor(num, num_priors)
    
     for idx in range(num):
    
         truths = targets[idx][:, :-1].data
    
         labels = targets[idx][:, -1].data
    
         defaults = priors.data
    
         match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
    
     if GPU:
    
         loc_t = loc_t.cuda()
    
         conf_t = conf_t.cuda()
    
     # wrap targets
    
     loc_t = Variable(loc_t, requires_grad=False)
    
     conf_t = Variable(conf_t, requires_grad=False)
    
  
    
     pos = conf_t > 0
    
  
    
     # Localization Loss (Smooth L1)
    
     # Shape: [batch,num_priors,4]
    
     pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
    
     loc_p = loc_data[pos_idx].view(-1, 4)
    
     loc_t = loc_t[pos_idx].view(-1, 4)
    
     loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
    
  
    
     # Compute max conf across batch for hard negative mining
    
     batch_conf = conf_data.view(-1, self.num_classes)
    
     loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
    
  
    
     # Hard Negative Mining
    
     loss_c[pos.view(-1, 1)] = 0  # filter out pos boxes for now
    
     loss_c = loss_c.view(num, -1)
    
     _, loss_idx = loss_c.sort(1, descending=True)
    
     _, idx_rank = loss_idx.sort(1)
    
     num_pos = pos.long().sum(1, keepdim=True)
    
  
    
     constant_min = torch.ones(num_pos.shape, dtype=torch.int64) 
    
     neg_min = torch.max(self.negpos_ratio * num_pos, constant_min.cuda())
    
     num_neg = torch.clamp(neg_min, max=pos.size(1) - 1)
    
     neg = idx_rank < num_neg.expand_as(idx_rank)
    
  
    
     # Confidence Loss Including Positive and Negative Examples
    
     pos_idx = pos.unsqueeze(2).expand_as(conf_data)
    
     neg_idx = neg.unsqueeze(2).expand_as(conf_data)
    
     conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
    
     targets_weighted = conf_t[(pos + neg).gt(0)]
    
     loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
    
  
    
     # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
    
  
    
     N = max(num_pos.data.sum().float(), 1)
    
     loss_l /= N
    
     loss_c /= N
    
     return loss_l, loss_c

全部评论 (0)

还没有任何评论哟~