Advertisement

医学图像分割--U-net变种

阅读量:

综上所述,在关于医学图像分割的研究综述中提到了U-Net系列及其在实例分割技术中的应用。

2D Unet

  • 收缩分支:其中每个模块包含两个连续的3 \times 3卷积操作后跟一个ReLU激活函数以及最大池化层(下采样)

  • 扩展分支:该分支由一个2 \times 2转置卷积层(上采样)开始后跟两个连续的3 \times 3卷积操作以及一次ReLU激活

  • nn.BatchNorm2d(out_channels):通过批处理规范化技术对输入数据进行标准化处理,并非强制性要求必须采用此层;该操作有助于缓解梯度消失与爆炸的问题;同时能够加速网络训练过程并提升模型的泛化性能。

复制代码
>       1. class DoubleConv(nn.Module):

>  
>       2.     def __init__(self, in_channels, out_channels, with_bn=False):
>  
>       3.         super().__init__()
>  
>       4.         if with_bn:
>  
>       5.             self.step = nn.Sequential(
>  
>       6.                 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
>  
>       7.                 nn.BatchNorm2d(out_channels),
>  
>       8.                 nn.ReLU(),
>  
>       9.                 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
>  
>       10.                 nn.BatchNorm2d(out_channels),
>  
>       11.                 nn.ReLU(),
>  
>       12.             )
>  
>       13.         else:
>  
>       14.             self.step = nn.Sequential(
>  
>       15.                 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
>  
>       16.                 nn.ReLU(),
>  
>       17.                 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
>  
>       18.                 nn.ReLU(),
>  
>       19.             )
>  
>       20.  
>  
>       21.     def forward(self, x):
>  
>       22.         return self.step(x)
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/hUZqJ8HGNPST7r4KdWsgFlkjftmY.png)

定义了整个UNet网络结构,包括编码(下采样)和解码(上采样)部分。

复制代码
>       1. class UNet(nn.Module):

>  
>       2.     def __init__(self, in_channels, out_channels, with_bn=False):
>  
>       3.         super().__init__()
>  
>       4.         init_channels = 32
>  
>       5.         self.out_channels = out_channels
>  
>       6.  
>  
>       7.         self.en_1 = DoubleConv(in_channels    , init_channels  , with_bn)
>  
>       8.         self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
>  
>       9.         self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
>  
>       10.         self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
>  
>       11.  
>  
>       12.         self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
>  
>       13.         self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
>  
>       14.         self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
>  
>       15.         self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
>  
>       16.  
>  
>       17.         self.maxpool = nn.MaxPool2d(kernel_size=2)
>  
>       18.         self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
>  
>       19.  
>  
>       20.     def forward(self, x):
>  
>       21.         e1 = self.en_1(x)
>  
>       22.         e2 = self.en_2(self.maxpool(e1))
>  
>       23.         e3 = self.en_3(self.maxpool(e2))
>  
>       24.         e4 = self.en_4(self.maxpool(e3))
>  
>       25.  
>  
>       26.         d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
>  
>       27.         d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
>  
>       28.         d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
>  
>       29.         d4 = self.de_4(d3)
>  
>       30.  
>  
>       31.         return d4
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/naDs6oTqRBbcFOyukUSXgMK7vQHe.png)

以下是以3 \times 256 \times 256图片尺寸为例,在实际操作中进行了手动计算验证。对于存在疑问的读者来说,请确保能够自行核对代码进行验证。

跳过连接增强

  • 该系统通过整合深层低分辨率语义信息与浅层高分辨率本地信息来实现对细节特征的有效分析。
  • 在U-Net架构中,由于本地高分辨率细节在网络压缩阶段缺失,在后续上采样过程中难以完全恢复这些细节特征的影响。

增加跳过连接的数量

改进方案一基于具有双向跳跃连接的U-Net架构设计

  • 改进二(U-net++)
  • 参考:[网络模型(U-net,U-net++, U-net+++)_u-net+++官网-博客]( "网络模型(U-net,U-net++, U-net+++)_u-net+++官网-博客")
  • 流程图:
  • 整体过程为:

