Advertisement

【深度学习】U-Net 网络分割多分类医学图像解析

阅读量:

【深度学习】U-Net 网络分割多分类医学图像解析

复制代码
    文章目录
    【深度学习】U-Net 网络分割多分类医学图像解析
    1 U-Net 多分类
    2 Keras 利用Unet进行多类分割
    	2.1 代码实现
    	2.2 结果
    3 多分类标签验证
    4 数据变换
    	4.1 概述
    	4.2 图像数据变化代码(为了满足多分类需求)
    	4.3 随机亮度(为了数据增强)
    5 Unet训练自己的数据
    
    
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

1 U-Net 多分类

Unet图像分割技术在大多数开源项目的实现中主要聚焦于二分类问题,在理论上而言这一方法同样适用于多分类场景

在这里插入图片描述

其前半部分主要负责特征提取功能,在其后半部分则实施上采样过程。在学术界常将其称为编码器-解码器架构。值得注意的是, 该网络的整体架构形似大写字母'U',因此得名U-Net。与大多数其他主流分割网络相比,U-Net采用了完全不同的特征融合机制:其主要特征融合手段为拼接操作,而另一类常见的方法,如FCN,则采用逐点相加法进行特征融合,这种方法并不产生更为丰富的特征层。

所以语义分割网络在特征融合时有两种办法:

在TensorFlow框架中, FCN架构通过调用tf.add()函数来实现对应点的相加运算. 相较于这一机制,U-Net架构则采用通道维度上的拼接融合机制,在TensorFlow中这可以通过调用tf.concat()函数来实现. 相较于上述创新性的特征融合方法,在以下方面U-Net架构展现出显著的优势: 1. 通过设计5个池化层结构,U-Net能够有效提取图像的不同尺度特征. 2. 在上采样过程中,系统会整合特征提取模块输出的信息. 这种连接关系贯穿整个网络结构的各个阶段,你可以观察到图中所示的四个整合过程,而传统的FCN网络仅在最后一层完成这种整合.

2 Keras 利用Unet进行多类分割

2.1 代码实现

采用了基于CamVid平台的数据集合进行研究,在实验中设定统一的图片大小为360×480像素。该研究包含了367张训练样本、另有101张用于校准的数据样本以及233张用于测试评估的数据样本。总计参与实验的图片数量为701张。所使用的深度学习框架包括TensorFlow和Keras技术。更多详细信息可访问以下链接:https://github.com/preddy5/segnet/tree/master/CamVid

我们从main2.py开始查看整个项目的实现步骤。

复制代码
    # -*- coding:utf-8 -*-
    # Author : Ray
    # Data : 2019/7/25 2:15 PM
    
    
    from datapre2 import *
    from model import *
    import warnings
    
    warnings.filterwarnings('ignore')	#忽略警告信息
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'	#设置GPU
    #测试集233张,训练集367张,校准集101张,总共233+367+101=701张,图像大小为360*480,语义类别为13类
    
    aug_args = dict( #设置ImageDataGenerator参数
    rotation_range = 0.2,
    width_shift_range = 0.05,
    height_shift_range = 0.05,
    shear_range = 0.05,
    zoom_range = 0.05,
    horizontal_flip = True,
    vertical_flip = True,
    fill_mode = 'nearest'
    )
    
    train_gene = trainGenerator(batch_size=2,aug_dict=aug_args,train_path='CamVid/',
                        image_folder='train',label_folder='trainannot',
                        image_color_mode='rgb',label_color_mode='rgb',
                        image_save_prefix='image',label_save_prefix='label',
                        flag_multi_class=True,save_to_dir=None
                        )
    val_gene = valGenerator(batch_size=2,aug_dict=aug_args,val_path='CamVid/',
                       image_folder='val',label_folder='valannot',
                       image_color_mode='rgb',label_color_mode='rgb',
                       image_save_prefix='image',label_save_prefix='label',
                       flag_multi_class=True,save_to_dir=None
                       )
    tensorboard = TensorBoard(log_dir='./log')
    model = unet(num_class=13)
    model_checkpoint = ModelCheckpoint('camvid.hdf5',monitor='val_loss',verbose=1,save_best_only=True)
    
    history = model.fit_generator(train_gene,
                              steps_per_epoch=100,
                              epochs=20,
                              verbose=1,
                              callbacks=[model_checkpoint,tensorboard],
                              validation_data=val_gene,
                              validation_steps=50   #validation/batchsize
                              )
    # model.load_weights('camvid.hdf5')
    test_gene = testGenerator(test_path='CamVid/test')
    results = model.predict_generator(test_gene,233,verbose=1)
    saveResult('CamVid/testpred/',results)
    
    
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

