Advertisement

医学图像算法之基于Unet++的息肉分割

阅读量:

第一步:准备数据

息肉分割数据,总共有1000张

39d92256800a4ad088367e709819cf16.png

第二步:搭建模型

UNet++是一种专为克服上述局限性而设计的新颖通用图像分割体系结构。如图所示, UNet++由不同深度的U-Net模块构成,其解码器通过重新设计的跳跃连接实现了相同分辨率下的密集连接。该架构引入了两项显著优势:首先, UNet++不便于明确选择网络深度参数,因为它整合了多种不同深度的U-Net模块作为其基础架构;这些模块共享部分编码器信息,而解码器则进行了有机整合。在经过深度监督训练后,UNet++能够同时训练所有层次上的U-Net模块,从而充分利用共享图像表示的优势,从而显著提升了整体分割性能的同时还减少了推理时所需的计算资源消耗;其次,该架构避免了传统方法中由于不必要的限制性跳跃连接所带来的性能损失,即传统方法只能融合来自编码器和解码器相同分辨率的特征图;而UNet++采用了经过重新设计后的跳跃连接系统,能够在解码节点处提供不同分辨率的特征图集合,从而让聚合层能够自主决定如何整合这些多来源特征图与当前解码器层面的特征表示;这种设计使得UNet++能够在相同分辨率下实现层次化解码操作;此外,在实验评估方面作者选取了六个分割数据集以及多个主干模型进行广泛的性能测试

faa9f10f23dd4216a86f9b0f81958a33.png

五个贡献:

在UNet++架构中加入了内置可变深度的U-Net集合组,在处理不同尺寸目标时展现出显著性能提升。相较于固定深度的传统U-Net模型而言这一改进更具灵活性。
解码器中的跳跃连接进行了优化以实现特征图的有效融合从而克服了传统方法仅依赖相同比例特征图融合的局限性。
本研究设计了一种高效的剪枝策略用于优化经过训练的UNet++模型成功实现了加速推理速度的同时不牺牲性能目标。
通过在同一架构下并行训练多深度版本的UNet++网络能够激发各组件网络之间的协同学习效果较于单独训练独立模型能显著提升整体性能表现。
该方法展示了对多种主干编码器的支持并将其扩展应用于CT MRI和电子显微镜等多种医学成像领域实现了广泛的适用性和实用价值。

第三步:代码

1)损失函数为:交叉熵损失函数+dice_loss

2)网络代码:

复制代码
 class UnetPlusPlus(nn.Module):

    
     def __init__(self, num_classes, deep_supervision=False):
    
     super(UnetPlusPlus, self).__init__()
    
     self.num_classes = num_classes
    
     self.deep_supervision = deep_supervision
    
     self.filters = [64, 128, 256, 512, 1024]
    
  
    
     self.CONV3_1 = ContinusParalleConv(512 * 2, 512, pre_Batch_Norm=True)
    
  
    
     self.CONV2_2 = ContinusParalleConv(256 * 3, 256, pre_Batch_Norm=True)
    
     self.CONV2_1 = ContinusParalleConv(256 * 2, 256, pre_Batch_Norm=True)
    
  
    
     self.CONV1_1 = ContinusParalleConv(128 * 2, 128, pre_Batch_Norm=True)
    
     self.CONV1_2 = ContinusParalleConv(128 * 3, 128, pre_Batch_Norm=True)
    
     self.CONV1_3 = ContinusParalleConv(128 * 4, 128, pre_Batch_Norm=True)
    
  
    
     self.CONV0_1 = ContinusParalleConv(64 * 2, 64, pre_Batch_Norm=True)
    
     self.CONV0_2 = ContinusParalleConv(64 * 3, 64, pre_Batch_Norm=True)
    
     self.CONV0_3 = ContinusParalleConv(64 * 4, 64, pre_Batch_Norm=True)
    
     self.CONV0_4 = ContinusParalleConv(64 * 5, 64, pre_Batch_Norm=True)
    
  
    
     self.stage_0 = ContinusParalleConv(3, 64, pre_Batch_Norm=False)
    
     self.stage_1 = ContinusParalleConv(64, 128, pre_Batch_Norm=False)
    
     self.stage_2 = ContinusParalleConv(128, 256, pre_Batch_Norm=False)
    
     self.stage_3 = ContinusParalleConv(256, 512, pre_Batch_Norm=False)
    
     self.stage_4 = ContinusParalleConv(512, 1024, pre_Batch_Norm=False)
    
  
    
     self.pool = nn.MaxPool2d(2)
    
  
    
     self.upsample_3_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1)
    
  
    
     self.upsample_2_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
    
     self.upsample_2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
    
  
    
     self.upsample_1_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
    
     self.upsample_1_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
    
     self.upsample_1_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
    
  
    
     self.upsample_0_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
    
     self.upsample_0_2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
    
     self.upsample_0_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
    
     self.upsample_0_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
    
  
    
     # 分割头
    
     self.final_super_0_1 = nn.Sequential(
    
         nn.BatchNorm2d(64),
    
         nn.ReLU(),
    
         nn.Conv2d(64, self.num_classes, 3, padding=1),
    
     )
    
     self.final_super_0_2 = nn.Sequential(
    
         nn.BatchNorm2d(64),
    
         nn.ReLU(),
    
         nn.Conv2d(64, self.num_classes, 3, padding=1),
    
     )
    
     self.final_super_0_3 = nn.Sequential(
    
         nn.BatchNorm2d(64),
    
         nn.ReLU(),
    
         nn.Conv2d(64, self.num_classes, 3, padding=1),
    
     )
    
     self.final_super_0_4 = nn.Sequential(
    
         nn.BatchNorm2d(64),
    
         nn.ReLU(),
    
         nn.Conv2d(64, self.num_classes, 3, padding=1),
    
     )
    
  
    
     def forward(self, x):
    
     x_0_0 = self.stage_0(x)
    
     x_1_0 = self.stage_1(self.pool(x_0_0))
    
     x_2_0 = self.stage_2(self.pool(x_1_0))
    
     x_3_0 = self.stage_3(self.pool(x_2_0))
    
     x_4_0 = self.stage_4(self.pool(x_3_0))
    
  
    
     x_0_1 = torch.cat([self.upsample_0_1(x_1_0), x_0_0], 1)
    
     x_0_1 = self.CONV0_1(x_0_1)
    
  
    
     x_1_1 = torch.cat([self.upsample_1_1(x_2_0), x_1_0], 1)
    
     x_1_1 = self.CONV1_1(x_1_1)
    
  
    
     x_2_1 = torch.cat([self.upsample_2_1(x_3_0), x_2_0], 1)
    
     x_2_1 = self.CONV2_1(x_2_1)
    
  
    
     x_3_1 = torch.cat([self.upsample_3_1(x_4_0), x_3_0], 1)
    
     x_3_1 = self.CONV3_1(x_3_1)
    
  
    
     x_2_2 = torch.cat([self.upsample_2_2(x_3_1), x_2_0, x_2_1], 1)
    
     x_2_2 = self.CONV2_2(x_2_2)
    
  
    
     x_1_2 = torch.cat([self.upsample_1_2(x_2_1), x_1_0, x_1_1], 1)
    
     x_1_2 = self.CONV1_2(x_1_2)
    
  
    
     x_1_3 = torch.cat([self.upsample_1_3(x_2_2), x_1_0, x_1_1, x_1_2], 1)
    
     x_1_3 = self.CONV1_3(x_1_3)
    
  
    
     x_0_2 = torch.cat([self.upsample_0_2(x_1_1), x_0_0, x_0_1], 1)
    
     x_0_2 = self.CONV0_2(x_0_2)
    
  
    
     x_0_3 = torch.cat([self.upsample_0_3(x_1_2), x_0_0, x_0_1, x_0_2], 1)
    
     x_0_3 = self.CONV0_3(x_0_3)
    
  
    
     x_0_4 = torch.cat([self.upsample_0_4(x_1_3), x_0_0, x_0_1, x_0_2, x_0_3], 1)
    
     x_0_4 = self.CONV0_4(x_0_4)
    
  
    
     if self.deep_supervision:
    
         out_put1 = self.final_super_0_1(x_0_1)
    
         out_put2 = self.final_super_0_2(x_0_2)
    
         out_put3 = self.final_super_0_3(x_0_3)
    
         out_put4 = self.final_super_0_4(x_0_4)
    
         return [out_put1, out_put2, out_put3, out_put4]
    
     else:
    
         return self.final_super_0_4(x_0_4)

第四步:统计一些指标(训练过程中的loss和miou)

f14b605f721b4056a0f2dbae42eae1b5.png
25f01622a64d425c847efb6acab29ce4.png

第五步:搭建GUI界面

d285bda62f9a4e269f75e34f4a37bb65.png
9dc1039bd3f64610a4bdb83eaae330b6.png

第六步:整个工程的内容

392a7e51862c435eb8e9dfcfec8f4206.png

源码下载

在线获取源代码

在线获取源代码

整套项目源码内容包含

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

全部评论 (0)

还没有任何评论哟~