常规的双卷积模块(包含两个卷积层和批归一化层)VGGBlock结构,则与上文所述的DoubleConv结构具有相似之处。

复制代码
>       1. class VGGBlock(nn.Module):

>  
>       2.     def __init__(self, in_channels, middle_channels, out_channels):
>  
>       3.         super().__init__()
>  
>       4.         self.relu = nn.ReLU(inplace=True)
>  
>       5.         self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
>  
>       6.         self.bn1 = nn.BatchNorm2d(middle_channels)
>  
>       7.         self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
>  
>       8.         self.bn2 = nn.BatchNorm2d(out_channels)
>  
>       9.  
>  
>       10.     def forward(self, x):
>  
>       11.         out = self.conv1(x)
>  
>       12.         out = self.bn1(out)
>  
>       13.         out = self.relu(out)
>  
>       14.  
>  
>       15.         out = self.conv2(out)
>  
>       16.         out = self.bn2(out)
>  
>       17.         out = self.relu(out)
>  
>       18.  
>  
>       19.         return out
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/o4CgptiAwUdD3Grj5ZRV6aNklJ1x.png)

U-Net与VGG系列模型类似,在阅读过程中一定要自己推导一遍U-Net模型的基本原理和实现细节,在后续章节的学习中将会更加顺畅。

复制代码
>       1. class UNet(nn.Module):

>  
>       2.     def __init__(self, num_classes, input_channels=3, **kwargs):
>  
>       3.         super().__init__()
>  
>       4.  
>  
>       5.         nb_filter = [32, 64, 128, 256, 512]
>  
>       6.  
>  
>       7.         self.pool = nn.MaxPool2d(2, 2)
>  
>       8.         self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>  
>       9.  
>  
>       10.         self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
>  
>       11.         self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
>  
>       12.         self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
>  
>       13.         self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
>  
>       14.         self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
>  
>       15.  
>  
>       16.         self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
>  
>       17.         self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
>  
>       18.         self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
>  
>       19.         self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
>  
>       20.  
>  
>       21.         self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       22.  
>  
>       23.  
>  
>       24.     def forward(self, input):
>  
>       25.         x0_0 = self.conv0_0(input)
>  
>       26.         x1_0 = self.conv1_0(self.pool(x0_0))
>  
>       27.         x2_0 = self.conv2_0(self.pool(x1_0))
>  
>       28.         x3_0 = self.conv3_0(self.pool(x2_0))
>  
>       29.         x4_0 = self.conv4_0(self.pool(x3_0))
>  
>       30.  
>  
>       31.         x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
>  
>       32.         x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
>  
>       33.         x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
>  
>       34.         x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
>  
>       35.  
>  
>       36.         output = self.final(x0_4)
>  
>       37.         return output
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/asS91EGZH28qjo5Pb7lYNkvx3ifJ.png)

核心部分即为NestedUNet,用来替代Unet,整体流程如上述两张图片所示

复制代码
>       1. class NestedUNet(nn.Module):