从上述代码中可以看到,整个项目的思路比较明确,概括而言就是:

复制代码
    数据准备
    模型训练
    预测结果
    
    
      
      
      
    
    代码解读

先利用trainGenerator、valGenerator准备训练数据,利用tensorboard = TensorBoard(log_dir=’./log’)保存日志信息,能够查看训练过程的loss、acc变化。调用ModelCheckpoint保存模型。模型训练时采用fit_generator模式输入。最后在调用predict_generatorj进行预测时先准备好测试集数据,然后调用saveResult保存预测结果。
注意到main.py中在model.fit_generator()中设置callbacks=[model_checkpoint,tensorboard],我们可以在训练结束之后查阅loss和acc的训练过程。
打开terminal,进入项目所在路径,输入tensorboard --logdir ./log(是你代码中所写的),然后出现以下界面就说明你的路径正确了
在浏览器键入http://http:localhost:6006即可查阅loss及acc的训练过程。

2.2 结果

在这里插入图片描述

3 多分类标签验证

该颜色值通常用于将RGB值转换为十六进制表示...
b = a.copy()
在处理一维数组或列表时,
通过调用np.unique函数对数组A进行处理,
它会去除其中重复的元素并按元素由大到小返回一个新的无元素重复的元组或者列表。

复制代码
    import numpy as np
    A = [1, 2, 2, 5,3, 4, 3]
    a = np.unique(A)
    B= (1, 2, 2,5, 3, 4, 3)
    b= np.unique(B)
    C= ['fgfh','asd','fgfh','asdfds','wrh']
    c= np.unique(C)
    print(a)
    print(b)
    print(c)
    #   输出为 [1 2 3 4 5]
    # [1 2 3 4 5]
    # ['asd' 'asdfds' 'fgfh' 'wrh']
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

将变量c和s分别赋值为数组b的唯一元素及其对应的索引位置。
返回参数True指示函数不仅返回唯一的数值结果(存储于变量c),还提供这些数值在原始数据中的具体位置信息(存储于变量s)。
将变量a和s分别赋值为数组A的唯一元素及其对应的索引位置。
打印数组a的结果。
打印数组s的结果。
运行程序后得到以下输出结果:
[1 2 3 4 5]
[0 1 4 5 3]

a, s,p = np.unique(A, return_index=True, return_inverse=True)
其中参数return_inverse设为True时,默认返回旧列表元素在新列表中的索引位置,并将这些索引存储在一个列表中返回。

数组a、s和p分别通过np.unique函数实现对输入数组A的一次去重运算并返回索引信息
其中返回索引并设为True
随后分别打印出数组a的值以及对应的索引信息
运行上述代码后得到的结果如下:
输出结果为数组[1,2,3,4,5]
输出索引结果为数组[0,1,4,5,3]

复制代码
 对于一维列表或数组A: 
    
    import numpy as np
    A = [1, 2, 2, 3, 4, 3]
    a = np.unique(A)
    print a            # 输出为 [1 2 3 4]
    a, b, c = np.unique(A, return_index=True, return_inverse=True)
    print a, b, c      # 输出为 [1 2 3 4], [0 1 3 4], [0 1 1 2 3 2]
    2. 对于二维数组(“darray数字类型”): 
    
    A = [[1, 2], [3, 4], [5, 6], [1, 2]]
    A = np.array(A)   #列表类型需转为数组类型
    a, b, c = np.unique(A.view(A.dtype.descr * A.shape[1]), return_index=True, return_inverse=True)
    print a, b, c     #输出为 [(1, 2) (3, 4) (5, 6)], [0 1 2], [0 1 2 0]
    可以看出, Python中unique函数与Matlab完全一致. 
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

