医学图像分割实战——使用U-Net实现肾脏CT分割
使用U-Net实现肾脏CT分割
-
数据集准备
-
- 数据来源
- 数据预处理
-
网络结构及代码
-
- 网络结构
- 训练代码
-
训练过程
-
- 参数设置:
- 可视化
-
结果分析
数据集准备
数据来源
MICCAI KiTS19 Kidney Tumor Segmentation Challenge: https://kits19.grand-challenge.org>
在MICCAI 2019中,KiTS 2019是一项重要的竞赛项目.该项目的主要目标是通过3D-CT数据对肾脏及其肿瘤进行精确分割.官方提供了包含214例训练病例和86例测试病例的数据集.吸引了来自全球的845位参赛者参与了本次比赛.最终共有47支队伍提交了有效结果.该挑战目前仍处于开放状态,感兴趣的朋友可随时加入.该挑战计划于2021年再次举办,并将继续扩大其数据规模与应用场景.如对医学影像分析感兴趣的同学可随时关注后续动态.
感恩您在评论区分享的数据集!该集合为weixin_40621562老哥在博客中提供的百度网盘版本资源包,请您前往以下链接获取完整文件并使用d7jk作为提取密码解码:https://pan.baidu.com/s/1AOQDjPz9ye32DH-oDS0WDw
数据预处理
该研究库提供了3DCT图像数据集,并要求我们使用其中的数据来训练最简单的2D U-Net模型。为了实现这一目标,我们需要从该3DCT体数据集中提取相应的二维切片进行处理。其GitHub存储库中提供了便于使用的可视化辅助工具。这些工具能够帮助我们直观地观察到通过nibabel库对.nii格式的体数据进行处理以获得对应的二维.png切片文件。在实验过程中,对于提取到的二维切片样本进行筛选和预处理后才能进一步分析。需要注意的是,在分割结果中分为背景区域、肾脏区域以及肿瘤区域三类情况。然而,在本研究中我们的输出目标仅仅是区分背景与肾脏两部分组织,并非全面分类所有可能存在的类别(如正常细胞)。为此,在可视化过程中我们将肿瘤标记为肾脏的一部分以简化问题设定
包含有210组3D CT扫描数据,在每一组CT扫描中提取了10片切片,最终获得了2,100张二维形式的png文件图像。

网络结构及代码
网络结构
U-Net的架构是最基本的encoder-decoder架构,并且通过跳跃连接实现了额外的联系。详细的网络架构请参考我的另一篇博客:《深度学习图像语义分割网络总结:U-Net与V-Net的Pytorch实现》
训练代码
代码参照了github上的https://github.com/milesial/Pytorch-UNet
import argparse
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from eval import eval_net
from unet import UNet
from visdom import Visdom
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\ train_choose\slice_png'
dir_mask = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\ train_choose\mask_png'
dir_checkpoint = 'checkpoints/'
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.2,
save_cp=True,
img_scale=1):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
#writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
viz=Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))
viz.line([0.], [0.], win='learning_rate', opts=dict(title='learning_rate'))
viz.line([0.], [0.], win='Dice/test', opts=dict(title='Dice/test'))
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
#print('mask_pred',masks_pred.shape)
#print('masks_pred',masks_pred.shape)
#print('true_masks', true_masks.shape)
viz.image(imgs, win='imgs/train')
viz.image(true_masks, win='masks/true/train')
viz.image(masks_pred, win='masks/pred/train')
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
#writer.add_scalar('Loss/train', loss.item(), global_step)
viz.line([loss.item()],[global_step],win='train_loss',update='append')
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
#nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (n_train // (10 * batch_size)) == 0:
# for tag, value in net.named_parameters():
# tag = tag.replace('.', '/')
# writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
# writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
val_score = eval_net(net, val_loader, device)
scheduler.step(val_score)
#writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
viz.line([optimizer.param_groups[0]['lr']], [global_step], win='learning_rate', update='append')
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
#writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
#writer.add_scalar('Dice/test', val_score, global_step)
viz.line([val_score], [global_step], win='Dice/test', update='append')
viz.image(imgs, win='images')
if net.n_classes == 1:
print('true_mask',true_masks.shape,true_masks.type)
viz.image( true_masks, win='masks/true')
print('pred',(torch.sigmoid(masks_pred) > 0.5).squeeze(0).shape)
viz.images((torch.sigmoid(masks_pred) > 0.5),win='masks/pred')
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
#writer.close()
def eval_net(net, loader, device):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
mask_type = torch.float32 #if net.n_classes == 1 else torch.long
n_val = len(loader) # the number of batch
tot = 0
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for batch in loader:
imgs, true_masks = batch['image'], batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=mask_type)
with torch.no_grad():
mask_pred = net(imgs)#['out']
# if net.n_classes > 1:
# tot += F.cross_entropy(mask_pred, true_masks).item()
# else:
pred = torch.sigmoid(mask_pred)
pred = (pred > 0.5).float()
tot += dice_coeff(pred, true_masks).item()
pbar.update()
net.train()
return tot / n_val
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=1, n_classes=1, bilinear=True)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
训练过程
参数设置:
训练集与验证集划分比例为8:2(1680:420)
批次大小设置为2
学习率采用torch.optim.lr_scheduler.ReduceLROnPlateau策略,在神经网络的性能指标持续下降的情况下,则通过降低模型的学习率来提升模型性能效果。
损失函数采用BCEWithLogitsLoss进行计算,该损失函数用于衡量目标值与预测输出之间的二进制交叉熵差异
可视化
该系统采用Visdom技术实现可视化界面。在初始阶段的训练状态下,左侧展示真实的目标mask,并右侧则显示网络预测的结果。初步观察发现网络预测效果欠佳。

在完成一轮训练后,请看下图中的具体结果展示。其中以红色线条标注的是训练过程(True),以蓝色线条标注的是验证过程(Validation)。该图表包含了原始图像、真实mask T以及预测mask P四个部分。需要注意的是,在这一过程中预测mask的变化主要取决于是否进行了二值化处理。

第四轮训练之后的结果,预测的mask与真实的mask已经很接近了。

结果分析

实验结果显示Dice系数为0.832分。