>  
>       2.     def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
>  
>       3.         super().__init__()
>  
>       4.  
>  
>       5.         nb_filter = [32, 64, 128, 256, 512]
>  
>       6.  
>  
>       7.         self.deep_supervision = deep_supervision
>  
>       8.  
>  
>       9.         self.pool = nn.MaxPool2d(2, 2)
>  
>       10.         self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>  
>       11.         #编码器
>  
>       12.         self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
>  
>       13.         self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
>  
>       14.         self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
>  
>       15.         self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
>  
>       16.         self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
>  
>       17.         #解码器
>  
>       18.         self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
>  
>       19.         self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
>  
>       20.         self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
>  
>       21.         self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
>  
>       22.  
>  
>       23.         self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
>  
>       24.         self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
>  
>       25.         self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
>  
>       26.         #拿conv0_3举例,*3是因为(0,0),(0,1),(0,2)跳跃连接,+nb_filter[1]是因为(1,2)传递过来
>  
>       27.         self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
>  
>       28.         self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
>  
>       29.  
>  
>       30.         self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
>  
>       31.  
>  
>       32.         if self.deep_supervision:
>  
>       33.             self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       34.             self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       35.             self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       36.             self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       37.         else:
>  
>       38.             self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>  
>       39.  
>  
>       40.  
>  
>       41.     def forward(self, input):
>  
>       42.         x0_0 = self.conv0_0(input)
>  
>       43.         x1_0 = self.conv1_0(self.pool(x0_0))
>  
>       44.         x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
>  
>       45.  
>  
>       46.         x2_0 = self.conv2_0(self.pool(x1_0))
>  
>       47.         x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
>  
>       48.         x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
>  
>       49.  
>  
>       50.         x3_0 = self.conv3_0(self.pool(x2_0))
>  
>       51.         x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
>  
>       52.         x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
>  
>       53.         x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
>  
>       54.  
>  
>       55.         x4_0 = self.conv4_0(self.pool(x3_0))
>  
>       56.         x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
>  
>       57.         x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
>  
>       58.         x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
>  
>       59.         x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
>  
>       60.  
>  
>       61.         if self.deep_supervision:
>  
>       62.             output1 = self.final1(x0_1)
>  
>       63.             output2 = self.final2(x0_2)
>  
>       64.             output3 = self.final3(x0_3)
>  
>       65.             output4 = self.final4(x0_4)
>  
>       66.             return [output1, output2, output3, output4]
>  
>       67.  
>  
>       68.         else:
>  
>       69.             output = self.final(x0_4)
>  
>       70.             return output
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/WTIhxwDeov1GuracMzqsXNFJ564Z.png)

在跳过连接中处理特征映射

  • 改进一
    • 本研究提出了一种新型算法,在超声医学图像分析领域取得了显著成果;该算法特别关注并解决卵巢及卵泡分割这一极具挑战性的任务,并显著提升了模型的整体性能。
    • 相邻滤波的空间相互关系对模型性能至关重要;采用的空间循环神经网络(RNNs)能够有效捕捉这种相互关联性。
  • 原有U-Net架构中的最大池化操作可能导致重要的边缘信息丢失;这种做法可能导致模型在捕捉小区域特征时出现偏差。
  • 改进二(Attunet)

初始化层init_weights

复制代码
>       1. def init_weights(net, init_type='normal', gain=0.02):

>  
>       2.     def init_func(m):
>  
>       3.         classname = m.__class__.__name__
>  
>       4.         if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
>  
>       5.             if init_type == 'normal':
>  
>       6.                 init.normal_(m.weight.data, 0.0, gain)
>  
>       7.             elif init_type == 'xavier':
>  
>       8.                 init.xavier_normal_(m.weight.data, gain=gain)
>  
>       9.             elif init_type == 'kaiming':
>  
>       10.                 init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
>  
>       11.             elif init_type == 'orthogonal':
>  
>       12.                 init.orthogonal_(m.weight.data, gain=gain)
>  
>       13.             else:
>  
>       14.                 raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
>  
>       15.             if hasattr(m, 'bias') and m.bias is not None:
>  
>       16.                 init.constant_(m.bias.data, 0.0)
>  
>       17.         elif classname.find('BatchNorm2d') != -1:
>  
>       18.             init.normal_(m.weight.data, 1.0, gain)
>  
>       19.             init.constant_(m.bias.data, 0.0)
>  
>       20.  
>  
>       21.     print('initialize network with %s' % init_type)
>  
>       22.     net.apply(init_func)
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/CNbgnLapVcIXUYtsiJle53WruFPy.png)

ConvBlock 类——两组连续应用的卷积操作后跟批量归一化层以及ReLU激活函数。

复制代码
>       1. class conv_block(nn.Module):

