Advertisement

UNet图像分割

阅读量:
什么是 UNet?

UNet是一种用于图像分割任务的深度学习模型(CNN),其架构设计灵感来源于人体血液循环系统中的体中心导管系统。如该模型所示,在2015年首次提出。这种架构能够有效应对多种复杂度的图像分割问题。从基本概念出发,图像是指将一张或多张数字图像分解为若干有意义的部分。这一技术在医学影像诊断和自动驾驶系统中具有重要意义,并且通过结合多尺度特征提取与空间注意力机制的设计理念,在提升检测精度的同时也显著降低了计算开销。显式的多尺度特征融合方法不仅能够有效减少计算复杂度还能使后续分类或回归过程更加鲁棒。显式的多尺度特征融合方法不仅能够有效减少计算复杂度还能使后续分类或回归过程更加鲁棒

UNet 的架构:编码器与解码器

UNet的主要组成部分包括两个关键模块:编码模块(Contraction Path)解码模块(Expansion Path) ,它们通过跳跃连接进行通信。

编码器 :这部分类似于典型的卷积神经网络架构,在计算机视觉领域具有重要地位。它主要负责从输入图片中提取关键且有意义的信息,并通过多层处理逐渐优化这些表征。具体而言,在每经过一层处理都会依次执行以下步骤:首先进行二维卷积运算,在此基础之上应用ReLU激活函数以引入非线性特性,并结合Max Pooling操作以降低空间分辨率并增强特征表示能力。最终能够有效地提取出关键且有意义的图像特徵

解码器部分的功能是将编码器提取出的特征逐步恢复成与输入图片相同尺寸的分割结果。这一过程采用了上采样技术,并通过跳跃连接将编码器相应层级的特征整合到解码器中以保持更多的细节信息。

  1. 跳跃连接 :它通过跳跃连接,在解码器进行上采样时捕获编码器提取的关键特征,避免细节信息的丢失,并显著提升了分割的质量。

f3f3c137273502a48205f026b463ec22.png
UNet 的优势

UNet 之所以如此受欢迎,主要得益于以下几大优势:

  • 小数据集友好 :UNet 最早被设计用于医学图像分割,针对小样本数据集有很好的处理能力,这使得它在一些数据稀缺的场景中表现尤为突出。

清晰且精确的分割效果:由于UNet借助跳跃连接机制,在处理高分辨率图像时能够有效维持图像细节特征;其分割精度通常显著优于其他方法。

  • 适应性强 :UNet采用模块化设计,在保持简单性和高效性的同时具备较强的可扩展性特点。该算法可以根据不同类型的分割任务进行优化配置,在保证性能的前提下实现了对复杂场景的适应能力。这种方法不仅仅局限于医学领域,在计算机视觉领域的图像分割问题中也取得了良好的应用效果。
使用 PyTorch 实现 UNet

为了更直观地认识 UNet, 我们可以通过代码示例, 借助 PyTorch 开发了一个简化的 UNet 模型

复制代码
 import torch

    
 import torch.nn as nn
    
 import torch.nn.functional as F
    
  
    
 class UNet(nn.Module):
    
     def __init__(self, in_channels, out_channels):
    
         super(UNet, self).__init__()
    
         # 编码器部分(下采样路径)
    
         self.encoder1 = self.double_conv(in_channels, 64)
    
         self.encoder2 = self.double_conv(64, 128)
    
         self.encoder3 = self.double_conv(128, 256)
    
         self.encoder4 = self.double_conv(256, 512)
    
         
    
         # Bottleneck(网络最深处)
    
         self.bottleneck = self.double_conv(512, 1024)
    
         
    
         # 解码器部分(上采样路径)
    
         self.upconv4 = self.upconv(1024, 512)
    
         self.decoder4 = self.double_conv(1024, 512)  # 1024 是因为有跳跃连接
    
         self.upconv3 = self.upconv(512, 256)
    
         self.decoder3 = self.double_conv(512, 256)
    
         self.upconv2 = self.upconv(256, 128)
    
         self.decoder2 = self.double_conv(256, 128)
    
         self.upconv1 = self.upconv(128, 64)
    
         self.decoder1 = self.double_conv(128, 64)
    
         
    
         # 最终输出层
    
         self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
     
    
     def double_conv(self, in_channels, out_channels):
    
         """两次3x3卷积+批归一化+ReLU激活"""
    
         return nn.Sequential(
    
             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
    
             nn.BatchNorm2d(out_channels),
    
             nn.ReLU(inplace=True),
    
             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
    
             nn.BatchNorm2d(out_channels),
    
             nn.ReLU(inplace=True)
    
         )
    
     
    
     def upconv(self, in_channels, out_channels):
    
         """上采样:使用2x2的转置卷积"""
    
         return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
  
    
     def forward(self, x):
    
         # 编码器路径
    
         enc1 = self.encoder1(x)
    
         enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
    
         enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
    
         enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))
    
         
    
         # Bottleneck
    
         bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))
    
         
    
         # 解码器路径
    
         dec4 = self.upconv4(bottleneck)
    
         dec4 = torch.cat((dec4, enc4), dim=1)  # 跳跃连接
    
         dec4 = self.decoder4(dec4)
    
         
    
         dec3 = self.upconv3(dec4)
    
         dec3 = torch.cat((dec3, enc3), dim=1)  # 跳跃连接
    
         dec3 = self.decoder3(dec3)
    
         
    
         dec2 = self.upconv2(dec3)
    
         dec2 = torch.cat((dec2, enc2), dim=1)  # 跳跃连接
    
         dec2 = self.decoder2(dec2)
    
         
    
         dec1 = self.upconv1(dec2)
    
         dec1 = torch.cat((dec1, enc1), dim=1)  # 跳跃连接
    
         dec1 = self.decoder1(dec1)
    
         
    
         return self.final_conv(dec1)
    
  
    
 # 使用示例
    
 model = UNet(in_channels=1, out_channels=1)
    
 input_image = torch.randn(1, 1, 512, 512)  # 批大小为1,单通道,512x512图像
    
 output = model(input_image)
    
 print(output.shape)  # 应该输出 torch.Size([1, 1, 512, 512])
    
    
    
    
    go
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/4Vv6gedAZkrz9Uu8PC3EyqYIt1Sn.png)

通过上述代码,便可轻松搭建一个简化的 UNet 模型,助力图像分割任务。

d643e321575451854ad395864364b474.gif

全部评论 (0)

还没有任何评论哟~