Advertisement

医学图像语义分割

阅读量:

语义分割在生物医学图像分析领域具有重要且广泛的用途:包括X射线设备、磁共振成像(MRI)技术、数字病理学研究以及显微镜观察等前沿技术。该网站https://grand-challenge.org/challenges提供了众多具有挑战性的问题供研究人员探索和解决。

在技术层面来看,在处理语义分割任务时,给定一个RGB图像(即尺寸为N \times M \times 3),我们需要建立其对应的标签图(尺寸为N \times M \times k,其中k代表类别数量)。为了实现这一目标,存在多种解决方案。在这里我想重点介绍两个主流的架构:U-Net和U-Net++。

在这一领域中积累了诸多评价与讨论,在此架构下实现了对问题的有效解决方案。该设计框架由编码器和解码器两部分构成,在此过程中前者通过特征提取完成对图像的理解阶段而后者则利用这些特征进行图像分割。通过特征映射连接(用灰色箭头表示)实现了不同分辨率信息的有效融合,并在此基础上构建出完整的分割体系。这种架构的成功应用已为其奠定了重要地位

在这里插入图片描述

随后我们将采用一个经过训练的编码器。针对图像分类问题,在这里我们旨在构建图像的表征。从而使得不同类别在这一特征空间得以区分。在这里我们几乎可以应用任何卷积神经网络(CNN)作为编码器,在此过程中从编码器中提取特征后,则传递给我们的解码器进行处理。据我所知,Iglovikov & Shvets 使用了VGG11和resnet34这两种模型分别作为Unet解码器来生成更优的表征并提升性能。

在这里插入图片描述

Unet++是最近对Unet体系结构的改进,它有多个跳跃连接。

在这里插入图片描述

研究表明,在经典的深度学习框架中,基于卷积神经网络的模型表现出了显著的优越性。与经典的Unet架构相似,在本研究中我们采用多级编码器(骨干模块)来提取出丰富的表征信息。

在选择编码器方面存在哪些考虑?本文将重点阐述两种模型:Unet及其增强版Unet++。为了便于比较分析,在本研究中我们选择了基于胸部X射线图像数据集进行肺部分割。这是一个典型的二值分割问题:我们对每个像素赋予一个概率值p(x),若该像素属于目标区域,则p(x)=1;否则p(x)=0。让我们先了解数据的基本特征与分布情况吧!

在这里插入图片描述

这些是一些非常大的图像,在分辨率上通常具有2000×2000像素。它们包含相当大的遮罩区域,在视觉效果上来看,并不会影响到肺部区域的识别效果。为了提升模型性能与灵活性,在PyTorch框架下我们借助于segmentation_models_pytorch库支持了15种不同的预训练编码器。通过该库的支持我们构建了一个高效的数据处理流水线用于模型训练,并结合Albumentations(一个强大的图像数据增强工具),这使得数据预处理过程更加灵活多变的同时还能有效提升模型性能。此外我们还设计了一个复杂的特征提取模块以提高模型对复杂场景的理解能力以及预测精度。

构建数据集并进行增强处理。将图像尺寸统一设置为256×256像素,并在训练数据集中实施一系列复杂的增强操作以提升模型鲁棒性。

复制代码
    import albumentations as A
    from torch.utils.data import Dataset, DataLoader
    from collections import OrderedDict
    
    class ChestXRayDataset(Dataset):
    def __init__(
        self,
        images,
        masks,
            transforms):
        self.images = images
        self.masks = masks
        self.transforms = transforms
    
    def __len__(self):
        return(len(self.images))
    
    def __getitem__(self, idx):
        """Will load the mask, get random coordinates around/with the mask,
        load the image by coordinates
        """
        sample_image = imread(self.images[idx])
        if len(sample_image.shape) == 3:
            sample_image = sample_image[..., 0]
        sample_image = np.expand_dims(sample_image, 2) / 255
        sample_mask = imread(self.masks[idx]) / 255
        if len(sample_mask.shape) == 3:
            sample_mask = sample_mask[..., 0]  
        augmented = self.transforms(image=sample_image, mask=sample_mask)
        sample_image = augmented['image']
        sample_mask = augmented['mask']
        sample_image = sample_image.transpose(2, 0, 1)  # channels first
        sample_mask = np.expand_dims(sample_mask, 0)
        data = {'features': torch.from_numpy(sample_image.copy()).float(),
                'mask': torch.from_numpy(sample_mask.copy()).float()}
        return(data)
    
    def get_valid_transforms(crop_size=256):
    return A.Compose(
        [
            A.Resize(crop_size, crop_size),
        ],
        p=1.0)
    
    def light_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
    ])
    
    def medium_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])
    
    
    def heavy_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])
    
    def get_training_trasnforms(transforms_type):
    if transforms_type == 'light':
        return(light_training_transforms())
    elif transforms_type == 'medium':
        return(medium_training_transforms())
    elif transforms_type == 'heavy':
        return(heavy_training_transforms())
    else:
        raise NotImplementedError("Not implemented transformation configuration")
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
  1. 建立模型与损失函数体系。本研究中我们采用了带有regnety_004编码器的Unet++架构作为基础模型,并配置了RAdam优化算法与Lookahed辅助技术相结合的方式进行参数更新。为了衡量模型性能与训练效果,在损失函数方面我们综合运用了DICE相似性度量指标与二分类交叉熵(BCE)损失函数,并采用两者的加权和形式作为整体目标函数进行优化。

  2. 构建模型及损失函数方案。具体而言,在本研究中我们选择性地采用了带有regnety_004编码器的Unet++架构作为核心模块,并搭配RAdam优化算法与Lookahed辅助技术实现参数的有效更新机制。在评估模型性能方面,则通过整合DICE相似性度量指标与二分类交叉熵(BCE)损失函数两种指标形成复合型目标函数来进行系统训练。

