Advertisement

unet脑肿瘤分割_2D UNet3+ Pytorch实现 脑肿瘤分割

阅读量:

一、网络介绍

获取学术资源的方法包括对相关论文进行深入分析。其中涉及的技术细节及其优化方案值得深入研究。

原代码链接:链接

二、BraTs数据预处理

本文基于训练集与验证集来源于 BraTs2018 的训练数据集中(其中 HGG 为 210 例病例、LGG 为 75 例病例)。

然而, BraTs仅提供了训练数据集,而缺乏测试数据.如果我们从训练集中预留一部分作为测试用例进行评估,将会导致训练数据量的减少.由于样本数量有限,模型容易陷入过拟合现象:即在训练阶段表现出色但在实际应用中效果不佳.经过一番思考后, 我设计了一个解决方案.

由于BraTs2019训练集基于BraTs2018进行了扩展,在HGG样本数量上增加了49例,在LGG样本数量上仅增加至1例。因此我打算将这些新增样本用作测试数据

下面我提供百度云盘给大家下载,这是原始数据

BraTs18数据集下载地址(不包含测试集,提供的验证集无GT)

链接:https://pan.baidu.com/s/1Ry41OVl9VLOMzhQQR9qXuA 提取码:qvmo

BraTs19数据集下载地址如下(不包含测试集,提供的验证集无GT)

链接: https://pan.baidu.com/s/1S5XGTdHkwFnagKS-5vWYBg 提取码: 2333

数据的预处理以及实现代码

把上面两年的数据下下来,然后我对数据的预处理方法是链接

完整的Python代码(在Jupyter Notebook中运行)访问GitHub仓库:https://github.com/Merofine/BraTS2Dpreprocessing,并下载训练集与验证集:https://github.com/Merofine/BraTS2Dpreprocessing/blob/master/GetTrainingSets.ipynb

GetTestingSetsFrom2019.ipynb-—>测试集

代码执行完后,获得npy数据

链接:https://pan.baidu.com/s/1W3rcl9I-Y8DwWu5p4o--cw 密码:hfe7

三、运行环境的安装

1、系统环境 WIN10 + CUDA 92 + CUDNN7 + ANACONDA

2、ANACONDA指令快速配置环境,先下载下面文件

四、核心代码

-- coding: utf-8 --

import torch

import torch.nn as nn

import torch.nn.functional as F

from layers import unetConv2

from init_weights import init_weights

'''

UNet 3+

'''

class UNet_3Plus(nn.Module):

def init(self, args):

super(UNet_3Plus, self).init()

self.args = args

in_channels = 4

n_classes = 3

feature_scale = 4

is_deconv = True

is_batchnorm = True

self.is_deconv = is_deconv

self.in_channels = in_channels

self.is_batchnorm = is_batchnorm

self.feature_scale = feature_scale

filters = [64, 128, 256, 512, 1024]

-------------Encoder--------------

self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)

self.maxpool1 = nn.MaxPool2d(kernel_size=2)

self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)

self.maxpool2 = nn.MaxPool2d(kernel_size=2)

self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)

self.maxpool3 = nn.MaxPool2d(kernel_size=2)

self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)

self.maxpool4 = nn.MaxPool2d(kernel_size=2)

self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)

-------------Decoder--------------

self.CatChannels = filters[0]

self.CatBlocks = 5

self.UpChannels = self.CatChannels * self.CatBlocks

'''stage 4d'''

h1->320320, hd4->4040, Pooling 8 times

self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)

self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)

self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)

self.h1_PT_hd4_relu = nn.ReLU(inplace=True)

h2->160160, hd4->4040, Pooling 4 times

self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)

self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)

self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)

self.h2_PT_hd4_relu = nn.ReLU(inplace=True)

h3->8080, hd4->4040, Pooling 2 times

self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)

self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)

self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)

self.h3_PT_hd4_relu = nn.ReLU(inplace=True)

h4->4040, hd4->4040, Concatenation

self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)

self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)

self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

hd5->2020, hd4->4040, Upsample 2 times

self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14

self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)

self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)

self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)

self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16

self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)

self.relu4d_1 = nn.ReLU(inplace=True)

'''stage 3d'''

h1->320320, hd3->8080, Pooling 4 times

self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)

self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)

self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)

self.h1_PT_hd3_relu = nn.ReLU(inplace=True)

h2->160160, hd3->8080, Pooling 2 times

self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)

self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)

self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)

self.h2_PT_hd3_relu = nn.ReLU(inplace=True)

h3->8080, hd3->8080, Concatenation

self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)

self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)

self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

hd4->4040, hd4->8080, Upsample 2 times

self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14

self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)

self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

hd5->2020, hd4->8080, Upsample 4 times

self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14

self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)

self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)

self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)

self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16

self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)

self.relu3d_1 = nn.ReLU(inplace=True)

'''stage 2d '''

h1->320320, hd2->160160, Pooling 2 times

self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)

self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)

self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)

self.h1_PT_hd2_relu = nn.ReLU(inplace=True)

h2->160160, hd2->160160, Concatenation

self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)

self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)

self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

hd3->8080, hd2->160160, Upsample 2 times

self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14

self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)

self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

hd4->4040, hd2->160160, Upsample 4 times

self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14

self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)

self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

hd5->2020, hd2->160160, Upsample 8 times

self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14

self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)

self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)

self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)

self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16

self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)

self.relu2d_1 = nn.ReLU(inplace=True)

'''stage 1d'''

h1->320320, hd1->320320, Concatenation

self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)

self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)

self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

hd2->160160, hd1->320320, Upsample 2 times

self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14

self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)

self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

hd3->8080, hd1->320320, Upsample 4 times

self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14

self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)

self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

hd4->4040, hd1->320320, Upsample 8 times

self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14

self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)

self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)

self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

hd5->2020, hd1->320320, Upsample 16 times

self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14

self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)

self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)

self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)

self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16

self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)

self.relu1d_1 = nn.ReLU(inplace=True)

output

self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)

initialise weights

for m in self.modules():

if isinstance(m, nn.Conv2d):

init_weights(m, init_type='kaiming')

elif isinstance(m, nn.BatchNorm2d):

init_weights(m, init_type='kaiming')

def forward(self, inputs):

-------------Encoder-------------

h1 = self.conv1(inputs) # h1->32032064

h2 = self.maxpool1(h1)

h2 = self.conv2(h2) # h2->160160128

h3 = self.maxpool2(h2)

h3 = self.conv3(h3) # h3->8080256

h4 = self.maxpool3(h3)

h4 = self.conv4(h4) # h4->4040512

h5 = self.maxpool4(h4)

hd5 = self.conv5(h5) # h5->20201024

-------------Decoder-------------

h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))

h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))

h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))

h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))

hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))

hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(

torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->4040UpChannels

h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))

h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))

h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))

hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))

hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))

hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(

torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->8080UpChannels

h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))

h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))

hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))

hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))

hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))

hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(

torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160160UpChannels

h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))

hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))

hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))

hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))

hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))

hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(

torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320320UpChannels

d1 = self.outconv1(hd1) # d1->320320n_classes

return d1

完整代码请私聊博主~(QQ:704783475、博主想恰杯奶茶)

五、训练

python train.py --arch=“UNet_3Plus” --dataset=“Jiu0Monkey”

六、测试

python test.py --name="jiu0Monkey_UNet_3Plus_woDS"

七、与其它模型对比

全部评论 (0)

还没有任何评论哟~