Advertisement

5分钟速成半监督医学图像分割

阅读量:

在医学影像领域追求快速精准的信息提取对于生命的延续至关重要。然而面对复杂的影像数据和海量信息传统基于全监督的分割算法因过分依赖耗时费力的大规模标注数据而显得力不从心这种局限不仅制约了分析效率也加重了医疗工作者的工作负担设想一下假如我们能在短短五分钟内实现高精度的医学影像分割这将如何改变医疗界的实践?

目录

概述

核心逻辑

复现过程

写在最后


概述

这里我将阐述一篇发表于MICCAI 2023的一篇医学图像分割领域的研究文章https://link.springer.com/chapter/10.1007/978-3-031-43907-0_53

对于半监督的学习任务而言,在医学图像分割领域中采用伪标签策略时会剔除置信度较低的像素单元,并且传统的统一正则化手段未能充分挖掘高置信度与低置信度样本的价值优势。鉴于此,在无标签样本应用方面仍显不足。本文提出了一种新型解耦型半监督医学图像分割框架。具体而言,在预测结果的基础上通过动态阈值划分预测结果为一致区域和不一致区域。对于一致区域,则采用交叉伪监督策略实施优化训练;而对于不一致区域,则进一步将其划分为可能位于决策边界附近的不可靠样本以及可能位于高密度区域的引导样本两类子集。这些不可靠样本将朝向引导样本的方向进行优化训练以实现方向一致性目标;同时为了更好地发挥数据潜力,在训练过程中引入特征图并计算特征一致性损失函数以辅助模型收敛。

核心逻辑

这篇文章的模型图如下图所示:

如图所示, DC-Net架构包含一个编码器及两个一致性的解码器.其中,A型解码器采用双线性插值完成上采样,B型则采用反卷积完成上采样.针对带标签的数据,我们计算其与真实目标间的差异损失Lseg;针对一致的部分,引入交叉伪监督式的损失Lcps;针对不一致的部分,则定义方向一致性相关的损失Ldc;而对于特征图层,则定义了特征一致性相关的损失Lf.

动态一致性阈值:研究者[23]通过实验表明,在训练初期阶段应当将γ设置为相对较小的数值以充分利用未标注数据并促进伪标签的多样化。随着模型训练过程的发展γ应维持一个稳定的伪标签比例其中B表示批量大小λ是一个会随着时间递增的因素系数我们将其设定为λ = t/tmax(t代表当前迭代次数tmax代表最大迭代次数)。为了尽可能多地获取未标注数据我们对pA和pB设置了动态评估机制并采用较小的阈值作为我们的统一判别标准初始时我们将λt设为1/C其中C表示分类器类别数量。

分解一致性 :该文章将不一致的元素分解为不可靠的数据与引导的数据。其中不可靠的数据可能位于决策边界附近,而引导的数据更容易出现在高密度区域。这些部分具有相同的索引信息,在其区别在于引导的数据通常比不可可靠的数据更具信心。基于平滑假设(smoothness assumption),这些部分的输出应当保持一致性,并集中于高密度区域。因此,在优化策略中应特别关注决策边界附近的像素,并通过增强这些像素的信任度使它们趋向于靠近高密度区域。

复现过程

这是从我们的源码库中获取的项目目录。随后,在此项目中, 作者为我们提供了基于ACDC数据集构建的预训练模型, 在使用10%标注数据样本的情况下.接下来, 我们将上述预训练模型直接应用于测试环节.

我们是从网络上获取了这个源码目录,并在此基础上提供了相关说明和实验代码的具体位置说明。接着,在10%标注数据样本条件下,作者提供了ACDC数据集的预训练模型,并要求我们将该模型直接用于测试目的。问题在于,作者将该模型放置在ACDC_7目录中,并建议只需将该模型移动至ACDC_mcnet_kd_DCNet_7_labeled目录即可完成相应的测试操作。运行test_acdc.py文件后,预期结果即可显现.下图是实现的结果:

除了这个指标性能,我们还可以得到预测到的3D医学图像:

建议采用该款专业的3D图像可视化工具ITK-SNAP进行操作演示,并观察其具体表现为。

在这里,在启动test_acdc.py文件之前,在线准备好必要的数据集。第一步需要获取ACDC和PROMISE12两个数据集。接下来我会展示一些核心代码内容,并加入了一些注释来辅助说明。

