Advertisement

语义分割论文:U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI2015)

阅读量:

U-Net(全称:Convolutional Networks for Biomedical Image Segmentation):这是一种用于生物医学图像分割的卷积神经网络(MICCAI 2015)。PDF版本:《Deep Learning in Medical Imaging》。此外,请参阅官方网页以获取更多细节。PyTorch实现可通过GitHub仓库访问。

在这里插入图片描述

特点:

  1. 捕捉上下文的收缩路径(contracting path);
  2. 实现精确定位的对称扩展路径(symmetric expanding path),扩张路径由2∗22∗2的上卷积,上卷积的output channels为原先的一半,再与对应的特征图(裁剪后)串联起来(得到和原先一样大小的channels),导致模型更大,需要更多内存;
  3. 可以对非常少的图像端对端地进行训练;
  4. 适合超大图像分割,适合医学图像分割,医学图像一般比较大,但是分割时候不可能将原图太小输入网络,所以必须切成一张一张的小patch,在切成小patch的时候,Unet由于网络结构原因适合有overlap的切图 (见Fig.2);
在这里插入图片描述

PyTorch代码:

复制代码
    # !/usr/bin/env python
    # -- coding: utf-8 --
    # @Time : 2020/7/8 13:51
    # @Author : liumin
    # @File : Unet.py
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    def Conv3x3BNReLU(in_channels,out_channels,stride,groups=1):
    return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, groups=groups),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    
    def Conv1x1BNReLU(in_channels,out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    
    def Conv1x1BN(in_channels,out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels)
        )
    
    
    class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            Conv3x3BNReLU(in_channels, out_channels,stride=1),
            Conv3x3BNReLU(out_channels, out_channels, stride=1)
        )
    
    def forward(self, x):
        return self.double_conv(x)
    
    
    class DownConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels,stride=2):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2,stride=stride)
        self.double_conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x):
        return self.pool(self.double_conv(x))
    
    
    class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels,bilinear=True):
        super().__init__()
        self.reduce = Conv1x1BNReLU(in_channels, in_channels//2)
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(self.reduce(x1))
        _, channel1, height1, width1 = x1.size()
        _, channel2, height2, width2 = x2.size()
    
        # input is CHW
        diffY = height2 - height1
        diffX = width2 - width1
    
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
    
    class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        bilinear = True
    
        self.conv = DoubleConv(3, 64)
        self.down1 = DownConv(64, 128)
        self.down2 = DownConv(128, 256)
        self.down3 = DownConv(256, 512)
        self.down4 = DownConv(512, 1024)
        self.up1 = UpConv(1024, 512, bilinear)
        self.up2 = UpConv(512, 256, bilinear)
        self.up3 = UpConv(256, 128, bilinear)
        self.up4 = UpConv(128, 64, bilinear)
        self.outconv = nn.Conv2d(64, num_classes, kernel_size=1)
    
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        xx = self.up1(x5, x4)
        xx = self.up2(xx, x3)
        xx = self.up3(xx, x2)
        xx = self.up4(xx, x1)
        outputs = self.outconv(xx)
        return outputs
    
    
    if __name__ =='__main__':
    model = UNet(19)
    print(model)
    
    input = torch.randn(1,3,572,572)
    out = model(input)
    print(out.shape)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

全部评论 (0)

还没有任何评论哟~