Advertisement

医学图像分割 基于深度学习的肝脏肿瘤分割 实战(二)

阅读量:

在医学图像分割(基于深度学习)实战系列文章之一《肝癌肿瘤分割》中(原文链接:

第一次 实验中发现网络性能(dice系数)在训练集上始终在40%以下徘徊,在测试集中的准确率仅为约10%,这些结果与预期相差较大。经过仔细排查发现可能与数据质量有关,在查阅大量原始医学影像后发现许多病例即使在专业医师看来也难以区分病变区域边界。基于此猜想进行了一系列实验验证,并将相关结论详细记录于博客文章《医学图像预处理(五) 器官与病灶的直方图》:进一步证实了不同病人肝脏与肿瘤区域的hu值分布高度重合

通过实验验证了猜想;随后又进行了相关实验;于是转向使用LiTS2017数据库;其数据规模更为庞大;然后根据病人体内肝脏与肿瘤的直方图分布情况;手动完成了对130名病人的分类划分:其中level 1为肝脏与肿瘤对比度最高者;而level 3则为对比度几乎可测者(即几乎没有);并最终选取了level 1及以下进行训练数据集构建

当然,在本次实验中,在数据预处理阶段以及ROI识别阶段均采用了与前一致的方法,并未出现明显差异。
然而,在本次实验中,在相同的条件下设置下……仍然出现了同样的结果。

第二次 :首先在训练集上骰系数能够达到10%,但随后又骤降至接近于零的状态。对此感到困惑,并误以为可能是发生了梯度消失或梯度爆炸导致的问题。然而,在深入思考后发现了一个看似奇怪实则严肃的问题:即ROI操作通过逻辑与运算将真实肝脏分割结果与原图结合在一起后,在非感兴趣区域中实现了黑色化处理;然而由于肿瘤区域通常具有较低的灰度值这一特性,在这一区域内也呈现出黑色状态(相对于肝脏而言)。为了区分感兴趣区域与其他部分的目标差异性,则采用了窗口值等技术手段对非目标区域进行黑色化处理;最终使得肝脏区域呈现为灰白色(而其他器官则保持白色状态)。于是尝试进行了一项大胆的实验:将肝脏区域设置为灰色,并将其边界外的肿瘤设置为白色。(进行了颜色翻转)

在这里插入图片描述
在这里插入图片描述

本次实验中,在训练集中实现了dice系数约为90%,而在测试集上的dice系数约为70%。总体而言表现尚可,并且成功验证了猜想的可行性。这一结果进一步验证了数据为王的原则:正如所言,在正确的方向上播种必定会有相应的收获。

本实验采用的代码框架主要包含两大部分:第一阶段的任务主要是数据预处理工作,在本地环境下运行;第二阶段则涉及模型架构搭建与训练过程,在服务器环境下完成模型的构建与训练工作。其中,在服务器环境下完成模型的构建与训练工作时,请注意使用ubuntu16.04系统并配置好tensorflow-gpu环境变量

第一部分:
(注:有关h5文件读写的工具类也放在了博客里)

复制代码
    # -*- coding: utf-8 -*-
    
    """
    根据LITS_check.py,观察结果
    根据肝脏与肿瘤的对比度,将病人分成 3 level
    1 level:对比度最高(随机选出两个作为validation集)
    2 level: 对比度中等(随机选出两个作为validation集)
    3 level:对比度最低
    """
    # theshold = 1e-3, total=755
    # 81,125作为测试集
    level_1 = [0,1,22,23,25,26,27,31,37,46,49,50,55,57,58,59,61,62,
           63,64,66,78,79,82,83,90,92,95,99,109,112,124]
    
    #level_1 = [63,64,66,78,79,81,82,83]
    # theshold = 1e-3, total= 1345
    # 11,110作为测试集
    level_2 = [2,7,8,9,10,12,14,15,17,28,35,40,42,
           53,56,69,76,93,96,101,111,113,117]
    
    level12 = level_1 + level_2
    level12.sort()
    
    test_list = [11,81,110,125]
    
    a = [i for i in range(130)]
    level_3 =list(set(a)-set(level_1)-set(level_2))
    # sort方法直接改变原列表,无返回值
    level_3.sort()
    
    
    """
    将level_1的其余图片观察,确定是否对比度高
    观察后确定窗口值为:[-50,200]
    """
    onServer = False
    if onServer:
    niiSegPath = './LITS17/seg/'
    niiImagePath = './LITS17/ct/'
    else:
    niiSegPath = '~/Documents/LITS17/seg/'
    niiImagePath = '~/Documents/LITS17/ct/'
    
    
    import numpy as np
    import SimpleITK as sitk
    import matplotlib.pyplot as plt
    
    def getRangeImageDepth(image):
    z = np.any(image, axis=(1,2)) # z.shape:(depth,)
    #print("all index:",np.where(z)[0])
    if len(np.where(z)[0]) >0:
        startposition,endposition = np.where(z)[0][[0,-1]]
    else:
        startposition = endposition = 0
    
    return startposition, endposition
    
    def sample_stack(stack, name="images.png", rows=4, cols=2, start_with=0, show_every=1):
    fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
    if rows==1 or cols==1 :
        nums = rows*cols
        for i in range(nums):
            ind = start_with + i*show_every
            ax[int(i % nums)].set_title('slice %d' % ind)
            ax[int(i % nums)].imshow(stack[ind],cmap='gray')
            ax[int(i % nums)].axis('off')
    else:
        for i in range(rows*cols):
            ind = start_with + i*show_every
            ax[int(i/cols),int(i % cols)].set_title('slice %d' % ind)
            ax[int(i/cols),int(i % cols)].imshow(stack[ind],cmap='gray')
            ax[int(i/cols),int(i % cols)].axis('off')
    # 这句话一定要在show之前写,否则show函数之后会创建新的空白图
    #    plt.savefig(name)
    plt.show()
    
    """
    工具函数,左边原图,右边真实分割图
    """
    def show_src_seg(srcimg, segimg,index, rows=3,start_with=0, show_every=1):
    assert srcimg.shape == segimg.shape
    
    rows = srcimg.shape[0]
    plan_rows = start_with + rows*show_every - 1
    print("rows=%d,planned_rows=%d"%(rows,plan_rows))
    
    rows = plan_rows if (rows > plan_rows) else rows
    cols = 2
    print("final rows=%d"%rows)
    
    fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
    for i in range(rows):
        ind = start_with + i*show_every
        ax[i,0].set_title('src slice %d' % ind)
        ax[i,0].imshow(srcimg[ind],cmap='gray')
        ax[i,0].axis('off')
        
        ax[i,1].set_title('truth seg slice %d' % ind)
        ax[i,1].imshow(segimg[ind],cmap='gray')
        ax[i,1].axis('off')
    # 这句话一定要在show之前写,否则show函数之后会创建新的空白图
    name = "../LITS/crop/"+str(index)+".png"
    plt.savefig(name)
    #    plt.show()
    
    
    def transform_ctdata(image, windowWidth, windowCenter, normal=False):
        """
        注意,这个函数的self.image一定得是float类型的,否则就无效!
        return: trucated image according to window center and window width
        """
        minWindow = float(windowCenter) - 0.5*float(windowWidth)
        newimg = (image - minWindow) / float(windowWidth)
        newimg[newimg < 0] = 0
        newimg[newimg > 1] = 1
        if not normal:
            newimg = (newimg * 255).astype('uint8')
        return newimg
    
    import cv2   
    def clahe_equalized(imgs):
    assert (len(imgs.shape)==3)  #3D arrays
    #create a CLAHE object (Arguments are optional).
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    imgs_equalized = np.empty(imgs.shape)
    for i in range(len(imgs)):
        imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
    return imgs_equalized
    
    # 根据肝脏真实分割图,将原始图片进行裁剪为 以肝脏为中心,指定宽、高的图片
    def crop_images_func(refer_images, target_images, target_tumors):
    maxw=maxh=0
    assert refer_images.shape == target_images.shape == target_tumors.shape
    
    crop_images = []
    crop_tumors = []
    
    for i in range(refer_images.shape[0]):
        # Create figure and axes
    #        fig,ax = plt.subplots(1)
        
        mask = refer_images[i]
        
        # find coordinates of liver
        coor = np.nonzero(mask) 
        xmin = coor[0][0] # x代表了行
        xmax = coor[0][-1]
        coor[1].sort() # 直接改变原数组,没有返回值
        ymin = coor[1][0]
        ymax = coor[1][-1]
        
        width_center = (ymax + ymin) // 2
        height_center = (xmax + xmin) // 2
        
        # pre-parameter: height:266, width:334
        # 参数的选定:是之前随机后,挑出的最大值,然后适当扩大后的结果
        height = 280
        width = 360
        istart = int(height_center - height/2)
        
        #注意逻辑!
        if istart < 0:
            istart = 0
            iend = height
        else:
            iend = int(istart + height)
        if iend > 512:
            istart = 512 - height
            iend = 512
            
        jstart = int(width_center - width/2)
        if jstart < 0:
            jstart = 0
            jend = width
            
        jend = int(jstart + width)
        
        if jend > 512:
            jstart = 512 - width
            jend = 512
    
    #        print("[%d:%d,%d:%d]"%(istart,iend,jstart,jend))
    
        mask_crop = target_images[i,istart:iend,jstart:jend]   
        tumors_crop = target_tumors[i,istart:iend,jstart:jend]
        
    #        ax.imshow(mask_crop,cmap=plt.cm.gray)
        
        crop_images.append(mask_crop)
        crop_tumors.append(tumors_crop)
        
    crop_images = np.asarray(crop_images)
    crop_tumors = np.asarray(crop_tumors)
    return (crop_images,crop_tumors)
    
    """
    训练数据
    第一步:读取数据
    第二步:找到具有肿瘤的切片(具有肿瘤的切片一定是肝脏也在的)
    第三步:预处理
       窗口化、自适应直方图均衡化、归一化、颜色翻转、ROI
    第四步:裁剪
    第五步:将数据写入文件
    """
    # 工具类在博客里有写
    from HDF5DatasetWriter import HDF5DatasetWriter
    
    dataset = HDF5DatasetWriter(image_dims=(1967, 280, 360, 1),
                            mask_dims=(1967, 280, 360, 1),
                            outputPath="../data_train/LITS_train_tumor_crop.h5")
    
    
    count = 0
    for i in level12:
    seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
    segimg = sitk.GetArrayFromImage(seg)
    src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
    srcimg = sitk.GetArrayFromImage(src)
    
    seg_liver = segimg.copy()
    seg_liver[seg_liver>0] = 1
    
    seg_tumorimage = segimg.copy()
    seg_tumorimage[segimg == 1] = 0
    seg_tumorimage[segimg == 2] = 1
    
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    start,end = getRangeImageDepth(seg_tumorimage)
    if start==0 and end == 0:
        print("continue")
        continue
    print("start:",start," end:",end)
    
    theshold = 1e-3 # 最小阈值
    
    filter_index = []
    
    for j in range(start, end+1):
        if np.mean(seg_tumorimage[j]) > theshold:
            filter_index.append(j)
            
    if len(filter_index)<1:
        continue
    
    count += len(filter_index)
    
    #    print("picked index:",filter_index)
       
    srcimg = srcimg[filter_index]
    seg_liver = seg_liver[filter_index]
    seg_tumorimage = seg_tumorimage[filter_index]
    #    
    srcimg = transform_ctdata(srcimg, 250,75,normal=False)
    srcimg = clahe_equalized(srcimg)
    srcimg /= 255.
    
    # 注意,下面这两步顺序一定不能变,否则就不能达到正确的颜色翻转效果了
    srcimg = 1- srcimg
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
    
    #    show_src_seg(crop_images,crop_tumors,index=i)
    
    crop_images = np.expand_dims(crop_images,axis=-1)
    crop_tumors = np.expand_dims(crop_tumors,axis=-1)
    
    
    dataset.add(crop_images,crop_tumors)
    
    print(dataset.close())
    
    
    dataset = HDF5DatasetWriter(image_dims=(133, 280, 360, 1),
                            mask_dims=(133, 280, 360, 1),
                            outputPath="../data_train/LITS_val_tumor_crop.h5")
    
    
    count = 0
    for i in test_list:
    seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
    segimg = sitk.GetArrayFromImage(seg)
    src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
    srcimg = sitk.GetArrayFromImage(src)
    
    seg_liver = segimg.copy()
    seg_liver[seg_liver>0] = 1
    
    seg_tumorimage = segimg.copy()
    seg_tumorimage[segimg == 1] = 0
    seg_tumorimage[segimg == 2] = 1
    
    
    
    
    start,end = getRangeImageDepth(seg_tumorimage)
    if start==0 and end == 0:
        print("continue")
        continue
    print("start:",start," end:",end)
    
    theshold = 1e-3 # 最小阈值
    
    filter_index = []
    
    for j in range(start, end+1):
        if np.mean(seg_tumorimage[j]) > theshold:
            filter_index.append(j)
            
    if len(filter_index)<1:
        continue
    
    count += len(filter_index)
    
    
    #    print("picked index:",filter_index)
       
    srcimg = srcimg[filter_index]
    seg_liver = seg_liver[filter_index]
    seg_tumorimage = seg_tumorimage[filter_index]
    
    srcimg = transform_ctdata(srcimg, 250,75,normal=False)
    srcimg = clahe_equalized(srcimg)
    srcimg /= 255.
    
    srcimg = 1- srcimg
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    
    crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
    
    show_src_seg(crop_images,crop_tumors,index=i)
    
    crop_images = np.expand_dims(crop_images,axis=-1)
    crop_tumors = np.expand_dims(crop_tumors,axis=-1)
    
    
    dataset.add(crop_images,crop_tumors)
    
    print(dataset.close())
    
    """ 
    # 测试
    from HDF5DatasetGenerator import HDF5DatasetGenerator
    
    outputPath = '../data_train/LITS_train_tumor_crop.h5'
    val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
    BATCH_SIZE = 8
    
    reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
    train_iter = reader.generator()
    
    src,seg = train_iter.__next__()
    
    src = np.squeeze(src)
    seg = np.squeeze(seg)
    
    sample_stack(src)
    sample_stack(seg)
    """
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/cO2dvaUbwg50Mio3yVx7PG4CIjZt.png)