>  
>       2.     def __init__(self,ch_in,ch_out):
>  
>       3.         super(conv_block,self).__init__()
>  
>       4.         self.conv = nn.Sequential(
>  
>       5.             nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
>  
>       6.             nn.BatchNorm2d(ch_out),
>  
>       7.             nn.ReLU(inplace=True),
>  
>       8.             nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
>  
>       9.             nn.BatchNorm2d(ch_out),
>  
>       10.             nn.ReLU(inplace=True)
>  
>       11.         )
>  
>       12.  
>  
>       13.     def forward(self,x):
>  
>       14.         x = self.conv(x)
>  
>       15.         return x
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/w7bq0ody85f2TtJOZmGnQ4IkiLpu.png)

up_conv 类---上采样卷积块

复制代码
>       1. class up_conv(nn.Module):

>  
>       2.     def __init__(self,ch_in,ch_out):
>  
>       3.         super(up_conv,self).__init__()
>  
>       4.         self.up = nn.Sequential(
>  
>       5.             nn.Upsample(scale_factor=2),
>  
>       6.             nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
>  
>       7.             nn.BatchNorm2d(ch_out),
>  
>       8.             nn.ReLU(inplace=True)
>  
>       9.         )
>  
>       10.  
>  
>       11.     def forward(self,x):
>  
>       12.         x = self.up(x)
>  
>       13.         return x
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/Y7e8QJaD6gCRGrTWKn1L9Hk3czN5.png)

递归卷积块

多次进行卷积操作,丰富提取特征

复制代码
>       1. class Recurrent_block(nn.Module):

>  
>       2.     def __init__(self,ch_out,t=2):
>  
>       3.         super(Recurrent_block,self).__init__()
>  
>       4.         self.t = t
>  
>       5.         self.ch_out = ch_out
>  
>       6.         self.conv = nn.Sequential(
>  
>       7.             nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
>  
>       8.             nn.BatchNorm2d(ch_out),
>  
>       9.             nn.ReLU(inplace=True)
>  
>       10.         )
>  
>       11.  
>  
>       12.     def forward(self,x):
>  
>       13.         for i in range(self.t):
>  
>       14.  
>  
>       15.             if i==0:
>  
>       16.                 x1 = self.conv(x)
>  
>       17.  
>  
>       18.             x1 = self.conv(x+x1)
>  
>       19.         return x1
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/CWKBha6bjvDVA7ItcuJT5XGmdlMQ.png)

RRCNN_block:继续增强特征提取

复制代码
>       1. class RRCNN_block(nn.Module):

>  
>       2.     def __init__(self,ch_in,ch_out,t=2):
>  
>       3.         super(RRCNN_block,self).__init__()
>  
>       4.         self.RCNN = nn.Sequential(
>  
>       5.             Recurrent_block(ch_out,t=t),
>  
>       6.             Recurrent_block(ch_out,t=t)
>  
>       7.         )
>  
>       8.         self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
>  
>       9.  
>  
>       10.     def forward(self,x):
>  
>       11.         x = self.Conv_1x1(x)
>  
>       12.         x1 = self.RCNN(x)
>  
>       13.         return x+x1
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/8hu94vqg0CsjldXp32b1aSzWQB7K.png)

实现方式:把U-net中的卷积块编程RRCNN_block

复制代码
>       1. class R2U_Net(nn.Module):

