基于U-net的肝脏肿瘤分割实战(Pytorch实现)
发布时间
阅读量:
阅读量
这是我的本科毕业设计中的一个重要组成部分。我去年基于TensorFlow开发了一种模型,在最近几天我又尝试使用PyTorch实现了类似功能。肝脏肿瘤分割作为医学影像分割中的核心内容具有重要意义——它能够从CT或MRI影像中准确提取肝脏区域及其肿瘤区域,在临床治疗方案制定中仍然发挥着关键作用。基于最经典的U-Net架构进行设计,在这一领域仍具重要价值。
网络结构

跳跃连接部分主要涉及Encoder网络提取的有效特征图与Decoder网络解码器端产生的特征图之间的信息传递机制,在通道维度上进行融合处理
数据集介绍
由于医学影像的独特属性,在大多数情况下其详细数据不会被广泛分享。现将一个公开获取的数据库——3D-IRCADB介绍如下:该数据库包含来自20名患者的CT扫描图像序列;具体而言,每个患者大约拥有几百张左右的扫描图像。这些扫描通常以 DICOM 格式存储;此外,在进行深度学习分析时,默认也会使用对应的掩膜(mask),这些掩膜同样采用 DICOM 格式存储。为了便于后续处理和分析,在此将原始 DICOM 格式的图像转换为 PNG 格式保存;长图:image:

label:

转换的过程会用到一个第三方库,不详细介绍,会重新写一篇博客。
整个训练流程及其网络架构与我之前的两篇博客几乎一模一样;仅对dataloader进行了稍作改动;这里不做详细说明;完整的源代码可以在我的GitHub仓库中获取;基于2000张图像经过80个epoch的训练所得结果看上去相当出色。
Dataset.py
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/9 下午1:28
"""
使用的是视网膜血管分割的数据集,训练集就二十张图像
"""
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
class _3Dircadb_Dataset(Dataset):
def __init__(self,data_path):
self.data_path = data_path
self.image_path = glob.glob(os.path.join(data_path,'image/*.png'))
self.label_path = glob.glob(os.path.join(data_path,'label/*.png'))
def augment(self,image,mode):
"""
:param image:
:param mode: 1 :水平翻转 0 : 垂直翻转 -1 水平+垂直翻转
:return:
"""
file = cv2.flip(image,mode)
return file
def __len__(self):
return len(self.image_path)
def __getitem__(self,index):
image_path = self.image_path[index]
label_path = self.label_path[index]
#读取
image = cv2.imread(image_path)
label = cv2.imread(label_path)
#转为灰度图
image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
label = cv2.cvtColor(label,cv2.COLOR_BGR2GRAY)
# 随机进行数据增强,2时不做数据增强
mode = random.choice([-1, 0, 1, 2])
if mode != 2:
image = self.augment(image, mode)
label = self.augment(label, mode)
image = image.reshape(1, image.shape[0], image.shape[1])
label = label.reshape(1, label.shape[0], label.shape[1])
# 标签二值化 ,将255 -> 1
label = label / 255
return image, label
# isbi = _3Dircadb_Dataset(r"Dataset/train/")
# print(len(isbi))
# train_loader = torch.utils.data.DataLoader(isbi,
# batch_size=2,
# shuffle=True)
# for image ,label in train_loader:
# print(image.shape)
test
image

label

segment result

全部评论 (0)
还没有任何评论哟~