第二部分:

复制代码
    # -*- coding: utf-8 -*-
    import os
    import sys
    import numpy as np
    import random
    import math
    import tensorflow as tf
    from HDF5DatasetGenerator import HDF5DatasetGenerator
    from keras.models import Model
    from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D,ZeroPadding2D
    from keras.optimizers import Adam
    from keras.callbacks import ModelCheckpoint
    from keras import backend as K
    from skimage import io
    from keras import losses
    
    # Set some parameters
    IMG_WIDTH = 360
    IMG_HEIGHT = 280
    IMG_CHANNELS = 1
    TOTAL = 1967 # 总共的训练数据
    TOTAL_VAL = 133 # 总共的validation数据
    outputPath = '../data_train/LITS_train_tumor_crop.h5' # 训练文件
    val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
    #checkpoint_path = 'model.ckpt'
    BATCH_SIZE = 4
    
    K.set_image_data_format('channels_last')
    
    def dice_coef(y_true, y_pred):
    print("in loss function, y_true shape:",y_true.shape)
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    
    
    def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)
    
    # 疑问,不知道除n的操作是否该写?还是说keras会自动取平均
    def weighted_binary_cross_entropy_loss(y_true, y_pred):
    """
    # 跟标准的结果差不多 0.068760,该结果:0.0685122
    print("y_pred shape ",K.int_shape(y_pred))
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    ce = - K.mean(y_true*K.log(K.epsilon()+y_pred) + (1-y_true)*K.log(1-y_pred+K.epsilon()))
    return ce
    """
    
    """
    # 跟标准结果一样
    b_ce = K.binary_crossentropy(y_true, y_pred)
    return b_ce
    """
    # 不确定是否正确
    
    # Calculate the binary crossentropy
    b_ce = K.binary_crossentropy(y_true, y_pred)
    one_weight = K.mean(y_true)
    zero_weight = 1 - one_weight
    #    weight = zero_weight / one_weight
    # Apply the weights
    weight_vector = y_true * zero_weight  + (1. - y_true) * one_weight
    weighted_b_ce = weight_vector * b_ce
    
    # Return the mean error
    return K.mean(weighted_b_ce)
    
    # 不确定是否正确?
    def weighted_dice_loss(y_true, y_pred):
    mean = K.mean(y_true)
    w_1 = 1/mean**2
    w_0 = 1/(1-mean)**2
    y_true_f_1 = K.flatten(y_true)
    y_pred_f_1 = K.flatten(y_pred)
    y_true_f_0 = K.flatten(1-y_true)
    y_pred_f_0 = K.flatten(1-y_pred)
    
    intersection_0 = K.sum(y_true_f_0 * y_pred_f_0)
    intersection_1 = K.sum(y_true_f_1 * y_pred_f_1)
    
    return -2 * (w_0 * intersection_0 +w_1 * intersection_1)\
          / ((w_0 * (K.sum(y_true_f_0) + K.sum(y_pred_f_0))) \
             + (w_1 * (K.sum(y_true_f_1) + K.sum(y_pred_f_1))))
    
    def get_crop_shape(target, refer):
        # width, the 3rd dimension
    #        print(target.shape)
    #        print(refer._keras_shape)
        cw = (target._keras_shape[2] - refer._keras_shape[2])
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target._keras_shape[1] - refer._keras_shape[1])
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)
    
        return (ch1, ch2), (cw1, cw2)
    
    def get_unet():
    inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    
    up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
    
    ch, cw = get_crop_shape(conv4, up_conv5)
    #    print("ch,cw",ch,cw)
    #    
    up_conv5 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv5)
    up6 = concatenate([up_conv5, conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    
    up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
    
    ch, cw = get_crop_shape(conv3, up_conv6)
    up_conv6 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv6)
    #    
    up7 = concatenate([up_conv6, conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    
    up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    up_conv7 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv7)
    
    up8 = concatenate([up_conv7, conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    
    up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    up_conv8 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv8)
    
    up9 = concatenate([up_conv8, conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    
    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    
    model = Model(inputs=[inputs], outputs=[conv10])
    
    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
    
    return model
    
    
    
    class UnetModel:
    
    def predict(self):
        model = get_unet()
        model.load_weights('weights2.h5')
        test_reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=30)
        test_iter = test_reader.generator()
        fixed_test_images, fixed_test_masks = test_iter.__next__()
    #        print(model.evaluate(fixed_test_images, fixed_test_masks,BATCH_SIZE*5))
        
        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
        test_reader.close()
        print('-' * 30)
        print('Saving predicted masks to files...')
        print('-' * 30)
        pred_dir = 'step2_train1'
        if not os.path.exists(pred_dir):
            os.mkdir(pred_dir)
        i = 0
        
        
        for image in imgs_mask_test:
            image = (image[:, :, 0] * 255.).astype(np.uint8)
            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
            i += 1
        
    
    def train_and_predict(self):
        
        reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
        train_iter = reader.generator()
        
        test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
        test_iter = test_reader.generator()
        
    #   
        
        model = get_unet()
        model_checkpoint = ModelCheckpoint('weights2.h5', monitor='val_loss', save_best_only=True)
        model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
                            validation_data=test_iter, validation_steps=int(TOTAL_VAL/BATCH_SIZE) ,callbacks=[model_checkpoint])
    #        
        reader.close()
        test_reader.close()
        
        
    #        print('-'*30)
    #        print('Loading and preprocessing test data...')
    #        print('-'*30)
    #        
    #        print('-'*30)
    #        print('Loading saved weights...')
    #        print('-'*30)
    #        model.load_weights('weights.h5')
    #    
    #        print('-'*30)
    #        print('Predicting masks on test data...')
    #        print('-'*30)
    #        
    #        
    #        
    #        # 不懂这儿为什么会是np格式
    #        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
    #        np.save('imgs_mask_test.npy', imgs_mask_test)
    #    
    #        print('-' * 30)
    #        print('Saving predicted masks to files...')
    #        print('-' * 30)
    #        pred_dir = 'preds'
    #        if not os.path.exists(pred_dir):
    #            os.mkdir(pred_dir)
    #        i = 0
    #        
    #        
    #        for image in imgs_mask_test:
    #            image = (image[:, :, 0] * 255.).astype(np.uint8)
    #            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
    #            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
    #            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
    #            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
    #            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
    #            i += 1
    
    #model = get_unet()
    #model.summary()
    unet = UnetModel()
    #unet.train_and_predict()
    unet.train_and_predict()
    #print("test")
       
        
    
    
    AI生成项目python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/oVmSkXa8nOICbuMphJtT1ByZfF2E.png)

全部评论 (0)

还没有任何评论哟~