语义分割论文:U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI2015)
发布时间
阅读量:
阅读量
U-Net(全称:Convolutional Networks for Biomedical Image Segmentation):这是一种用于生物医学图像分割的卷积神经网络(MICCAI 2015)。PDF版本:《Deep Learning in Medical Imaging》。此外,请参阅官方网页以获取更多细节。PyTorch实现可通过GitHub仓库访问。

特点:
- 捕捉上下文的收缩路径(contracting path);
- 实现精确定位的对称扩展路径(symmetric expanding path),扩张路径由2∗22∗2的上卷积,上卷积的output channels为原先的一半,再与对应的特征图(裁剪后)串联起来(得到和原先一样大小的channels),导致模型更大,需要更多内存;
- 可以对非常少的图像端对端地进行训练;
- 适合超大图像分割,适合医学图像分割,医学图像一般比较大,但是分割时候不可能将原图太小输入网络,所以必须切成一张一张的小patch,在切成小patch的时候,Unet由于网络结构原因适合有overlap的切图 (见Fig.2);

PyTorch代码:
# !/usr/bin/env python
# -- coding: utf-8 --
# @Time : 2020/7/8 13:51
# @Author : liumin
# @File : Unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def Conv3x3BNReLU(in_channels,out_channels,stride,groups=1):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, groups=groups),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def Conv1x1BNReLU(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def Conv1x1BN(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels)
)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
Conv3x3BNReLU(in_channels, out_channels,stride=1),
Conv3x3BNReLU(out_channels, out_channels, stride=1)
)
def forward(self, x):
return self.double_conv(x)
class DownConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels,stride=2):
super().__init__()
self.pool = nn.MaxPool2d(kernel_size=2,stride=stride)
self.double_conv = DoubleConv(in_channels, out_channels)
def forward(self, x):
return self.pool(self.double_conv(x))
class UpConv(nn.Module):
def __init__(self, in_channels, out_channels,bilinear=True):
super().__init__()
self.reduce = Conv1x1BNReLU(in_channels, in_channels//2)
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(self.reduce(x1))
_, channel1, height1, width1 = x1.size()
_, channel2, height2, width2 = x2.size()
# input is CHW
diffY = height2 - height1
diffX = width2 - width1
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, num_classes):
super(UNet, self).__init__()
bilinear = True
self.conv = DoubleConv(3, 64)
self.down1 = DownConv(64, 128)
self.down2 = DownConv(128, 256)
self.down3 = DownConv(256, 512)
self.down4 = DownConv(512, 1024)
self.up1 = UpConv(1024, 512, bilinear)
self.up2 = UpConv(512, 256, bilinear)
self.up3 = UpConv(256, 128, bilinear)
self.up4 = UpConv(128, 64, bilinear)
self.outconv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
x1 = self.conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
xx = self.up1(x5, x4)
xx = self.up2(xx, x3)
xx = self.up3(xx, x2)
xx = self.up4(xx, x1)
outputs = self.outconv(xx)
return outputs
if __name__ =='__main__':
model = UNet(19)
print(model)
input = torch.randn(1,3,572,572)
out = model(input)
print(out.shape)
代码解读
全部评论 (0)
还没有任何评论哟~
