Advertisement

医学图像分割 基于深度学习的肝脏肿瘤分割 实战(一)

阅读量:

在之前的某一篇博客中( stage 的知识储备尚不充足(目前主要掌握Python编程基础及深度学习框架实现),文中可能存在诸多不足之处,请广大读者批评指正

正文:

文章目录

  • 目的
    • 原始数据信息介绍
    • 总体工作思路
    • 首先,在数据提取阶段
      • 然后进行数据预处理工作
      • 接着执行数据增强操作
      • 最后完成数据存储流程
      • 完整记录了测试阶段的数据存储全过程
      • 在网络构建环节开始实施训练操作
      • 最终完成模型训练并进行性能测试

目标

分割出CT腹部图像的肝脏区域。

原始数据介绍

实验中所涉及的信息源自3\text{D} Ricad\text{b}系统,并即是腹部CT图像数据。每个患者单独存储在一个独立的文件夹中。如3\text{D} Ricad 1.1版本包含20个不同的分割区域,则在对应的DICOM目录结构中有一个专门用于存储完整的人体切片序列的PATIENT DICOM目录以及一个专门用于存储各个解剖学部位分割结果的MASKS DICOM目录。其中,MASKS DICOM目录则包含了针对各个解剖学部位的不同分割结果区域,如肝脏以及肝脏肿瘤等解剖学区域等信息,如下图所示

在这里插入图片描述

PATIENT_DICOM利用软件展示效果如下:一个dcm文件包含129张切片。

在这里插入图片描述

MASKS_DICOM下的liver分割图效果如下:

在这里插入图片描述

整体思路

1、数据提取

数据读取:
从原始dcm格式读入成我们需要的数组格式

复制代码
    #part1
    import numpy as np
    import pydicom
    import os
    import matplotlib.pyplot as plt
    import cv2
    from keras.preprocessing.image import ImageDataGenerator
    from HDF5DatasetWriter import HDF5DatasetWriter
    from HDF5DatasetGenerator import HDF5DatasetGenerator
    
    for i in range(1,18): # 前17个人作为测试集
       full_images = [] # 后面用来存储目标切片的列表
       full_livers = [] #功能同上
       # 注意不同的系统,文件分割符的区别
       label_path = '~/3Dircadb/3Dircadb1.%d/MASKS_DICOM/liver'%i
       data_path = '~/3Dircadb/3Dircadb1.%d/PATIENT_DICOM'%i
       liver_slices = [pydicom.dcmread(label_path + '/' + s) for s in os.listdir(label_path)]
       # 注意需要排序,即使文件夹中显示的是有序的,读进来后就是随机的了
       liver_slices.sort(key = lambda x: int(x.InstanceNumber))
       # s.pixel_array 获取dicom格式中的像素值
       livers = np.stack([s.pixel_array for s in liver_slices])
       image_slices = [pydicom.dcmread(data_path + '/' + s) for s in os.listdir(data_path)]
       image_slices.sort(key = lambda x: int(x.InstanceNumber))
       
       """ 省略进行的预处理操作,具体见part2"""
       
       full_images.append(images)
       full_livers.append(livers)
       
       full_images = np.vstack(full_images)
       full_images = np.expand_dims(full_images,axis=-1)
       full_livers = np.vstack(full_livers)
       full_livers = np.expand_dims(full_livers,axis=-1)
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/zHATURsOBVdYlNn6bhCJIQu0xq3L.png)

2、数据的预处理

1、将CT值转换为统一标准的HU数值
由于不详细阐述原因,请参考上述链接中的get_pixels_hu函数。
2、窗口化处理步骤介绍
医学图像预处理(三)——windowing(ct对比增强)
3、直方图均匀化处理过程
对图像进行直方图均匀化处理以改善对比度

复制代码
    def clahe_equalized(imgs,start,end):
       assert (len(imgs.shape)==3)  #3D arrays
       #create a CLAHE object (Arguments are optional).
       clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
       imgs_equalized = np.empty(imgs.shape)
       for i in range(start, end+1):
       imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
       return imgs_equalized
    
    
    AI生成项目python

4、标准化处理
5、仅提取包含肝脏在内的腹部所有切片,并非包括未含肝脏的部分

复制代码
    #part2
    # 接part1
       images = get_pixels_hu(image_slices)
       
       images = transform_ctdata(images,500,150)
       
       start,end = getRangImageDepth(livers)
       images = clahe_equalized(images,start,end)
       
       images /= 255.
       # 仅提取腹部所有切片中包含了肝脏的那些切片,其余的不要
      
       total = (end - 4) - (start+4) +1
       print("%d person, total slices %d"%(i,total))
       # 首和尾目标区域都太小,舍弃
       images = images[start+5:end-5]
       print("%d person, images.shape:(%d,)"%(i,images.shape[0]))
       
       livers[livers>0] = 1
       
       livers = livers[start+5:end-5]
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/oLqPxN9j2TZtc45fBY0FEKgweOCX.png)