>  
>       2.     def __init__(self,img_ch=3,output_ch=1,t=2):
>  
>       3.         super(R2U_Net,self).__init__()
>  
>       4.  
>  
>       5.         self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>  
>       6.         self.Upsample = nn.Upsample(scale_factor=2)
>  
>       7.  
>  
>       8.         self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
>  
>       9.  
>  
>       10.         self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
>  
>       11.  
>  
>       12.         self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
>  
>       13.  
>  
>       14.         self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
>  
>       15.  
>  
>       16.         self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
>  
>       17.  
>  
>       18.  
>  
>       19.         self.Up5 = up_conv(ch_in=1024,ch_out=512)
>  
>       20.         self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
>  
>       21.  
>  
>       22.         self.Up4 = up_conv(ch_in=512,ch_out=256)
>  
>       23.         self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
>  
>       24.  
>  
>       25.         self.Up3 = up_conv(ch_in=256,ch_out=128)
>  
>       26.         self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
>  
>       27.  
>  
>       28.         self.Up2 = up_conv(ch_in=128,ch_out=64)
>  
>       29.         self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
>  
>       30.  
>  
>       31.         self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>  
>       32.  
>  
>       33.  
>  
>       34.     def forward(self,x):
>  
>       35.         # encoding path
>  
>       36.         x1 = self.RRCNN1(x)
>  
>       37.  
>  
>       38.         x2 = self.Maxpool(x1)
>  
>       39.         x2 = self.RRCNN2(x2)
>  
>       40.  
>  
>       41.         x3 = self.Maxpool(x2)
>  
>       42.         x3 = self.RRCNN3(x3)
>  
>       43.  
>  
>       44.         x4 = self.Maxpool(x3)
>  
>       45.         x4 = self.RRCNN4(x4)
>  
>       46.  
>  
>       47.         x5 = self.Maxpool(x4)
>  
>       48.         x5 = self.RRCNN5(x5)
>  
>       49.  
>  
>       50.         # decoding + concat path
>  
>       51.         d5 = self.Up5(x5)
>  
>       52.         d5 = torch.cat((x4,d5),dim=1)
>  
>       53.         d5 = self.Up_RRCNN5(d5)
>  
>       54.  
>  
>       55.         d4 = self.Up4(d5)
>  
>       56.         d4 = torch.cat((x3,d4),dim=1)
>  
>       57.         d4 = self.Up_RRCNN4(d4)
>  
>       58.  
>  
>       59.         d3 = self.Up3(d4)
>  
>       60.         d3 = torch.cat((x2,d3),dim=1)
>  
>       61.         d3 = self.Up_RRCNN3(d3)
>  
>       62.  
>  
>       63.         d2 = self.Up2(d3)
>  
>       64.         d2 = torch.cat((x1,d2),dim=1)
>  
>       65.         d2 = self.Up_RRCNN2(d2)
>  
>       66.  
>  
>       67.         d1 = self.Conv_1x1(d2)
>  
>       68.  
>  
>       69.         return d1
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/DvJCeQd8bI4tc0u5AFgHW3yVXY17.png)

Attention_block(这里应该是重头戏吧)

复制代码
>       1. class Attention_block(nn.Module):

>  
>       2.     #F_g:来自解码器的特征图通道数
>  
>       3.     #F_l:来自编码器的特征图通道数
>  
>       4.     #F_int:中间特征图的通道数
>  
>       5.     def __init__(self,F_g,F_l,F_int):
>  
>       6.         super(Attention_block,self).__init__()
>  
>       7.         #对解码器特征图进行 1x1 卷积和批量归一化,用于调整通道数到 F_int
>  
>       8.         self.W_g = nn.Sequential(
>  
>       9.             nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
>  
>       10.             nn.BatchNorm2d(F_int)
>  
>       11.             )
>  
>       12.         #对编码器特征图进行 1x1 卷积和批量归一化,用于调整通道数到 F_int
>  
>       13.         self.W_x = nn.Sequential(
>  
>       14.             nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
>  
>       15.             nn.BatchNorm2d(F_int)
>  
>       16.         )
>  
>       17.         #将 F_int 通道数的特征图压缩为单通道特征图,通过 1x1 卷积、批量归一化和 Sigmoid 激活函数,输出注意力权重
>  
>       18.         self.psi = nn.Sequential(
>  
>       19.             nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
>  
>       20.             nn.BatchNorm2d(1),
>  
>       21.             nn.Sigmoid()
>  
>       22.         )
>  
>       23.  
>  
>       24.         self.relu = nn.ReLU(inplace=True)
>  
>       25.  
>  
>       26.     def forward(self,g,x):
>  
>       27.         g1 = self.W_g(g)
>  
>       28.         x1 = self.W_x(x)
>  
>       29.         psi = self.relu(g1+x1)
>  
>       30.         psi = self.psi(psi)
>  
>       31.  
>  
>       32.         return x*psi
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/f6ybr4Bjz9otMxaTwEViRdqQ1O8X.png)

