Advertisement

半监督医学图像分割(二):Contour-aware consistency for semi-supervised medical image segmentation

阅读量:

BSPC-2024 Contour-aware consistency for semi-supervised medical image segmentation

  • 研究背景及动机#

    • 背景:
    • 动机:
  • 主要贡献

  • 方法

    • 1 轮廓增强解码器及损失计算
    • 2 自对比策略
    • 3 损失函数
    • 数据集
    • 结果可视化
    • 相关应用
  • 总结

pub:2024 Biomedical Signal Processing and Control
[ paper] [

复制代码
(https://github.com/SmileJET/CAC4SSL)]

## 研究背景及动机#

### 背景:

1 对于医学专家而言,绘制可靠标注的工作繁琐且耗时,而且由于专家的主观性,人工标注也可能造成一定的分割差异  
2 医疗机构中通常存在大量的未标记数据,充分发挥未标记数据的作用

### 动机:

1医学图像具有模糊的边缘,大多数方法没有对边缘明确建模,对于边缘区域的预测是不可靠的。  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/U3QEhKNalVJxb0O4qmfYzy7ovBT5.png)图1是MC-Net+在5%、10%和20%标签的分割结果,可以看到:该模型在边缘区域容易产生误差,数据越少,边缘的错误率则越高。  
2 现有方法很少同时利用两种信息

## 主要贡献

1 提出了**轮廓感知一致性框架** :由轮廓增强解码器和自对比策略组成。  
2 轮廓增强解码器,专门用于训练阶段。解码器根据**预测** 的**概率图计算轮廓** 以增强相应的特征  
3 引入一种**自对比策略** 来优化分割结果中的不确定区域。

## 方法

### 1 轮廓增强解码器及损失计算
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/rOi1dtfTan6yE0XJxMqVok9W5pUI.png)整体模型是VNet,编码器和跳跃连接层不变,在解码器部分添加轮廓增强解码器作为辅助解码器用以细化目标预测的边界,使用分割头𝑆来生成预测概率图。然后,通过模拟形态学操作中的腐蚀和膨胀操作,从概率图得到轮廓图,公式如下:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/7TuY6fJvdRD59nQx4rshKFwqym8O.png)其中𝐷为类别数,𝑓𝑐为图像轮廓,𝑅𝑒𝐿𝑈为激活函数  
最后使用残差结构得到特征增强:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/yRWDJqY8QTVp4iweC2mjkZSsNMX7.png)M为合并操作。  
代码解析(完整代码请参考github):
复制代码
		self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
    self.out_conv_8 = nn.Conv3d(n_filters*2, n_classes, 1, padding=0)
    self.out_conv_7 = nn.Conv3d(n_filters*4, n_classes, 1, padding=0)
    self.out_conv_6 = nn.Conv3d(n_filters*8, n_classes, 1, padding=0)
    self.out_conv_5 = nn.Conv3d(n_filters*16, n_classes, 1, padding=0)

    self.dropout = nn.Dropout3d(p=0.5, inplace=False)

    self.proj = nn.Sequential(nn.Conv3d(n_filters, n_filters, kernel_size=3, padding=1),
                              nn.PReLU(),
                              nn.Conv3d(n_filters, emb_num, kernel_size=3, padding=1)
                              )
                              
		out_seg = self.out_conv(x9)
    out_8 = self.out_conv_8(x8)
    out_7 = self.out_conv_7(x7)
    out_6 = self.out_conv_6(x6)
    out_5 = self.out_conv_5(x5)

    proj_out = self.proj(x9)
    
    return out_seg, out_8, out_7, out_6, out_5, proj_out


python
复制代码
轮廓增强解码器在VNet的解码器使用卷积核大小为1的普通卷积得到每个解码器的预测概率,proj_out见自对比策略部分。  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/QC7SorMimeIzG4daTtlB3W9E2j8q.png)对比损失的图示:(1)主解码器生成的预测。(2)辅助解码器生成的预测。(3)预测的不确定区域。