数据预处理最重要的一步就是要对gt进行one-hot编码,

4 数据变换

4.1 概述

图像方面的数据增强可以从下面几个角度来看.

仿射变换(Affine transformation: 随机裁剪、随机翻转、随机旋转、随机缩放、随机错切、随机平移等技术);
彩色失真(Color distortion: 随机Gamma校正、随机亮度调整、随机色调变化、随机对比度调整及高斯噪声干扰等技术);
信息丢弃(Information erasure: Gridmask、Cutout、Random Erasing及Hide-and-seek等技术);
多图融合(Image fusion: Mixup、Cutmix及FMix等混合增强技术);
另类(Alternative methods: Augmix等先进技术)

4.2 图像数据变化代码(为了满足多分类需求)

该代码的主要目标是调整不符合网络输出标准的图像内容,并模拟一个规范化的操作流程。

复制代码
    
    
    
    
    class CenterCrop(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
    
    def __call__(self, img, mask):
        assert img.size == mask.size
        w, h = img.size
        th, tw = self.size
        x1 = int(math.ceil((w - tw) / 2.))
        y1 = int(math.ceil((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))
    
    
    class SingleCenterCrop(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
    
    def __call__(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = int(math.ceil((w - tw) / 2.))
        y1 = int(math.ceil((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th))
    
    
    class CenterCrop_npy(object):
    def __init__(self, size):
        self.size = size
    
    def __call__(self, img, mask):
        assert img.shape == mask.shape
        if (self.size <= img.shape[1]) and (self.size <= img.shape[0]):
            x = math.ceil((img.shape[1] - self.size) / 2.)
            y = math.ceil((img.shape[0] - self.size) / 2.)
    
            if len(mask.shape) == 3:
                return img[y:y + self.size, x:x + self.size, :], mask[y:y + self.size, x:x + self.size, :]
            else:
                return img[y:y + self.size, x:x + self.size, :], mask[y:y + self.size, x:x + self.size]
        else:
            raise Exception('Crop shape (%d, %d) exceeds image dimensions (%d, %d)!' % (
                self.size, self.size, img.shape[0], img.shape[1]))
    
    
    
    
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4.3 随机亮度(为了数据增强)

随机调整亮度基于像素值的变化来实现亮度变化的效果显著提升,在此我们提供相应的代码实现以达到图像数据增强的目的

defrandom_brightness():

results = np.copy(batch_xs)

fori inrange( 9):

image = sess.run(tf.image.random_brightness(batch_xs[i].reshape( 28, 28), 0.9))

results[i, :, :, :] = image.reshape( -1, 28, 28)

show_images(results, “random_brightness”)

在这里插入图片描述

5 Unet训练自己的数据

整个模型训练过程如下图所示:

在这里插入图片描述

Unet网络调参优化研究
针对Unet网络的参数调整过程主要包括以下几点:(1)引入批归一化层以加速训练并提升模型稳定性;(2)将最后一层激活函数更换为ReLU以促进非线性特征提取能力的增强;(3)采用均方误差损失函数替代交叉熵损失以优化回归任务的表现。在多分类场景中,默认的全连接层末尾通常采用软最大(softmax)函数。然而,在实际实验中发现该配置下分割效果并不理想。于是尝试采用此前在图对图项目中使用的激活函数进行替代改进。这种做法纯属经验层面的尝试,并无相应的理论基础支撑。同样地,在交叉熵损失得到广泛应用的同时也遇到了性能瓶颈问题。因此最终选择将损失函数更换为均方根误差作为替代方案进行测试评估。

全部评论 (0)

还没有任何评论哟~