Advertisement

5分钟速成半监督医学图像分割

阅读量:

5分钟速成半监督医学图像分割

本文所述的所有资源均可在传知代码平台上获取。

文章目录

    • 5分钟速成半监督医学图像分割
    • 摘要
      • 一、研究方法概述
      • 研究框架介绍
      • 数据预处理方案
    • 二、模型架构分析
    • 三、技术实现细节
      • 基于PyTorch的实现框架
      • 训练策略探讨
  • 动态一致性阈值的计算方法及其实现细节
    • 分解一致性指标的具体计算公式

    • 阈值设定依据与优化策略

      • 四、复现过程

        • 部署方式

概述

本文旨在介绍一篇发表于MICCAI 2023的医学图像分割领域的文章《Decoupled Consistency for Semi-supervised Medical Image Segmentation》。该研究提出了一种新型的半监督医学图像分割框架。这种新方法通过充分运用预测数据实现了对其的有效分解与重组,在最大化各功能优势的同时提升了整体性能。

一、论文思路

对于半监督的学习场景而言,在医学分割领域中采用伪标签方法存在不足之处:一方面它能够过滤掉低置信度像素点;另一方面却未能充分挖掘高置信度与低置信度样本的价值。因此这两种策略均未能有效利用无标签样本的价值特性。
本文提出了一种全新的解耦一致性半监督医学图像分割框架。
该框架首先通过动态阈值将预测结果分解为一致部分与不一致部分。
对于一致部分,则采用了交叉伪监督的方式进行优化训练。
而对于不一致部分,则进一步将其划分为两类:一类是可能靠近决策边界的少数关键样本;另一类则是更容易集中在高密度区域的辅助样本。
这些关键少数样本将被重新导向辅助样本的方向进行优化训练。
值得注意的是,在这一过程中我们特别强调了方向一致性这一特性。
此外为了最大化利用现有数据资源我们将特征图引入到模型训练流程中并对特征图的一致性进行了损失计算。

二、模型介绍

这篇文章的模型图如下图所示:

在这里插入图片描述

如图所示,在DC-Net架构中包含一个编码器以及两个独立但高度相关的解码器网络。其中A解码器通过双线性插值实现上采样过程而B解码器则采用反卷积操作完成上采样任务。在训练过程中针对带有标签的数据样本我们采用标准的监督学习方法计算其与真实目标之间的交叉熵损失Lseg;而对于网络中的一致区域则引入交叉伪监督损失Lcps来指导其收敛;针对不一致区域则采用方向一致性损失Ldc来进行优化;最后对所有特征层均计算其对应的特征一致性损失Lf以确保模型能够有效捕捉到多尺度的空间关系信息

三、细节分析
动态一致性阈值

研究表明,在训练初期,
以提升无监督数据利用效率的同时促进伪标签多样性,
研究者建议将γ设置为相对较小的值。
随着模型训练进展,
研究者建议维持一个稳定的伪标签比例。
从而得出了一致性阈值的定义。

在这里插入图片描述

其中B被定义为批量大小,在机器学习模型训练过程中逐渐增大的权重系数λ满足λ = t/tmax的形式。通过这一机制,在模型训练初期赋予较大的权重系数以促进更快的学习速率,在后期则逐步降低以避免过拟合现象的发生。为了获取更多的未标注数据样本以提升模型泛化能力,在计算过程中我们需要对pA和pB这两个概率估计结果进行置信度评估,并选取较低的置信度作为一致性的基准标准。在初始化阶段将λt设为1/C(其中C表示类别数量),随后根据实际训练情况动态调整其取值范围以适应不同任务需求。最终的一致性阈值γt被系统性地确定和优化。

在这里插入图片描述
分解一致性

本文将不一致的部分分解为不可靠数据与引导数据两类指标,并分析了这两类指标在不同区域的表现特征。其中不可靠数据可能位于决策边界区域,而引导数据通常会出现在高密度区域。两者的索引信息一致,但其区别在于引导数据相较于不可靠数据表现出更高的信心程度。基于平滑假设,这两部分的预测结果应当保持一致,并且集中在高密度区域。因此,在优化模型时应重点关注决策边界附近的像素点。我们首先提升这些像素置信度水平。
以下是具体的优化步骤:

在这里插入图片描述

其中o代表模型输出变量(output variable),而T取值于区间(0,1),用于调节锐化程度(degree of sharpening)。在实验设置中,默认将温度系数设定为T=0.5(default temperature setting)。通过对比分析SpA与SpB的结果(results of SpA and SpB),我们可以分别提取出高置信度区域(high confidence region)及其对应的hSpA和hSpB(hSpA and hSpB)以及低置信度区域(low confidence region)及其对应的lSpA和lSpB(lSpA and lSB)。为了衡量两区域间的相似性(similarity),我们采用了L2范数作为衡量标准(standard measure)。需要注意的是,在实际计算过程中仅对低置信度样本进行优化处理(optimization processing),而无需对高置信度样本进行梯度反向传播计算(reverse propagation calculation)。基于此原则的方向一致性损失项则可表示为:

在这里插入图片描述

对于一致性部分:采用交叉伪监督的方法对系统的一致性进行优化处理。具体细节如下:

在这里插入图片描述

PLA和PLB分别代表对应的伪标签。
在特征部分这一块内容中,在模型训练过程中实施了特征图的提取与整合操作,并通过进一步挖掘和利用数据来提升模型性能。
在实现这一目标的过程中,在模型训练阶段实施了平均映射操作来降低计算复杂度。

在这里插入图片描述

映射过程如下:

在这里插入图片描述

其中p值大于1,在本研究中fm表示第m层的特征图而fmi则代表fm在通道维度上的第i个切片最终经过映射运算得到的结果记为f¯m。为了确保实验的有效性我们在实际操作中采用了动态调节的方式将参数设置为固定值即p=2在此基础上我们定义了基于特征一致性的损失函数用于评估模型性能。

在这里插入图片描述

其中 N 表示 f¯mi 这一像素层所包含的像素数量;网络层数设为 n;解码器中的第 m 层特征图与编码器中的第 m 层特征图分别具有各自的第 i 个像素值。本文仅采用解码器 B 对应的特征层来进行损失计算。
最后计算得到的整体损失由上述提到的分割 loss 以及另外三个子损失之和构成:

在这里插入图片描述

我们采用 Lseg 作为在小批量带标签数据上应用的一种Dice损失函数。在实验设计中,默认情况下设置了一个与迭代次数相关的预热机制[29]。具体而言,在每一轮迭代中,默认情况下设置参数 β 的值为 \beta = e^{-5(1-\frac{t}{t_{\text{max}}})^2}。其中,在实际运行过程中则根据实验结果动态地调整参数 λ 和 α 的值分别为 \lambda=1-\beta\alpha=0.3

四、复现过程
在这里插入图片描述

这是我们在云端存储的源码根目录。接着该研究团队提供了基于10%标注数据的预训练模型,并将其命名为ACDC数据集。随后我们将该模型导入测试环境进行评估。具体而言,在处理过程中发现作者将预训练好的模型放置于特定的文件夹中。只需将其重新组织至目标路径即可。最后通过执行test_acdc.py脚本即可获取所需结果。

在这里插入图片描述

除了这个指标性能,我们还可以得到预测到的3D医学图像:

在这里插入图片描述

在该场景中我们可以采用专业的3D图像可视化软件ITK-SNAP来进行数据展示,在此过程中能够清晰地观察到各部分的空间关系。

在这里插入图片描述
在这里插入图片描述

这里在运行test_acdc.py文件之前,还需要做好数据集的准备,首先需要获取ACDC和PROMISE12数据集。这里可以我之前的一篇博客半监督的医学分割数据集(LA, Pancreas, ACDC和PROMISE12)分享,这篇博客提供了四个公开的数据集LA, PANCREAS, ACDC以及PROMISE12数据集,都是处理好的数据集,放对地方就可以直接运行代码。
下面我将展示一下核心代码,也添加了一些注释进行讲解说明:

复制代码
    output1_soft = F.softmax(output1, dim=1)
    output2_soft = F.softmax(output2, dim=1)
    output1_soft0 = F.softmax(output1 / 0.5, dim=1)
    output2_soft0 = F.softmax(output2 / 0.5, dim=1)
    # 这里是预测输出的锐化过程
    with torch.no_grad():
    max_values1, _ = torch.max(output1_soft, dim=1)
    max_values2, _ = torch.max(output2_soft, dim=1)
    percent = (iter_num + 1) / max_iterations
    
    cur_threshold1 = (1 - percent) * cur_threshold + percent * max_values1.mean()
    cur_threshold2 = (1 - percent) * cur_threshold + percent * max_values2.mean()
    mean_max_values = min(max_values1.mean(), max_values2.mean())
    
    cur_threshold = min(cur_threshold1, cur_threshold2)
    cur_threshold = torch.clip(cur_threshold, 0.25, 0.95)
    
    mask_high = (output1_soft > cur_threshold) & (output2_soft > cur_threshold)
    mask_non_similarity = (mask_high == False)
    # 这里是动态阈值部分的实现,这里阈值的初始值是0.25,也就是类别的倒数,然后这个值会快速地上升,最大值为0.95. 这里由这个阈值可以得到一致的高阈值区域和不一致区域。
    
    new_output1_soft = torch.mul(mask_non_similarity, output1_soft)
    new_output2_soft = torch.mul(mask_non_similarity, output2_soft)
    high_output1 = torch.mul(mask_high, output1)
    high_output2 = torch.mul(mask_high, output2)
    high_output1_soft = torch.mul(mask_high, output1_soft)
    high_output2_soft = torch.mul(mask_high, output2_soft)
    
    pseudo_output1 = torch.argmax(output1_soft, dim=1)
    pseudo_output2 = torch.argmax(output2_soft, dim=1)
    pseudo_high_output1 = torch.argmax(high_output1_soft, dim=1)
    pseudo_high_output2 = torch.argmax(high_output2_soft, dim=1)
    
    max_output1_indices = new_output1_soft > new_output2_soft  # output1 距离近的像素的位置
    
    max_output1_value0 = torch.mul(max_output1_indices, output1_soft0)
    min_output2_value0 = torch.mul(max_output1_indices, output2_soft0)
    
    max_output2_indices = new_output2_soft > new_output1_soft  # output2 距离远的像素的位置
    
    max_output2_value0 = torch.mul(max_output2_indices, output2_soft0)
    min_output1_value0 = torch.mul(max_output2_indices, output1_soft0)
    # 上面这段代码就是利用一致性区域和非一致性区域的处理过程
    
    loss_dc0 = 0
    loss_cer = 0
    loss_at_kd = criterion_att(encoder_features, decoder_features2)
    
    
    loss_dc0 += mse_criterion(max_output1_value0.detach(), min_output2_value0)
    loss_dc0 += mse_criterion(max_output2_value0.detach(), min_output1_value0)
    
    loss_seg_dice += dice_loss(output1_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
    loss_seg_dice += dice_loss(output2_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
    
    
    if mean_max_values >= 0.95:
     loss_cer += ce_loss(output1, pseudo_output2.long().detach())
     loss_cer += ce_loss(output2, pseudo_output1.long().detach())
    else:
     loss_cer += ce_loss(high_output1, pseudo_high_output2.long().detach())
     loss_cer += ce_loss(high_output2, pseudo_high_output1.long().detach())
    
    
    consistency_weight = get_current_consistency_weight(iter_num // 150)
    supervised_loss = loss_seg_dice
    loss = supervised_loss + (1-consistency_weight) * (1000 * loss_at_kd) + consistency_weight * (1000 * loss_dc0 ) + 0.3 * loss_cer
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

部署方式

这里只需按照要求安装指定的PyTorch版本(1.8.1)、CUDA版本(10.1)以及Python 3.6.13。
除了这些硬性要求之外,
还需额外安装一些其他Python库,
当程序运行时会自动检测缺少哪些模块,
然后可以通过pip命令进行安装。

访问文章中的代码资源库,请通过下方链接进行[附件下载]操作。

全部评论 (0)

还没有任何评论哟~