复制代码
    import torch
    import segmentation_models_pytorch as smp
    import numpy as np
    import matplotlib.pyplot as plt
    from catalyst import dl, metrics, core, contrib, utils
    import torch.nn as nn
    from skimage.io import imread
    import os
    from sklearn.model_selection import train_test_split
    from catalyst.dl import  CriterionCallback, MetricAggregationCallback
    encoder = 'timm-regnety_004'
    model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
    #model.cuda()
    learning_rate = 5e-3
    encoder_learning_rate = 5e-3 / 10
    layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
    model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
    base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
    optimizer = contrib.nn.Lookahead(base_optimizer)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
    criterion = {
    "dice": DiceLoss(mode='binary'),
    "bce": nn.BCEWithLogitsLoss()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
  1. 定义回调函数并训练!
复制代码
    callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
       input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_bce",
        criterion_key="bce"
    ),
    
    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum", 
        metrics={
            "loss_dice": 1.0, 
            "loss_bce": 0.8
        },
    ),
    
    # metrics
    IoUMetricsCallback(
        mode='binary', 
        input_key='mask', 
    )
    
    ]
    
    runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
    runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='../logs/xray_test_log',
    num_epochs=100,
    main_metric="loss",
    minimize_metric=True,
    verbose=True,
    )
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

如果我们将不同类型的编码器应用于Unet及其增强版模型(Unet++)中进行性能测试,则可以观察到每个训练周期所达到的验证效果如何;通过不同编码器对这两个核心模型进行测试后,则能够观察到每个训练周期所达到的验证效果如何;我们能够观察到每个训练模型的验证结果,并据此得出结论;

在这里插入图片描述

值得注意的是,在所有编码器体系中运用Unet++展现出显著的性能优势。然而,在个别情况下两者的性能差异并不显著,并不能断定二者之间存在根本性的差异——建议我们在多轮次实验后进行评估。值得注意的是ResNest-200模型展现出卓越的质量,并且其参数规模依然保持合理水平。此外,在相关基准测试中的表现同样具有竞争力——参考https://paperswithcode.com/task/semantic-segmentation我们可以对比分析采用不同架构(如Unet++与标准Unet)结合ResNest-200模型进行预测的效果。

在这里插入图片描述

基于resnest200e编码器的预测结果表明

在这里插入图片描述

我们可以使用类似的代码在这个数据集上训练Unet++模型,如下所示:

在这里插入图片描述

基于验证集的Unet++得分评估结果显示

在这里插入图片描述

基于ResNeSt-200e与RegNetY_002的预测

毫无疑问,在此数据集上任务难度较大——不仅mask精度不够理想(即mask不够精确),而且个别细胞核被错误地归类至其他类别(即个别核被分配到错误的类别)。然而,在使用基于ResNeSt-200e编码器设计的Unet++架构中仍表现出良好的性能。

并非旨在全面覆盖语义分割的所有方案的一种指导性框架。相反地,则更像是一种概念性的起点思路。选择合适的基线模型有助于构建可靠的基础框架,在后续研究中可在此基础上进行深入探索与改进优化。现有的诸多模型架构在细节上有诸多差异:包括FPN、DeepLabV3、Linknet等模块组件的组合方式以及解码策略的选择等维度均有所区别。其中许多都沿袭了类似于U-Net的设计理念,并在此基础上进行了诸多创新性改进:例如采用双编码器联合设计的MAnet、PraNet等变体结构;以及近年来兴起的U²-net等新型架构设计均值得深入研究与实践探索。

This article explores the optimal strategy for semantic segmentation in biomedical imaging, emphasizing the critical role of deep learning models and data augmentation techniques. Supervised learning frameworks are employed to establish robust mappings between input images and their corresponding semantic annotations. The study highlights the significance of incorporating specialized loss functions tailored to medical imaging datasets, which enhance the accuracy of pixel-level classification tasks. Furthermore, the article delves into optimization strategies aimed at minimizing computational overhead while maintaining high prediction precision. By integrating these advanced methodologies, the research underscores the potential for achieving state-of-the-art performance in analyzing complex biomedical images. This comprehensive evaluation framework provides valuable insights for both researchers and practitioners in advancing medical imaging technologies through innovative approaches.

全部评论 (0)

还没有任何评论哟~