Advertisement

医学图像算法之基于UNet3+(UNet+++)的肝脏CT分割

阅读量:

第一步:准备数据

肝脏CT分割,总共有400张
702a32fae8714b47aa4403218acae96d.png

第二步:搭建模型

UNet3+主要是参考了UNet和UNet++两个网络结构。尽管UNet++采用了嵌套和密集跳过连接的网络结构(见图1(b)红色三角区域),但是它没有直接从多尺度信息中提取足够多的信息。此部分,在我理解而言UNet++虽然名义上通过嵌套和密集跳过连接进行了多尺度信息的利用,但是从本质上看基本都是短连接,基本上都对解码特征进行了再次处理,再加上各个连接的融合,多尺度信息的原始特征几乎没有得到特别好的利用,信号处理有些矫枉过正或是丢失。UNet3+利用了全尺度的跳跃连接(skip connection)和深度监督(deep supervisions)。全尺度的跳跃连接把来自不同尺度特征图中的高级语义与低级语义直接结合(当然需要必要的上采样操作);而深度监督则从多尺度聚合的特征图中学习层次表示。注意一点:UNet++和UNet3+都用到了深度监督,但是监督的位置是完全不一样的,从图1(b)、(c)中的Sup部分可以清楚的看到不同之处。
ea1d6c7e75eb4f1ca215aea4e6340671.png

在UNet3+中,可以从全尺度捕获细粒度的细节和粗粒度的语义。为了进一步从全尺寸的聚合特征图中学习层次表示法,每个边的输出都与一个混合损失函数相连接,这有助于精确分割,特别是对于在医学图像体积中出现不同尺度的器官。从图1中也可以看出,UNet3+的参数量明显小于UNet++。

第三步:代码

1)损失函数为:交叉熵损失函数

2)网络代码:

复制代码
 class UNet3Plus(nn.Module):

    
     def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
    
              is_deconv=True, is_batchnorm=True):
    
     super(UNet3Plus, self).__init__()
    
     self.n_channels = n_channels
    
     self.n_classes = n_classes
    
     self.bilinear = bilinear
    
     self.feature_scale = feature_scale
    
     self.is_deconv = is_deconv
    
     self.is_batchnorm = is_batchnorm
    
     filters = [16, 32, 64, 128, 256]
    
  
    
     ## -------------Encoder--------------
    
     self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm)
    
     self.maxpool1 = nn.MaxPool2d(kernel_size=2)
    
  
    
     self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
    
     self.maxpool2 = nn.MaxPool2d(kernel_size=2)
    
  
    
     self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
    
     self.maxpool3 = nn.MaxPool2d(kernel_size=2)
    
  
    
     self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
    
     self.maxpool4 = nn.MaxPool2d(kernel_size=2)
    
  
    
     self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
    
  
    
     ## -------------Decoder--------------
    
     self.CatChannels = filters[0]
    
     self.CatBlocks = 5
    
     self.UpChannels = self.CatChannels * self.CatBlocks
    
  
    
     '''stage 4d'''
    
     # h1->320*320, hd4->40*40, Pooling 8 times
    
     self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
    
     self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
    
     self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
    
  
    
     # h2->160*160, hd4->40*40, Pooling 4 times
    
     self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
    
     self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
    
     self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
    
  
    
     # h3->80*80, hd4->40*40, Pooling 2 times
    
     self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
    
     self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
    
     self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h3_PT_hd4_relu = nn.ReLU(inplace=True)
    
  
    
     # h4->40*40, hd4->40*40, Concatenation
    
     self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
    
     self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)
    
  
    
     # hd5->20*20, hd4->40*40, Upsample 2 times
    
     self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
    
     self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
    
     self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)
    
  
    
     # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
    
     self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
    
     self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
    
     self.relu4d_1 = nn.ReLU(inplace=True)
    
  
    
     '''stage 3d'''
    
     # h1->320*320, hd3->80*80, Pooling 4 times
    
     self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
    
     self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
    
     self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h1_PT_hd3_relu = nn.ReLU(inplace=True)
    
  
    
     # h2->160*160, hd3->80*80, Pooling 2 times
    
     self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
    
     self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
    
     self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h2_PT_hd3_relu = nn.ReLU(inplace=True)
    
  
    
     # h3->80*80, hd3->80*80, Concatenation
    
     self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
    
     self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)
    
  
    
     # hd4->40*40, hd4->80*80, Upsample 2 times
    
     self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
    
     self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)
    
  
    
     # hd5->20*20, hd4->80*80, Upsample 4 times
    
     self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
    
     self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
    
     self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)
    
  
    
     # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
    
     self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
    
     self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
    
     self.relu3d_1 = nn.ReLU(inplace=True)
    
  
    
     '''stage 2d '''
    
     # h1->320*320, hd2->160*160, Pooling 2 times
    
     self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
    
     self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
    
     self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h1_PT_hd2_relu = nn.ReLU(inplace=True)
    
  
    
     # h2->160*160, hd2->160*160, Concatenation
    
     self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
    
     self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)
    
  
    
     # hd3->80*80, hd2->160*160, Upsample 2 times
    
     self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
    
     self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)
    
  
    
     # hd4->40*40, hd2->160*160, Upsample 4 times
    
     self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
    
     self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)
    
  
    
     # hd5->20*20, hd2->160*160, Upsample 8 times
    
     self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
    
     self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
    
     self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)
    
  
    
     # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
    
     self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
    
     self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
    
     self.relu2d_1 = nn.ReLU(inplace=True)
    
  
    
     '''stage 1d'''
    
     # h1->320*320, hd1->320*320, Concatenation
    
     self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
    
     self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)
    
  
    
     # hd2->160*160, hd1->320*320, Upsample 2 times
    
     self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
    
     self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)
    
  
    
     # hd3->80*80, hd1->320*320, Upsample 4 times
    
     self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
    
     self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)
    
  
    
     # hd4->40*40, hd1->320*320, Upsample 8 times
    
     self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
    
     self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
    
     self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)
    
  
    
     # hd5->20*20, hd1->320*320, Upsample 16 times
    
     self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14
    
     self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
    
     self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
    
     self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)
    
  
    
     # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
    
     self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16
    
     self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
    
     self.relu1d_1 = nn.ReLU(inplace=True)
    
  
    
     # output
    
     self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
    
  
    
     # initialise weights
    
     for m in self.modules():
    
         if isinstance(m, nn.Conv2d):
    
             init_weights(m, init_type='kaiming')
    
         elif isinstance(m, nn.BatchNorm2d):
    
             init_weights(m, init_type='kaiming')
    
  
    
     def forward(self, inputs):
    
     ## -------------Encoder-------------
    
     h1 = self.conv1(inputs)  # h1->320*320*64
    
  
    
     h2 = self.maxpool1(h1)
    
     h2 = self.conv2(h2)  # h2->160*160*128
    
  
    
     h3 = self.maxpool2(h2)
    
     h3 = self.conv3(h3)  # h3->80*80*256
    
  
    
     h4 = self.maxpool3(h3)
    
     h4 = self.conv4(h4)  # h4->40*40*512
    
  
    
     h5 = self.maxpool4(h4)
    
     hd5 = self.conv5(h5)  # h5->20*20*1024
    
  
    
     ## -------------Decoder-------------
    
     h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
    
     h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
    
     h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
    
     h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
    
     hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
    
     hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
    
         torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1))))  # hd4->40*40*UpChannels
    
  
    
     h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
    
     h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
    
     h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
    
     hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
    
     hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
    
     hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
    
         torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1))))  # hd3->80*80*UpChannels
    
  
    
     h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
    
     h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
    
     hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
    
     hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
    
     hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
    
     hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
    
         torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1))))  # hd2->160*160*UpChannels
    
  
    
     h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
    
     hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
    
     hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
    
     hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
    
     hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
    
     hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
    
         torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1))))  # hd1->320*320*UpChannels
    
  
    
     d1 = self.outconv1(hd1)  # d1->320*320*n_classes
    
     return F.sigmoid(d1)
    
    
    
    

第四步:统计一些指标(训练过程中的loss和miou)
f041a8103b384b6ab04ab18b2e24f1b6.png
c8dc6caedb2943c5be0877322048fbc0.png

第五步:搭建GUI界面
0609e787228b4248998127bdba9cfd76.png
5337f0a706a6494e918432414358d9da.png

第六步:整个工程的内容
bab81ebbb6fb40f5b6378345224448a0.png

【源码下载】GitCode,关键词【肝脏CT分割

整套项目源码内容包含

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

全部评论 (0)

还没有任何评论哟~