实现方式:

复制代码
>       1. class AttU_Net(nn.Module):

>  
>       2.     def __init__(self,img_ch=3,output_ch=1):
>  
>       3.         super(AttU_Net,self).__init__()
>  
>       4.  
>  
>       5.         self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>  
>       6.  
>  
>       7.         self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
>  
>       8.         self.Conv2 = conv_block(ch_in=64,ch_out=128)
>  
>       9.         self.Conv3 = conv_block(ch_in=128,ch_out=256)
>  
>       10.         self.Conv4 = conv_block(ch_in=256,ch_out=512)
>  
>       11.         self.Conv5 = conv_block(ch_in=512,ch_out=1024)
>  
>       12.  
>  
>       13.         self.Up5 = up_conv(ch_in=1024,ch_out=512)
>  
>       14.         self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
>  
>       15.         self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
>  
>       16.  
>  
>       17.         self.Up4 = up_conv(ch_in=512,ch_out=256)
>  
>       18.         self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
>  
>       19.         self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
>  
>       20.  
>  
>       21.         self.Up3 = up_conv(ch_in=256,ch_out=128)
>  
>       22.         self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
>  
>       23.         self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
>  
>       24.  
>  
>       25.         self.Up2 = up_conv(ch_in=128,ch_out=64)
>  
>       26.         self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
>  
>       27.         self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
>  
>       28.  
>  
>       29.         self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>  
>       30.  
>  
>       31.  
>  
>       32.     def forward(self,x):
>  
>       33.         # encoding path
>  
>       34.         x1 = self.Conv1(x)
>  
>       35.  
>  
>       36.         x2 = self.Maxpool(x1)
>  
>       37.         x2 = self.Conv2(x2)
>  
>       38.  
>  
>       39.         x3 = self.Maxpool(x2)
>  
>       40.         x3 = self.Conv3(x3)
>  
>       41.  
>  
>       42.         x4 = self.Maxpool(x3)
>  
>       43.         x4 = self.Conv4(x4)
>  
>       44.  
>  
>       45.         x5 = self.Maxpool(x4)
>  
>       46.         x5 = self.Conv5(x5)
>  
>       47.  
>  
>       48.         # decoding + concat path
>  
>       49.         d5 = self.Up5(x5)
>  
>       50.         x4 = self.Att5(g=d5,x=x4)
>  
>       51.         d5 = torch.cat((x4,d5),dim=1)  
>  
>       52.         d5 = self.Up_conv5(d5)
>  
>       53.  
>  
>       54.         d4 = self.Up4(d5)
>  
>       55.         x3 = self.Att4(g=d4,x=x3)
>  
>       56.         d4 = torch.cat((x3,d4),dim=1)
>  
>       57.         d4 = self.Up_conv4(d4)
>  
>       58.  
>  
>       59.         d3 = self.Up3(d4)
>  
>       60.         x2 = self.Att3(g=d3,x=x2)
>  
>       61.         d3 = torch.cat((x2,d3),dim=1)
>  
>       62.         d3 = self.Up_conv3(d3)
>  
>       63.  
>  
>       64.         d2 = self.Up2(d3)
>  
>       65.         x1 = self.Att2(g=d2,x=x1)
>  
>       66.         d2 = torch.cat((x1,d2),dim=1)
>  
>       67.         d2 = self.Up_conv2(d2)
>  
>       68.  
>  
>       69.         d1 = self.Conv_1x1(d2)
>  
>       70.  
>  
>       71.         return d1
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/DZx8unhR7XqEsJF5yHzrBQbifV9C.png)

这是本人亲手绘制的手动推导过程中的可视化工具示意图。对于未完全理解该内容的学习者而言,请补充学习注意力机制中的Q/K/V组件后返回查看详细推导过程。


最后,将RRCNN_block和Attention_block结合,为R2AttU_Net模块:

复制代码
>       1. class R2AttU_Net(nn.Module):