复制代码
 output1_soft = F.softmax(output1, dim=1)

    
 output2_soft = F.softmax(output2, dim=1)
    
 output1_soft0 = F.softmax(output1 / 0.5, dim=1)
    
 output2_soft0 = F.softmax(output2 / 0.5, dim=1)
    
 # 这里是预测输出的锐化过程
    
 with torch.no_grad():
    
     max_values1, _ = torch.max(output1_soft, dim=1)
    
     max_values2, _ = torch.max(output2_soft, dim=1)
    
     percent = (iter_num + 1) / max_iterations
    
  
    
     cur_threshold1 = (1 - percent) * cur_threshold + percent * max_values1.mean()
    
     cur_threshold2 = (1 - percent) * cur_threshold + percent * max_values2.mean()
    
     mean_max_values = min(max_values1.mean(), max_values2.mean())
    
  
    
     cur_threshold = min(cur_threshold1, cur_threshold2)
    
     cur_threshold = torch.clip(cur_threshold, 0.25, 0.95)
    
  
    
 mask_high = (output1_soft > cur_threshold) & (output2_soft > cur_threshold)
    
 mask_non_similarity = (mask_high == False)
    
 # 这里是动态阈值部分的实现,这里阈值的初始值是0.25,也就是类别的倒数,然后这个值会快速地上升,最大值为0.95. 这里由这个阈值可以得到一致的高阈值区域和不一致区域。
    
  
    
 new_output1_soft = torch.mul(mask_non_similarity, output1_soft)
    
 new_output2_soft = torch.mul(mask_non_similarity, output2_soft)
    
 high_output1 = torch.mul(mask_high, output1)
    
 high_output2 = torch.mul(mask_high, output2)
    
 high_output1_soft = torch.mul(mask_high, output1_soft)
    
 high_output2_soft = torch.mul(mask_high, output2_soft)
    
  
    
 pseudo_output1 = torch.argmax(output1_soft, dim=1)
    
 pseudo_output2 = torch.argmax(output2_soft, dim=1)
    
 pseudo_high_output1 = torch.argmax(high_output1_soft, dim=1)
    
 pseudo_high_output2 = torch.argmax(high_output2_soft, dim=1)
    
  
    
 max_output1_indices = new_output1_soft > new_output2_soft  # output1 距离近的像素的位置
    
  
    
 max_output1_value0 = torch.mul(max_output1_indices, output1_soft0)
    
 min_output2_value0 = torch.mul(max_output1_indices, output2_soft0)
    
  
    
 max_output2_indices = new_output2_soft > new_output1_soft  # output2 距离远的像素的位置
    
  
    
 max_output2_value0 = torch.mul(max_output2_indices, output2_soft0)
    
 min_output1_value0 = torch.mul(max_output2_indices, output1_soft0)
    
 # 上面这段代码就是利用一致性区域和非一致性区域的处理过程
    
  
    
 loss_dc0 = 0
    
 loss_cer = 0
    
 loss_at_kd = criterion_att(encoder_features, decoder_features2)
    
  
    
  
    
 loss_dc0 += mse_criterion(max_output1_value0.detach(), min_output2_value0)
    
 loss_dc0 += mse_criterion(max_output2_value0.detach(), min_output1_value0)
    
  
    
 loss_seg_dice += dice_loss(output1_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
    
 loss_seg_dice += dice_loss(output2_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
    
  
    
  
    
 if mean_max_values >= 0.95:
    
      loss_cer += ce_loss(output1, pseudo_output2.long().detach())
    
      loss_cer += ce_loss(output2, pseudo_output1.long().detach())
    
 else:
    
      loss_cer += ce_loss(high_output1, pseudo_high_output2.long().detach())
    
      loss_cer += ce_loss(high_output2, pseudo_high_output1.long().detach())
    
  
    
  
    
 consistency_weight = get_current_consistency_weight(iter_num // 150)
    
 supervised_loss = loss_seg_dice
    
 loss = supervised_loss + (1-consistency_weight) * (1000 * loss_at_kd) + consistency_weight * (1000 * loss_dc0 ) + 0.3 * loss_cer
    
    
    
    
    代码解读

写在最后

在医学图像分析领域中,准确高效的分割技术起到关键作用。然而,在医学图像领域中获取高质量标注样本往往面临巨大挑战。本研究提出的新型速成方案——5分钟掌握半监督医学图像分割技术——作为一个创新性解决方案,在短时间内提供快速实现高质量分割的可能性。该方案不仅显著减少了对大量高质量标注数据的需求,并且利用智能算法的高度高效性和灵活性,在提升分割精度的同时大幅提高了处理速度。

随着半监督学习技术不断取得新的进展

完整复现流程中涉及的项目源代码、详细的数据集以及经过预先训练的模型均可从该文章下方链接附件获取

全部评论 (0)

还没有任何评论哟~