医学图像分割 基于深度学习的肝脏肿瘤分割 实战(二)
在医学图像分割(基于深度学习)实战系列文章之一《肝癌肿瘤分割》中(原文链接:
第一次 实验中发现网络性能(dice系数)在训练集上始终在40%以下徘徊,在测试集中的准确率仅为约10%,这些结果与预期相差较大。经过仔细排查发现可能与数据质量有关,在查阅大量原始医学影像后发现许多病例即使在专业医师看来也难以区分病变区域边界。基于此猜想进行了一系列实验验证,并将相关结论详细记录于博客文章《医学图像预处理(五) 器官与病灶的直方图》:进一步证实了不同病人肝脏与肿瘤区域的hu值分布高度重合。
通过实验验证了猜想;随后又进行了相关实验;于是转向使用LiTS2017数据库;其数据规模更为庞大;然后根据病人体内肝脏与肿瘤的直方图分布情况;手动完成了对130名病人的分类划分:其中level 1为肝脏与肿瘤对比度最高者;而level 3则为对比度几乎可测者(即几乎没有);并最终选取了level 1及以下进行训练数据集构建
当然,在本次实验中,在数据预处理阶段以及ROI识别阶段均采用了与前一致的方法,并未出现明显差异。
然而,在本次实验中,在相同的条件下设置下……仍然出现了同样的结果。
第二次 :首先在训练集上骰系数能够达到10%,但随后又骤降至接近于零的状态。对此感到困惑,并误以为可能是发生了梯度消失或梯度爆炸导致的问题。然而,在深入思考后发现了一个看似奇怪实则严肃的问题:即ROI操作通过逻辑与运算将真实肝脏分割结果与原图结合在一起后,在非感兴趣区域中实现了黑色化处理;然而由于肿瘤区域通常具有较低的灰度值这一特性,在这一区域内也呈现出黑色状态(相对于肝脏而言)。为了区分感兴趣区域与其他部分的目标差异性,则采用了窗口值等技术手段对非目标区域进行黑色化处理;最终使得肝脏区域呈现为灰白色(而其他器官则保持白色状态)。于是尝试进行了一项大胆的实验:将肝脏区域设置为灰色,并将其边界外的肿瘤设置为白色。(进行了颜色翻转)


本次实验中,在训练集中实现了dice系数约为90%,而在测试集上的dice系数约为70%。总体而言表现尚可,并且成功验证了猜想的可行性。这一结果进一步验证了数据为王的原则:正如所言,在正确的方向上播种必定会有相应的收获。
本实验采用的代码框架主要包含两大部分:第一阶段的任务主要是数据预处理工作,在本地环境下运行;第二阶段则涉及模型架构搭建与训练过程,在服务器环境下完成模型的构建与训练工作。其中,在服务器环境下完成模型的构建与训练工作时,请注意使用ubuntu16.04系统并配置好tensorflow-gpu环境变量
第一部分:
(注:有关h5文件读写的工具类也放在了博客里)
# -*- coding: utf-8 -*-
"""
根据LITS_check.py,观察结果
根据肝脏与肿瘤的对比度,将病人分成 3 level
1 level:对比度最高(随机选出两个作为validation集)
2 level: 对比度中等(随机选出两个作为validation集)
3 level:对比度最低
"""
# theshold = 1e-3, total=755
# 81,125作为测试集
level_1 = [0,1,22,23,25,26,27,31,37,46,49,50,55,57,58,59,61,62,
63,64,66,78,79,82,83,90,92,95,99,109,112,124]
#level_1 = [63,64,66,78,79,81,82,83]
# theshold = 1e-3, total= 1345
# 11,110作为测试集
level_2 = [2,7,8,9,10,12,14,15,17,28,35,40,42,
53,56,69,76,93,96,101,111,113,117]
level12 = level_1 + level_2
level12.sort()
test_list = [11,81,110,125]
a = [i for i in range(130)]
level_3 =list(set(a)-set(level_1)-set(level_2))
# sort方法直接改变原列表,无返回值
level_3.sort()
"""
将level_1的其余图片观察,确定是否对比度高
观察后确定窗口值为:[-50,200]
"""
onServer = False
if onServer:
niiSegPath = './LITS17/seg/'
niiImagePath = './LITS17/ct/'
else:
niiSegPath = '~/Documents/LITS17/seg/'
niiImagePath = '~/Documents/LITS17/ct/'
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
def getRangeImageDepth(image):
z = np.any(image, axis=(1,2)) # z.shape:(depth,)
#print("all index:",np.where(z)[0])
if len(np.where(z)[0]) >0:
startposition,endposition = np.where(z)[0][[0,-1]]
else:
startposition = endposition = 0
return startposition, endposition
def sample_stack(stack, name="images.png", rows=4, cols=2, start_with=0, show_every=1):
fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
if rows==1 or cols==1 :
nums = rows*cols
for i in range(nums):
ind = start_with + i*show_every
ax[int(i % nums)].set_title('slice %d' % ind)
ax[int(i % nums)].imshow(stack[ind],cmap='gray')
ax[int(i % nums)].axis('off')
else:
for i in range(rows*cols):
ind = start_with + i*show_every
ax[int(i/cols),int(i % cols)].set_title('slice %d' % ind)
ax[int(i/cols),int(i % cols)].imshow(stack[ind],cmap='gray')
ax[int(i/cols),int(i % cols)].axis('off')
# 这句话一定要在show之前写,否则show函数之后会创建新的空白图
# plt.savefig(name)
plt.show()
"""
工具函数,左边原图,右边真实分割图
"""
def show_src_seg(srcimg, segimg,index, rows=3,start_with=0, show_every=1):
assert srcimg.shape == segimg.shape
rows = srcimg.shape[0]
plan_rows = start_with + rows*show_every - 1
print("rows=%d,planned_rows=%d"%(rows,plan_rows))
rows = plan_rows if (rows > plan_rows) else rows
cols = 2
print("final rows=%d"%rows)
fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
for i in range(rows):
ind = start_with + i*show_every
ax[i,0].set_title('src slice %d' % ind)
ax[i,0].imshow(srcimg[ind],cmap='gray')
ax[i,0].axis('off')
ax[i,1].set_title('truth seg slice %d' % ind)
ax[i,1].imshow(segimg[ind],cmap='gray')
ax[i,1].axis('off')
# 这句话一定要在show之前写,否则show函数之后会创建新的空白图
name = "../LITS/crop/"+str(index)+".png"
plt.savefig(name)
# plt.show()
def transform_ctdata(image, windowWidth, windowCenter, normal=False):
"""
注意,这个函数的self.image一定得是float类型的,否则就无效!
return: trucated image according to window center and window width
"""
minWindow = float(windowCenter) - 0.5*float(windowWidth)
newimg = (image - minWindow) / float(windowWidth)
newimg[newimg < 0] = 0
newimg[newimg > 1] = 1
if not normal:
newimg = (newimg * 255).astype('uint8')
return newimg
import cv2
def clahe_equalized(imgs):
assert (len(imgs.shape)==3) #3D arrays
#create a CLAHE object (Arguments are optional).
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
imgs_equalized = np.empty(imgs.shape)
for i in range(len(imgs)):
imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
return imgs_equalized
# 根据肝脏真实分割图,将原始图片进行裁剪为 以肝脏为中心,指定宽、高的图片
def crop_images_func(refer_images, target_images, target_tumors):
maxw=maxh=0
assert refer_images.shape == target_images.shape == target_tumors.shape
crop_images = []
crop_tumors = []
for i in range(refer_images.shape[0]):
# Create figure and axes
# fig,ax = plt.subplots(1)
mask = refer_images[i]
# find coordinates of liver
coor = np.nonzero(mask)
xmin = coor[0][0] # x代表了行
xmax = coor[0][-1]
coor[1].sort() # 直接改变原数组,没有返回值
ymin = coor[1][0]
ymax = coor[1][-1]
width_center = (ymax + ymin) // 2
height_center = (xmax + xmin) // 2
# pre-parameter: height:266, width:334
# 参数的选定:是之前随机后,挑出的最大值,然后适当扩大后的结果
height = 280
width = 360
istart = int(height_center - height/2)
#注意逻辑!
if istart < 0:
istart = 0
iend = height
else:
iend = int(istart + height)
if iend > 512:
istart = 512 - height
iend = 512
jstart = int(width_center - width/2)
if jstart < 0:
jstart = 0
jend = width
jend = int(jstart + width)
if jend > 512:
jstart = 512 - width
jend = 512
# print("[%d:%d,%d:%d]"%(istart,iend,jstart,jend))
mask_crop = target_images[i,istart:iend,jstart:jend]
tumors_crop = target_tumors[i,istart:iend,jstart:jend]
# ax.imshow(mask_crop,cmap=plt.cm.gray)
crop_images.append(mask_crop)
crop_tumors.append(tumors_crop)
crop_images = np.asarray(crop_images)
crop_tumors = np.asarray(crop_tumors)
return (crop_images,crop_tumors)
"""
训练数据
第一步:读取数据
第二步:找到具有肿瘤的切片(具有肿瘤的切片一定是肝脏也在的)
第三步:预处理
窗口化、自适应直方图均衡化、归一化、颜色翻转、ROI
第四步:裁剪
第五步:将数据写入文件
"""
# 工具类在博客里有写
from HDF5DatasetWriter import HDF5DatasetWriter
dataset = HDF5DatasetWriter(image_dims=(1967, 280, 360, 1),
mask_dims=(1967, 280, 360, 1),
outputPath="../data_train/LITS_train_tumor_crop.h5")
count = 0
for i in level12:
seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
segimg = sitk.GetArrayFromImage(seg)
src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
srcimg = sitk.GetArrayFromImage(src)
seg_liver = segimg.copy()
seg_liver[seg_liver>0] = 1
seg_tumorimage = segimg.copy()
seg_tumorimage[segimg == 1] = 0
seg_tumorimage[segimg == 2] = 1
# 只选择ROI区域
srcimg = srcimg * seg_liver
start,end = getRangeImageDepth(seg_tumorimage)
if start==0 and end == 0:
print("continue")
continue
print("start:",start," end:",end)
theshold = 1e-3 # 最小阈值
filter_index = []
for j in range(start, end+1):
if np.mean(seg_tumorimage[j]) > theshold:
filter_index.append(j)
if len(filter_index)<1:
continue
count += len(filter_index)
# print("picked index:",filter_index)
srcimg = srcimg[filter_index]
seg_liver = seg_liver[filter_index]
seg_tumorimage = seg_tumorimage[filter_index]
#
srcimg = transform_ctdata(srcimg, 250,75,normal=False)
srcimg = clahe_equalized(srcimg)
srcimg /= 255.
# 注意,下面这两步顺序一定不能变,否则就不能达到正确的颜色翻转效果了
srcimg = 1- srcimg
# 只选择ROI区域
srcimg = srcimg * seg_liver
crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
# show_src_seg(crop_images,crop_tumors,index=i)
crop_images = np.expand_dims(crop_images,axis=-1)
crop_tumors = np.expand_dims(crop_tumors,axis=-1)
dataset.add(crop_images,crop_tumors)
print(dataset.close())
dataset = HDF5DatasetWriter(image_dims=(133, 280, 360, 1),
mask_dims=(133, 280, 360, 1),
outputPath="../data_train/LITS_val_tumor_crop.h5")
count = 0
for i in test_list:
seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
segimg = sitk.GetArrayFromImage(seg)
src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
srcimg = sitk.GetArrayFromImage(src)
seg_liver = segimg.copy()
seg_liver[seg_liver>0] = 1
seg_tumorimage = segimg.copy()
seg_tumorimage[segimg == 1] = 0
seg_tumorimage[segimg == 2] = 1
start,end = getRangeImageDepth(seg_tumorimage)
if start==0 and end == 0:
print("continue")
continue
print("start:",start," end:",end)
theshold = 1e-3 # 最小阈值
filter_index = []
for j in range(start, end+1):
if np.mean(seg_tumorimage[j]) > theshold:
filter_index.append(j)
if len(filter_index)<1:
continue
count += len(filter_index)
# print("picked index:",filter_index)
srcimg = srcimg[filter_index]
seg_liver = seg_liver[filter_index]
seg_tumorimage = seg_tumorimage[filter_index]
srcimg = transform_ctdata(srcimg, 250,75,normal=False)
srcimg = clahe_equalized(srcimg)
srcimg /= 255.
srcimg = 1- srcimg
# 只选择ROI区域
srcimg = srcimg * seg_liver
crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
show_src_seg(crop_images,crop_tumors,index=i)
crop_images = np.expand_dims(crop_images,axis=-1)
crop_tumors = np.expand_dims(crop_tumors,axis=-1)
dataset.add(crop_images,crop_tumors)
print(dataset.close())
"""
# 测试
from HDF5DatasetGenerator import HDF5DatasetGenerator
outputPath = '../data_train/LITS_train_tumor_crop.h5'
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
BATCH_SIZE = 8
reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
train_iter = reader.generator()
src,seg = train_iter.__next__()
src = np.squeeze(src)
seg = np.squeeze(seg)
sample_stack(src)
sample_stack(seg)
"""
AI生成项目python

第二部分:
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import random
import math
import tensorflow as tf
from HDF5DatasetGenerator import HDF5DatasetGenerator
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D,ZeroPadding2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from skimage import io
from keras import losses
# Set some parameters
IMG_WIDTH = 360
IMG_HEIGHT = 280
IMG_CHANNELS = 1
TOTAL = 1967 # 总共的训练数据
TOTAL_VAL = 133 # 总共的validation数据
outputPath = '../data_train/LITS_train_tumor_crop.h5' # 训练文件
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
#checkpoint_path = 'model.ckpt'
BATCH_SIZE = 4
K.set_image_data_format('channels_last')
def dice_coef(y_true, y_pred):
print("in loss function, y_true shape:",y_true.shape)
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
# 疑问,不知道除n的操作是否该写?还是说keras会自动取平均
def weighted_binary_cross_entropy_loss(y_true, y_pred):
"""
# 跟标准的结果差不多 0.068760,该结果:0.0685122
print("y_pred shape ",K.int_shape(y_pred))
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
ce = - K.mean(y_true*K.log(K.epsilon()+y_pred) + (1-y_true)*K.log(1-y_pred+K.epsilon()))
return ce
"""
"""
# 跟标准结果一样
b_ce = K.binary_crossentropy(y_true, y_pred)
return b_ce
"""
# 不确定是否正确
# Calculate the binary crossentropy
b_ce = K.binary_crossentropy(y_true, y_pred)
one_weight = K.mean(y_true)
zero_weight = 1 - one_weight
# weight = zero_weight / one_weight
# Apply the weights
weight_vector = y_true * zero_weight + (1. - y_true) * one_weight
weighted_b_ce = weight_vector * b_ce
# Return the mean error
return K.mean(weighted_b_ce)
# 不确定是否正确?
def weighted_dice_loss(y_true, y_pred):
mean = K.mean(y_true)
w_1 = 1/mean**2
w_0 = 1/(1-mean)**2
y_true_f_1 = K.flatten(y_true)
y_pred_f_1 = K.flatten(y_pred)
y_true_f_0 = K.flatten(1-y_true)
y_pred_f_0 = K.flatten(1-y_pred)
intersection_0 = K.sum(y_true_f_0 * y_pred_f_0)
intersection_1 = K.sum(y_true_f_1 * y_pred_f_1)
return -2 * (w_0 * intersection_0 +w_1 * intersection_1)\
/ ((w_0 * (K.sum(y_true_f_0) + K.sum(y_pred_f_0))) \
+ (w_1 * (K.sum(y_true_f_1) + K.sum(y_pred_f_1))))
def get_crop_shape(target, refer):
# width, the 3rd dimension
# print(target.shape)
# print(refer._keras_shape)
cw = (target._keras_shape[2] - refer._keras_shape[2])
assert (cw >= 0)
if cw % 2 != 0:
cw1, cw2 = int(cw/2), int(cw/2) + 1
else:
cw1, cw2 = int(cw/2), int(cw/2)
# height, the 2nd dimension
ch = (target._keras_shape[1] - refer._keras_shape[1])
assert (ch >= 0)
if ch % 2 != 0:
ch1, ch2 = int(ch/2), int(ch/2) + 1
else:
ch1, ch2 = int(ch/2), int(ch/2)
return (ch1, ch2), (cw1, cw2)
def get_unet():
inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
ch, cw = get_crop_shape(conv4, up_conv5)
# print("ch,cw",ch,cw)
#
up_conv5 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv5)
up6 = concatenate([up_conv5, conv4], axis=3)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
ch, cw = get_crop_shape(conv3, up_conv6)
up_conv6 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv6)
#
up7 = concatenate([up_conv6, conv3], axis=3)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
ch, cw = get_crop_shape(conv2, up_conv7)
up_conv7 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv7)
up8 = concatenate([up_conv7, conv2], axis=3)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
ch, cw = get_crop_shape(conv1, up_conv8)
up_conv8 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv8)
up9 = concatenate([up_conv8, conv1], axis=3)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
return model
class UnetModel:
def predict(self):
model = get_unet()
model.load_weights('weights2.h5')
test_reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=30)
test_iter = test_reader.generator()
fixed_test_images, fixed_test_masks = test_iter.__next__()
# print(model.evaluate(fixed_test_images, fixed_test_masks,BATCH_SIZE*5))
imgs_mask_test = model.predict(fixed_test_images, verbose=1)
test_reader.close()
print('-' * 30)
print('Saving predicted masks to files...')
print('-' * 30)
pred_dir = 'step2_train1'
if not os.path.exists(pred_dir):
os.mkdir(pred_dir)
i = 0
for image in imgs_mask_test:
image = (image[:, :, 0] * 255.).astype(np.uint8)
gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
i += 1
def train_and_predict(self):
reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
train_iter = reader.generator()
test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
test_iter = test_reader.generator()
#
model = get_unet()
model_checkpoint = ModelCheckpoint('weights2.h5', monitor='val_loss', save_best_only=True)
model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
validation_data=test_iter, validation_steps=int(TOTAL_VAL/BATCH_SIZE) ,callbacks=[model_checkpoint])
#
reader.close()
test_reader.close()
# print('-'*30)
# print('Loading and preprocessing test data...')
# print('-'*30)
#
# print('-'*30)
# print('Loading saved weights...')
# print('-'*30)
# model.load_weights('weights.h5')
#
# print('-'*30)
# print('Predicting masks on test data...')
# print('-'*30)
#
#
#
# # 不懂这儿为什么会是np格式
# imgs_mask_test = model.predict(fixed_test_images, verbose=1)
# np.save('imgs_mask_test.npy', imgs_mask_test)
#
# print('-' * 30)
# print('Saving predicted masks to files...')
# print('-' * 30)
# pred_dir = 'preds'
# if not os.path.exists(pred_dir):
# os.mkdir(pred_dir)
# i = 0
#
#
# for image in imgs_mask_test:
# image = (image[:, :, 0] * 255.).astype(np.uint8)
# gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
# ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
# io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
# io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
# io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
# i += 1
#model = get_unet()
#model.summary()
unet = UnetModel()
#unet.train_and_predict()
unet.train_and_predict()
#print("test")
AI生成项目python