>  
>       2.     def __init__(self,img_ch=3,output_ch=1,t=2):
>  
>       3.         super(R2AttU_Net,self).__init__()
>  
>       4.  
>  
>       5.         self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>  
>       6.         self.Upsample = nn.Upsample(scale_factor=2)
>  
>       7.  
>  
>       8.         self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
>  
>       9.  
>  
>       10.         self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
>  
>       11.  
>  
>       12.         self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
>  
>       13.  
>  
>       14.         self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
>  
>       15.  
>  
>       16.         self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
>  
>       17.  
>  
>       18.  
>  
>       19.         self.Up5 = up_conv(ch_in=1024,ch_out=512)
>  
>       20.         self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
>  
>       21.         self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
>  
>       22.  
>  
>       23.         self.Up4 = up_conv(ch_in=512,ch_out=256)
>  
>       24.         self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
>  
>       25.         self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
>  
>       26.  
>  
>       27.         self.Up3 = up_conv(ch_in=256,ch_out=128)
>  
>       28.         self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
>  
>       29.         self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
>  
>       30.  
>  
>       31.         self.Up2 = up_conv(ch_in=128,ch_out=64)
>  
>       32.         self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
>  
>       33.         self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
>  
>       34.  
>  
>       35.         self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>  
>       36.  
>  
>       37.  
>  
>       38.     def forward(self,x):
>  
>       39.         # encoding path
>  
>       40.         x1 = self.RRCNN1(x)
>  
>       41.  
>  
>       42.         x2 = self.Maxpool(x1)
>  
>       43.         x2 = self.RRCNN2(x2)
>  
>       44.  
>  
>       45.         x3 = self.Maxpool(x2)
>  
>       46.         x3 = self.RRCNN3(x3)
>  
>       47.  
>  
>       48.         x4 = self.Maxpool(x3)
>  
>       49.         x4 = self.RRCNN4(x4)
>  
>       50.  
>  
>       51.         x5 = self.Maxpool(x4)
>  
>       52.         x5 = self.RRCNN5(x5)
>  
>       53.  
>  
>       54.         # decoding + concat path
>  
>       55.         d5 = self.Up5(x5)
>  
>       56.         x4 = self.Att5(g=d5,x=x4)
>  
>       57.         d5 = torch.cat((x4,d5),dim=1)
>  
>       58.         d5 = self.Up_RRCNN5(d5)
>  
>       59.  
>  
>       60.         d4 = self.Up4(d5)
>  
>       61.         x3 = self.Att4(g=d4,x=x3)
>  
>       62.         d4 = torch.cat((x3,d4),dim=1)
>  
>       63.         d4 = self.Up_RRCNN4(d4)
>  
>       64.  
>  
>       65.         d3 = self.Up3(d4)
>  
>       66.         x2 = self.Att3(g=d3,x=x2)
>  
>       67.         d3 = torch.cat((x2,d3),dim=1)
>  
>       68.         d3 = self.Up_RRCNN3(d3)
>  
>       69.  
>  
>       70.         d2 = self.Up2(d3)
>  
>       71.         x1 = self.Att2(g=d2,x=x1)
>  
>       72.         d2 = torch.cat((x1,d2),dim=1)
>  
>       73.         d2 = self.Up_RRCNN2(d2)
>  
>       74.  
>  
>       75.         d1 = self.Conv_1x1(d2)
>  
>       76.  
>  
>       77.         return d1
>  
>  
>  
>  
>     python
>  
>     
>  
>     ![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/qrIwYl6fnZdHVOAoW51EzbyUN70Q.png)

ResUNet

与传统串行网络架构相比,在残差单元中引入了跳跃连接机制,在实现输入与输出之间直接联系的同时有效弥补了卷积操作中所丢失的重要特征信息。这种设计思路在某种程度上与其说网架构存在显著差异。
具体而言,在这一过程中,
ResNet主要采用的是简单的Add操作,
而U-Net则采用了更为复杂的Concatenate方式。

全部评论 (0)

还没有任何评论哟~