(4)对比损失的不确定区域样本。左边红点表示𝑘th被选中的点,右边红点表示作为正样本的另一个解码器中的对应点,绿色的点代表负样本。

如图4所示,对比学习主要改进了预测的不确定区域,即两个解码器的预测不一致的区域:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/ZqvBL6dGfl47hYPyuFXkb9oej2D8.png)

### 2 自对比策略

在训练阶段,两个解码器预测结果不一致的区域被认为是不确定区域。由于在未标记数据的类别预测结果是不准确的,因此我们将来自不同解码器的相同像素视为正样本,而将所有其他像素视为负样本。该策略在不考虑类别的情况下,有效地减轻了预测噪声的影响。该策略旨在减少解码器对同一像素的预测之间的差异,并增强其与其他像素的区分。此外,自对比只需要一个原始图像,不需要额外的增强或其他负样本图像。  
具体公式如下:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/1x0emgfJpB7Pb24IHTqjnsMODyXQ.png)  
在解码器最后一层添加了一个project head,用于对比学习。包括两个卷积层和一个𝑃 𝑅𝑒𝐿 _U_ 激活函数,将高维特征进行向量归一化。  
对于正、负样本的选择:随机选择不一致区域𝐾中的𝑄点作为锚点,对于锚点中给定的像素𝑘,其正样本𝑘+是另一个编码器对应位置上的𝑘th像素。负样本𝑘−是锚点中𝑘th像素以外的点。将主解码器的对比损失定义为:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/ScDxAfC4kM08WrLY9ZdEwuJ5PRaF.png)  
Γ是温度参数,𝑠𝑖 _m_ 是余弦相似度,公式为:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/C6FI7hHyrNA2xQLMbptDkRBPiZVO.png)  
总的对比损失由主编码器损失和辅助编码器损失两部分构成:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/laYMn3deQhOAZocwRT9qS7EVBFNJ.png)
复制代码
class Contrast(nn.Module):

def __init__(self, temperature: float = 0.07, sample_num: int = 50, bidirectional: bool = True):
    super(Contrast, self).__init__()
    self.tau = temperature

def forward(self, proj_list, idx, pseudo_label, mask, sample_num=5):
    batch_size = mask.shape[0]
    loss = 0
    
    curr_proj = None
    pos_proj = []
    for i in range(len(proj_list)):
        try:
            proj = proj_list[i].permute(0, 2, 3, 1)
        except:
            proj = proj_list[i].permute(0, 2, 3, 4, 1)

        proj = proj.contiguous().view(proj.shape[0], -1, proj.shape[-1])
        if i == idx:
            curr_proj = F.normalize(proj, dim=-1)
        else:
            pos_proj.append(F.normalize(proj.unsqueeze(1), dim=-1))
    pos_proj = torch.cat(pos_proj, dim=1)

    mask = mask.contiguous().view(batch_size, -1).long()
    fn_mask = 1-mask
    
    for b_idx in range(batch_size):
        mask_ = mask[b_idx]
        fn_mask_ = fn_mask[b_idx]
        c_proj = curr_proj[b_idx]
        p_proj = pos_proj[b_idx]

        hard_indices = fn_mask_.nonzero()

        num_hard = hard_indices.shape[0]

        hard_sample_num = min(sample_num, num_hard)

        hard_perm = torch.randperm(num_hard)
        hard_indices = hard_indices[hard_perm[:hard_sample_num]]
        indices = hard_indices

        c_proj_selected = c_proj[indices].squeeze(dim=1)
        p_proj_selected = p_proj[:, indices].squeeze(dim=2)


        pos_loss_item = F.cosine_similarity(c_proj_selected, p_proj_selected, dim=-1).sum(0)
        pos_loss_item = torch.exp(pos_loss_item / self.tau)
        matrix = F.cosine_similarity(c_proj_selected.unsqueeze(dim=1), c_proj_selected.unsqueeze(dim=0), dim=-1)
        matrix = torch.exp(matrix / self.tau)
        neg_loss_item = matrix.sum(dim=0) - torch.diagonal(matrix)

        loss += -torch.log(pos_loss_item / (pos_loss_item + neg_loss_item + 1e-8)).mean()

    return loss / batch_size