3、数据增强

通过Keras提供的数据增强接口功能,在处理分割问题时可以达到数据增广的目的。通常用于分类任务的增强方法仅需对图像进行变换操作即可,并无需调整标签信息。但对于分割任务而言,则需要同时将图像与对应的mask进行相同的变换处理。具体实现方法可参考下方代码示例部分,请确保随机种子设置一致以保证结果的一致性。

复制代码
    # 可以在part1之前设定好(即循环外)
    seed=1
    data_gen_args = dict(rotation_range=3,
                    width_shift_range=0.01,
                    height_shift_range=0.01,
                    shear_range=0.01,
                    zoom_range=0.01,
                    fill_mode='nearest')
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/pYhQcuHWP0yANF8ojCk9Kzg3mZ75.png)
复制代码
    #part3 接part2
    image_datagen.fit(full_images, augment=True, seed=seed)
    mask_datagen.fit(full_livers, augment=True, seed=seed)
    image_generator = image_datagen.flow(full_images,seed=seed)
    mask_generator = mask_datagen.flow(full_livers,seed=seed)
    
    train_generator = zip(image_generator, mask_generator)
    x=[]
    y=[]
    i = 0
    for x_batch, y_batch in train_generator:
        i += 1
        x.append(x_batch)
        y.append(y_batch)
        if i>=2: # 因为我不需要太多的数据
            break
    x = np.vstack(x)
    y = np.vstack(y)
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/zAGmph9K7uBWVY6SePZcTaDIoltR.png)

4、数据存储

在处理大数据时(面对大量数据的情况),通常会将原始数据库的内容转换为NP数组或H5格式的文件(这样做有两个优点:首先,在进行网络训练时输入的数据量会大幅减少(尤其是对于大型数据库来说),其次,在进行数据共享或上传到服务器时更加便捷)。

实验中将涉及两个类别:分别为处理H5文件和解析H5文件的辅助类。

复制代码
    class HDF5DatasetWriter:
    	"""用来写数据到h5文件"""
    class HDF5DatasetGenerator:
    """用来读h5文件的数据"""
    
    
    AI生成项目python

它们的具体实现主要集中在python h5文件的读写这一方面,并且由于篇幅限制,在此不做详细阐述

h5文件操作需要的包import h5py

复制代码
    # 可以在part1之前设定好(即循环外)
    # 这儿的数量需要提前写好,感觉很不方便,但不知道怎么改,我是先跑了之前的程序,计算了一共有多少
    # 张图片后再写的,但这样明显不是好的解决方案
    dataset = HDF5DatasetWriter(image_dims=(2782, 512, 512, 1),
                            mask_dims=(2782, 512, 512, 1),
                            outputPath="data_train/train_liver.h5")
    
    
    AI生成项目python
复制代码
    #part4 接part3
    dataset.add(full_images, full_livers)
    dataset.add(x, y)
    # end of lop
    dataset.close()
    
    
    AI生成项目python

测试数据存储的全部过程

测试数据的大体相同于训练数据的处理流程,但无需对测试数据实施数据增强技术

复制代码
    full_images2 = []
    full_livers2 = []
    for i in range(18,21):#后3个人作为测试样本
    label_path = '~/3Dircadb/3Dircadb1.%d/MASKS_DICOM/liver'%i
    data_path = '~/3Dircadb/3Dircadb1.%d/PATIENT_DICOM'%i
    liver_slices = [pydicom.dcmread(label_path + '/' + s) for s in os.listdir(label_path)]
    liver_slices.sort(key = lambda x: int(x.InstanceNumber))
    livers = np.stack([s.pixel_array for s in liver_slices])
    start,end = getRangImageDepth(livers)
    total = (end - 4) - (start+4) +1
    print("%d person, total slices %d"%(i,total))
    
    image_slices = [pydicom.dcmread(data_path + '/' + s) for s in os.listdir(data_path)]
    image_slices.sort(key = lambda x: int(x.InstanceNumber))
    
    images = get_pixels_hu(image_slices)
    images = transform_ctdata(images,500,150)
    images = clahe_equalized(images,start,end)
    images /= 255.
    images = images[start+5:end-5]
    print("%d person, images.shape:(%d,)"%(i,images.shape[0]))
    livers[livers>0] = 1
    livers = livers[start+5:end-5]
    
    full_images2.append(images)
    full_livers2.append(livers)
    
    full_images2 = np.vstack(full_images2)
    full_images2 = np.expand_dims(full_images2,axis=-1)
    full_livers2 = np.vstack(full_livers2)
    full_livers2 = np.expand_dims(full_livers2,axis=-1)
    
    dataset = HDF5DatasetWriter(image_dims=(full_images2.shape[0], full_images2.shape[1], full_images2.shape[2], 1),
                            mask_dims=(full_images2.shape[0], full_images2.shape[1], full_images2.shape[2], 1),
                            outputPath="data_train/val_liver.h5")
    
    
    dataset.add(full_images2, full_livers2)
    
    print("total images in val ",dataset.close())
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/7Lbf0I1UdB85QlW4smRrjJkMoCXa.png)

5、构建网络

这部分不做过多赘述,默认采用了现成的UNet架构。主要改动集中在Crop模块中。具体来说,在设计这一模块时遇到了如下问题:当输入图像的高度或宽度无法满足2^n(n≥m)的要求时(其中m表示网络收缩分支中的层数),下采样过程不可避免地引入了四舍五入操作(即取整策略)。尽管如此,在扩展模块中仍会执行上采样操作,并将其与收缩分支生成的特征图进行拼接(即长连接)。如果无法满足上述条件,则会导致尺寸不匹配并引发错误。

复制代码
    # partA
    import os
    import sys
    import numpy as np
    import random
    import math
    import tensorflow as tf
    from HDF5DatasetGenerator import HDF5DatasetGenerator
    from keras.models import Model
    from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D
    from keras.optimizers import Adam
    from keras.callbacks import ModelCheckpoint
    from keras import backend as K
    from skimage import io
    
    
    K.set_image_data_format('channels_last')
    
    def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    
    
    def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)
    
    def get_crop_shape(target, refer):
        # width, the 3rd dimension
        print(target.shape)
        print(refer._keras_shape)
        cw = (target._keras_shape[2] - refer._keras_shape[2])
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target._keras_shape[1] - refer._keras_shape[1])
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)
    
        return (ch1, ch2), (cw1, cw2)
    
    def get_unet():
    inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    
    up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
    
    ch, cw = get_crop_shape(conv4, up_conv5)
    
    crop_conv4 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv4)
    up6 = concatenate([up_conv5, crop_conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    
    up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
    
    ch, cw = get_crop_shape(conv3, up_conv6)
    crop_conv3 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv3)
    
    up7 = concatenate([up_conv6, crop_conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    
    up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    crop_conv2 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv2)
    
    up8 = concatenate([up_conv7, crop_conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    
    up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    crop_conv1 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv1)
    
    
    up9 = concatenate([up_conv8, crop_conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    
    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    
    model = Model(inputs=[inputs], outputs=[conv10])
    
    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
    
    return model
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/HJGw76ys3Q8KjXUhSg5DFpoZVT9e.png)

6、进行训练并测试

其中主要包括训练数据集与测试数据集的读取过程主要涉及以下几个方面首先是取消相关的CheckPoint函数设置然后基于fit_generator方法进行模型训练接着通过模型进行预测操作并对预测结果进行了存储处理

复制代码
    # partB 接partA
    IMG_WIDTH = 512
    IMG_HEIGHT = 512
    IMG_CHANNELS = 1
    TOTAL = 2782 # 总共的训练数据
    TOTAL_VAL = 152 # 总共的validation数据
    # part1部分储存的数据文件
    outputPath = './data_train/train_liver.h5' # 训练文件
    val_outputPath = './data_train/val_liver.h5'
    #checkpoint_path = 'model.ckpt'
    BATCH_SIZE = 8 # 根据服务器的GPU显存进行调整
    
    class UnetModel:
    def train_and_predict(self):
        
        reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
        train_iter = reader.generator()
        
        test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
        test_iter = test_reader.generator()
        fixed_test_images, fixed_test_masks = test_iter.__next__()
    #   
        
        model = get_unet()
        model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True)
        # 注:感觉validation的方式写的不对,应该不是这样弄的
        model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
                            validation_data=(fixed_test_images, fixed_test_masks),callbacks=[model_checkpoint])
    #        
        reader.close()
        test_reader.close()
        
        
        print('-'*30)
        print('Loading and preprocessing test data...')
        print('-'*30)
        
        print('-'*30)
        print('Loading saved weights...')
        print('-'*30)
        model.load_weights('weights.h5')
    
        print('-'*30)
        print('Predicting masks on test data...')
        print('-'*30)
        
        
        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
        np.save('imgs_mask_test.npy', imgs_mask_test)
    
        print('-' * 30)
        print('Saving predicted masks to files...')
        print('-' * 30)
        pred_dir = 'preds'
        if not os.path.exists(pred_dir):
            os.mkdir(pred_dir)
        i = 0
        
        
        for image in imgs_mask_test:
            image = (image[:, :, 0] * 255.).astype(np.uint8)
            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
            i += 1
    
    unet = UnetModel()
    unet.train_and_predict()
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/6t87UFbd2hosyvZHSrG1qLRD53Ea.png)

模型跑的过程如图。

在这里插入图片描述

预测结果可视化展示

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~