python
复制代码
### 3 损失函数

采用深度监督机制:解码器的不同层次添加分割头,以产生不同尺度的预测。  
对于标记的数据,我们利用Dice Loss进行监督,其定义为:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/egiafMluQ5p4NnG9DUcsRPbHwxEF.png)  
𝑝𝑟i为主编码器,aux为辅助编码器,𝐷表示解码器的𝑑th级,$\varepsilon$为平滑参数

使用两个解码器来输出预测结果${\widehat y}_{pri},{\widehat y}_{aux}$,然后对结果求平均值,得到伪标签  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/vUabeFq2TBlQVy5wH3zSIp1jP8WK.png)  
为了监督带有伪标签的训练,采用Dice loss和Cross Entropy loss:![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/oBQa0hSKN7I5MVnuLxlc1R2HJGY8.png)  
总损失为:  
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/xPvH3j12k8BuadmftS7silUAehzG.png)
复制代码
			outputs = model(volume_batch)
        proj_list = outputs[-2:]
        outputs = outputs[:-2]

        num_outputs = len(outputs) // single_out_num

        y_ori = []
        y_ori_softmax = []
        y_pseudo_label = []
        y_mask = []

        for idx in range(single_out_num):
            y_ori.append(torch.zeros((num_outputs,)+outputs[idx].shape, device=volume_batch.device))
            y_ori_softmax.append(torch.zeros((num_outputs,)+outputs[idx].shape, device=volume_batch.device))


        loss_seg = 0
        loss_seg_dice = 0 

        for idx in range(num_outputs):
            for sub_idx in range(single_out_num):
                true_idx = idx * single_out_num + sub_idx
                y = outputs[true_idx][:labeled_bs,...]
                y_prob = F.softmax(y, dim=1)
                loss_seg_dice += dice_loss(y_prob[:,1,...], label_batches[sub_idx].squeeze(1)[:labeled_bs,...] == 1)

                y_all = outputs[true_idx]
                y_ori[sub_idx][idx] = y_all
                y_prob_all = F.softmax(y_all, dim=1)
                y_ori_softmax[sub_idx][idx] = y_prob_all
                
        for idx in range(single_out_num):
            out_0 = y_ori[idx][0].argmax(dim=1)
            out_1 = y_ori[idx][1].argmax(dim=1)
            mask = out_0==out_1
            y_mask.append(mask)
            y_pseudo_label.append(y_ori_softmax[idx].mean(dim=0).argmax(dim=1))

        loss_consist = 0
        for i in range(single_out_num):
            for j in range(num_outputs):
                loss_consist += F.cross_entropy(y_ori[i][j], y_pseudo_label[i].long())
                loss_consist += dice_loss(y_ori_softmax[i][j][:, 1, ...], y_pseudo_label[i]==1)

        loss_contrast = 0
        for i in range(num_outputs):
            loss_contrast += contrast_loss_fn(proj_list, i, y_pseudo_label[0], mask=y_mask[proj_mask_idx], sample_num=sample_num)


python
复制代码
### 数据集
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/6txcKqeynI2VauZwsEN5pGvmQSdO.png)

### 结果可视化
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/zIHB29sZtyfVa6vnC3G0j5DNTEQu.png)

### 相关应用

[【语义分割】——STDC-Seg快又强 + 细节边缘的监督]()

## 总结

本文提出了一种基于轮廓感知的半监督医学图像分割一致性框架,有效地利用了有限的标记数据和大量的未标记数据。通过结合轮廓增强辅助解码器和自对比策略缓解了边缘区域预测不准确的问题。

全部评论 (0)

还没有任何评论哟~