医学图像分割--U-net变种
综上所述,在关于医学图像分割的研究综述中提到了U-Net系列及其在实例分割技术中的应用。
2D Unet
-
收缩分支:其中每个模块包含两个连续的3 \times 3卷积操作后跟一个ReLU激活函数以及最大池化层(下采样)
-
扩展分支:该分支由一个2 \times 2转置卷积层(上采样)开始后跟两个连续的3 \times 3卷积操作以及一次ReLU激活
-
nn.BatchNorm2d(out_channels):通过批处理规范化技术对输入数据进行标准化处理,并非强制性要求必须采用此层;该操作有助于缓解梯度消失与爆炸的问题;同时能够加速网络训练过程并提升模型的泛化性能。
> 1. class DoubleConv(nn.Module):
>
> 2. def __init__(self, in_channels, out_channels, with_bn=False):
>
> 3. super().__init__()
>
> 4. if with_bn:
>
> 5. self.step = nn.Sequential(
>
> 6. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
>
> 7. nn.BatchNorm2d(out_channels),
>
> 8. nn.ReLU(),
>
> 9. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
>
> 10. nn.BatchNorm2d(out_channels),
>
> 11. nn.ReLU(),
>
> 12. )
>
> 13. else:
>
> 14. self.step = nn.Sequential(
>
> 15. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
>
> 16. nn.ReLU(),
>
> 17. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
>
> 18. nn.ReLU(),
>
> 19. )
>
> 20.
>
> 21. def forward(self, x):
>
> 22. return self.step(x)
>
>
>
>
> python
>
>
>
> 
定义了整个UNet网络结构,包括编码(下采样)和解码(上采样)部分。
> 1. class UNet(nn.Module):
>
> 2. def __init__(self, in_channels, out_channels, with_bn=False):
>
> 3. super().__init__()
>
> 4. init_channels = 32
>
> 5. self.out_channels = out_channels
>
> 6.
>
> 7. self.en_1 = DoubleConv(in_channels , init_channels , with_bn)
>
> 8. self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
>
> 9. self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
>
> 10. self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
>
> 11.
>
> 12. self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
>
> 13. self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
>
> 14. self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
>
> 15. self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
>
> 16.
>
> 17. self.maxpool = nn.MaxPool2d(kernel_size=2)
>
> 18. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
>
> 19.
>
> 20. def forward(self, x):
>
> 21. e1 = self.en_1(x)
>
> 22. e2 = self.en_2(self.maxpool(e1))
>
> 23. e3 = self.en_3(self.maxpool(e2))
>
> 24. e4 = self.en_4(self.maxpool(e3))
>
> 25.
>
> 26. d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
>
> 27. d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
>
> 28. d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
>
> 29. d4 = self.de_4(d3)
>
> 30.
>
> 31. return d4
>
>
>
>
> python
>
>
>
> 
以下是以3 \times 256 \times 256图片尺寸为例,在实际操作中进行了手动计算验证。对于存在疑问的读者来说,请确保能够自行核对代码进行验证。

跳过连接增强
- 该系统通过整合深层低分辨率语义信息与浅层高分辨率本地信息来实现对细节特征的有效分析。
- 在U-Net架构中,由于本地高分辨率细节在网络压缩阶段缺失,在后续上采样过程中难以完全恢复这些细节特征的影响。
增加跳过连接的数量
改进方案一基于具有双向跳跃连接的U-Net架构设计

- 改进二(U-net++)
- 参考:[网络模型(U-net,U-net++, U-net+++)_u-net+++官网-博客]( "网络模型(U-net,U-net++, U-net+++)_u-net+++官网-博客")
- 流程图:

- 整体过程为:

常规的双卷积模块(包含两个卷积层和批归一化层)VGGBlock结构,则与上文所述的DoubleConv结构具有相似之处。
> 1. class VGGBlock(nn.Module):
>
> 2. def __init__(self, in_channels, middle_channels, out_channels):
>
> 3. super().__init__()
>
> 4. self.relu = nn.ReLU(inplace=True)
>
> 5. self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
>
> 6. self.bn1 = nn.BatchNorm2d(middle_channels)
>
> 7. self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
>
> 8. self.bn2 = nn.BatchNorm2d(out_channels)
>
> 9.
>
> 10. def forward(self, x):
>
> 11. out = self.conv1(x)
>
> 12. out = self.bn1(out)
>
> 13. out = self.relu(out)
>
> 14.
>
> 15. out = self.conv2(out)
>
> 16. out = self.bn2(out)
>
> 17. out = self.relu(out)
>
> 18.
>
> 19. return out
>
>
>
>
> python
>
>
>
> 
U-Net与VGG系列模型类似,在阅读过程中一定要自己推导一遍U-Net模型的基本原理和实现细节,在后续章节的学习中将会更加顺畅。
> 1. class UNet(nn.Module):
>
> 2. def __init__(self, num_classes, input_channels=3, **kwargs):
>
> 3. super().__init__()
>
> 4.
>
> 5. nb_filter = [32, 64, 128, 256, 512]
>
> 6.
>
> 7. self.pool = nn.MaxPool2d(2, 2)
>
> 8. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>
> 9.
>
> 10. self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
>
> 11. self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
>
> 12. self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
>
> 13. self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
>
> 14. self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
>
> 15.
>
> 16. self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
>
> 17. self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
>
> 18. self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
>
> 19. self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
>
> 20.
>
> 21. self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 22.
>
> 23.
>
> 24. def forward(self, input):
>
> 25. x0_0 = self.conv0_0(input)
>
> 26. x1_0 = self.conv1_0(self.pool(x0_0))
>
> 27. x2_0 = self.conv2_0(self.pool(x1_0))
>
> 28. x3_0 = self.conv3_0(self.pool(x2_0))
>
> 29. x4_0 = self.conv4_0(self.pool(x3_0))
>
> 30.
>
> 31. x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
>
> 32. x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
>
> 33. x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
>
> 34. x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
>
> 35.
>
> 36. output = self.final(x0_4)
>
> 37. return output
>
>
>
>
> python
>
>
>
> 
核心部分即为NestedUNet,用来替代Unet,整体流程如上述两张图片所示
> 1. class NestedUNet(nn.Module):
>
> 2. def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
>
> 3. super().__init__()
>
> 4.
>
> 5. nb_filter = [32, 64, 128, 256, 512]
>
> 6.
>
> 7. self.deep_supervision = deep_supervision
>
> 8.
>
> 9. self.pool = nn.MaxPool2d(2, 2)
>
> 10. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>
> 11. #编码器
>
> 12. self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
>
> 13. self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
>
> 14. self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
>
> 15. self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
>
> 16. self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
>
> 17. #解码器
>
> 18. self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
>
> 19. self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
>
> 20. self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
>
> 21. self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
>
> 22.
>
> 23. self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
>
> 24. self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
>
> 25. self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
>
> 26. #拿conv0_3举例,*3是因为(0,0),(0,1),(0,2)跳跃连接,+nb_filter[1]是因为(1,2)传递过来
>
> 27. self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
>
> 28. self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
>
> 29.
>
> 30. self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
>
> 31.
>
> 32. if self.deep_supervision:
>
> 33. self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 34. self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 35. self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 36. self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 37. else:
>
> 38. self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
>
> 39.
>
> 40.
>
> 41. def forward(self, input):
>
> 42. x0_0 = self.conv0_0(input)
>
> 43. x1_0 = self.conv1_0(self.pool(x0_0))
>
> 44. x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
>
> 45.
>
> 46. x2_0 = self.conv2_0(self.pool(x1_0))
>
> 47. x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
>
> 48. x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
>
> 49.
>
> 50. x3_0 = self.conv3_0(self.pool(x2_0))
>
> 51. x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
>
> 52. x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
>
> 53. x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
>
> 54.
>
> 55. x4_0 = self.conv4_0(self.pool(x3_0))
>
> 56. x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
>
> 57. x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
>
> 58. x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
>
> 59. x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
>
> 60.
>
> 61. if self.deep_supervision:
>
> 62. output1 = self.final1(x0_1)
>
> 63. output2 = self.final2(x0_2)
>
> 64. output3 = self.final3(x0_3)
>
> 65. output4 = self.final4(x0_4)
>
> 66. return [output1, output2, output3, output4]
>
> 67.
>
> 68. else:
>
> 69. output = self.final(x0_4)
>
> 70. return output
>
>
>
>
> python
>
>
>
> 
在跳过连接中处理特征映射
- 改进一
- 本研究提出了一种新型算法,在超声医学图像分析领域取得了显著成果;该算法特别关注并解决卵巢及卵泡分割这一极具挑战性的任务,并显著提升了模型的整体性能。
- 相邻滤波的空间相互关系对模型性能至关重要;采用的空间循环神经网络(RNNs)能够有效捕捉这种相互关联性。
- 原有U-Net架构中的最大池化操作可能导致重要的边缘信息丢失;这种做法可能导致模型在捕捉小区域特征时出现偏差。
- 改进二(Attunet)


初始化层init_weights
> 1. def init_weights(net, init_type='normal', gain=0.02):
>
> 2. def init_func(m):
>
> 3. classname = m.__class__.__name__
>
> 4. if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
>
> 5. if init_type == 'normal':
>
> 6. init.normal_(m.weight.data, 0.0, gain)
>
> 7. elif init_type == 'xavier':
>
> 8. init.xavier_normal_(m.weight.data, gain=gain)
>
> 9. elif init_type == 'kaiming':
>
> 10. init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
>
> 11. elif init_type == 'orthogonal':
>
> 12. init.orthogonal_(m.weight.data, gain=gain)
>
> 13. else:
>
> 14. raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
>
> 15. if hasattr(m, 'bias') and m.bias is not None:
>
> 16. init.constant_(m.bias.data, 0.0)
>
> 17. elif classname.find('BatchNorm2d') != -1:
>
> 18. init.normal_(m.weight.data, 1.0, gain)
>
> 19. init.constant_(m.bias.data, 0.0)
>
> 20.
>
> 21. print('initialize network with %s' % init_type)
>
> 22. net.apply(init_func)
>
>
>
>
> python
>
>
>
> 
ConvBlock类——两组连续应用的卷积操作后跟批量归一化层以及ReLU激活函数。
> 1. class conv_block(nn.Module):
>
> 2. def __init__(self,ch_in,ch_out):
>
> 3. super(conv_block,self).__init__()
>
> 4. self.conv = nn.Sequential(
>
> 5. nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
>
> 6. nn.BatchNorm2d(ch_out),
>
> 7. nn.ReLU(inplace=True),
>
> 8. nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
>
> 9. nn.BatchNorm2d(ch_out),
>
> 10. nn.ReLU(inplace=True)
>
> 11. )
>
> 12.
>
> 13. def forward(self,x):
>
> 14. x = self.conv(x)
>
> 15. return x
>
>
>
>
> python
>
>
>
> 
up_conv类---上采样卷积块
> 1. class up_conv(nn.Module):
>
> 2. def __init__(self,ch_in,ch_out):
>
> 3. super(up_conv,self).__init__()
>
> 4. self.up = nn.Sequential(
>
> 5. nn.Upsample(scale_factor=2),
>
> 6. nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
>
> 7. nn.BatchNorm2d(ch_out),
>
> 8. nn.ReLU(inplace=True)
>
> 9. )
>
> 10.
>
> 11. def forward(self,x):
>
> 12. x = self.up(x)
>
> 13. return x
>
>
>
>
> python
>
>
>
> 
递归卷积块
多次进行卷积操作,丰富提取特征
> 1. class Recurrent_block(nn.Module):
>
> 2. def __init__(self,ch_out,t=2):
>
> 3. super(Recurrent_block,self).__init__()
>
> 4. self.t = t
>
> 5. self.ch_out = ch_out
>
> 6. self.conv = nn.Sequential(
>
> 7. nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
>
> 8. nn.BatchNorm2d(ch_out),
>
> 9. nn.ReLU(inplace=True)
>
> 10. )
>
> 11.
>
> 12. def forward(self,x):
>
> 13. for i in range(self.t):
>
> 14.
>
> 15. if i==0:
>
> 16. x1 = self.conv(x)
>
> 17.
>
> 18. x1 = self.conv(x+x1)
>
> 19. return x1
>
>
>
>
> python
>
>
>
> 
RRCNN_block:继续增强特征提取
> 1. class RRCNN_block(nn.Module):
>
> 2. def __init__(self,ch_in,ch_out,t=2):
>
> 3. super(RRCNN_block,self).__init__()
>
> 4. self.RCNN = nn.Sequential(
>
> 5. Recurrent_block(ch_out,t=t),
>
> 6. Recurrent_block(ch_out,t=t)
>
> 7. )
>
> 8. self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
>
> 9.
>
> 10. def forward(self,x):
>
> 11. x = self.Conv_1x1(x)
>
> 12. x1 = self.RCNN(x)
>
> 13. return x+x1
>
>
>
>
> python
>
>
>
> 
实现方式:把U-net中的卷积块编程RRCNN_block
> 1. class R2U_Net(nn.Module):
>
> 2. def __init__(self,img_ch=3,output_ch=1,t=2):
>
> 3. super(R2U_Net,self).__init__()
>
> 4.
>
> 5. self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>
> 6. self.Upsample = nn.Upsample(scale_factor=2)
>
> 7.
>
> 8. self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
>
> 9.
>
> 10. self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
>
> 11.
>
> 12. self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
>
> 13.
>
> 14. self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
>
> 15.
>
> 16. self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
>
> 17.
>
> 18.
>
> 19. self.Up5 = up_conv(ch_in=1024,ch_out=512)
>
> 20. self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
>
> 21.
>
> 22. self.Up4 = up_conv(ch_in=512,ch_out=256)
>
> 23. self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
>
> 24.
>
> 25. self.Up3 = up_conv(ch_in=256,ch_out=128)
>
> 26. self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
>
> 27.
>
> 28. self.Up2 = up_conv(ch_in=128,ch_out=64)
>
> 29. self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
>
> 30.
>
> 31. self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>
> 32.
>
> 33.
>
> 34. def forward(self,x):
>
> 35. # encoding path
>
> 36. x1 = self.RRCNN1(x)
>
> 37.
>
> 38. x2 = self.Maxpool(x1)
>
> 39. x2 = self.RRCNN2(x2)
>
> 40.
>
> 41. x3 = self.Maxpool(x2)
>
> 42. x3 = self.RRCNN3(x3)
>
> 43.
>
> 44. x4 = self.Maxpool(x3)
>
> 45. x4 = self.RRCNN4(x4)
>
> 46.
>
> 47. x5 = self.Maxpool(x4)
>
> 48. x5 = self.RRCNN5(x5)
>
> 49.
>
> 50. # decoding + concat path
>
> 51. d5 = self.Up5(x5)
>
> 52. d5 = torch.cat((x4,d5),dim=1)
>
> 53. d5 = self.Up_RRCNN5(d5)
>
> 54.
>
> 55. d4 = self.Up4(d5)
>
> 56. d4 = torch.cat((x3,d4),dim=1)
>
> 57. d4 = self.Up_RRCNN4(d4)
>
> 58.
>
> 59. d3 = self.Up3(d4)
>
> 60. d3 = torch.cat((x2,d3),dim=1)
>
> 61. d3 = self.Up_RRCNN3(d3)
>
> 62.
>
> 63. d2 = self.Up2(d3)
>
> 64. d2 = torch.cat((x1,d2),dim=1)
>
> 65. d2 = self.Up_RRCNN2(d2)
>
> 66.
>
> 67. d1 = self.Conv_1x1(d2)
>
> 68.
>
> 69. return d1
>
>
>
>
> python
>
>
>
> 
Attention_block(这里应该是重头戏吧)
> 1. class Attention_block(nn.Module):
>
> 2. #F_g:来自解码器的特征图通道数
>
> 3. #F_l:来自编码器的特征图通道数
>
> 4. #F_int:中间特征图的通道数
>
> 5. def __init__(self,F_g,F_l,F_int):
>
> 6. super(Attention_block,self).__init__()
>
> 7. #对解码器特征图进行 1x1 卷积和批量归一化,用于调整通道数到 F_int
>
> 8. self.W_g = nn.Sequential(
>
> 9. nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
>
> 10. nn.BatchNorm2d(F_int)
>
> 11. )
>
> 12. #对编码器特征图进行 1x1 卷积和批量归一化,用于调整通道数到 F_int
>
> 13. self.W_x = nn.Sequential(
>
> 14. nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
>
> 15. nn.BatchNorm2d(F_int)
>
> 16. )
>
> 17. #将 F_int 通道数的特征图压缩为单通道特征图,通过 1x1 卷积、批量归一化和 Sigmoid 激活函数,输出注意力权重
>
> 18. self.psi = nn.Sequential(
>
> 19. nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
>
> 20. nn.BatchNorm2d(1),
>
> 21. nn.Sigmoid()
>
> 22. )
>
> 23.
>
> 24. self.relu = nn.ReLU(inplace=True)
>
> 25.
>
> 26. def forward(self,g,x):
>
> 27. g1 = self.W_g(g)
>
> 28. x1 = self.W_x(x)
>
> 29. psi = self.relu(g1+x1)
>
> 30. psi = self.psi(psi)
>
> 31.
>
> 32. return x*psi
>
>
>
>
> python
>
>
>
> 
实现方式:
> 1. class AttU_Net(nn.Module):
>
> 2. def __init__(self,img_ch=3,output_ch=1):
>
> 3. super(AttU_Net,self).__init__()
>
> 4.
>
> 5. self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>
> 6.
>
> 7. self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
>
> 8. self.Conv2 = conv_block(ch_in=64,ch_out=128)
>
> 9. self.Conv3 = conv_block(ch_in=128,ch_out=256)
>
> 10. self.Conv4 = conv_block(ch_in=256,ch_out=512)
>
> 11. self.Conv5 = conv_block(ch_in=512,ch_out=1024)
>
> 12.
>
> 13. self.Up5 = up_conv(ch_in=1024,ch_out=512)
>
> 14. self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
>
> 15. self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
>
> 16.
>
> 17. self.Up4 = up_conv(ch_in=512,ch_out=256)
>
> 18. self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
>
> 19. self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
>
> 20.
>
> 21. self.Up3 = up_conv(ch_in=256,ch_out=128)
>
> 22. self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
>
> 23. self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
>
> 24.
>
> 25. self.Up2 = up_conv(ch_in=128,ch_out=64)
>
> 26. self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
>
> 27. self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
>
> 28.
>
> 29. self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>
> 30.
>
> 31.
>
> 32. def forward(self,x):
>
> 33. # encoding path
>
> 34. x1 = self.Conv1(x)
>
> 35.
>
> 36. x2 = self.Maxpool(x1)
>
> 37. x2 = self.Conv2(x2)
>
> 38.
>
> 39. x3 = self.Maxpool(x2)
>
> 40. x3 = self.Conv3(x3)
>
> 41.
>
> 42. x4 = self.Maxpool(x3)
>
> 43. x4 = self.Conv4(x4)
>
> 44.
>
> 45. x5 = self.Maxpool(x4)
>
> 46. x5 = self.Conv5(x5)
>
> 47.
>
> 48. # decoding + concat path
>
> 49. d5 = self.Up5(x5)
>
> 50. x4 = self.Att5(g=d5,x=x4)
>
> 51. d5 = torch.cat((x4,d5),dim=1)
>
> 52. d5 = self.Up_conv5(d5)
>
> 53.
>
> 54. d4 = self.Up4(d5)
>
> 55. x3 = self.Att4(g=d4,x=x3)
>
> 56. d4 = torch.cat((x3,d4),dim=1)
>
> 57. d4 = self.Up_conv4(d4)
>
> 58.
>
> 59. d3 = self.Up3(d4)
>
> 60. x2 = self.Att3(g=d3,x=x2)
>
> 61. d3 = torch.cat((x2,d3),dim=1)
>
> 62. d3 = self.Up_conv3(d3)
>
> 63.
>
> 64. d2 = self.Up2(d3)
>
> 65. x1 = self.Att2(g=d2,x=x1)
>
> 66. d2 = torch.cat((x1,d2),dim=1)
>
> 67. d2 = self.Up_conv2(d2)
>
> 68.
>
> 69. d1 = self.Conv_1x1(d2)
>
> 70.
>
> 71. return d1
>
>
>
>
> python
>
>
>
> 
这是本人亲手绘制的手动推导过程中的可视化工具示意图。对于未完全理解该内容的学习者而言,请补充学习注意力机制中的Q/K/V组件后返回查看详细推导过程。

最后,将RRCNN_block和Attention_block结合,为R2AttU_Net模块:
> 1. class R2AttU_Net(nn.Module):
>
> 2. def __init__(self,img_ch=3,output_ch=1,t=2):
>
> 3. super(R2AttU_Net,self).__init__()
>
> 4.
>
> 5. self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
>
> 6. self.Upsample = nn.Upsample(scale_factor=2)
>
> 7.
>
> 8. self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
>
> 9.
>
> 10. self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
>
> 11.
>
> 12. self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
>
> 13.
>
> 14. self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
>
> 15.
>
> 16. self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
>
> 17.
>
> 18.
>
> 19. self.Up5 = up_conv(ch_in=1024,ch_out=512)
>
> 20. self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
>
> 21. self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
>
> 22.
>
> 23. self.Up4 = up_conv(ch_in=512,ch_out=256)
>
> 24. self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
>
> 25. self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
>
> 26.
>
> 27. self.Up3 = up_conv(ch_in=256,ch_out=128)
>
> 28. self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
>
> 29. self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
>
> 30.
>
> 31. self.Up2 = up_conv(ch_in=128,ch_out=64)
>
> 32. self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
>
> 33. self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
>
> 34.
>
> 35. self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
>
> 36.
>
> 37.
>
> 38. def forward(self,x):
>
> 39. # encoding path
>
> 40. x1 = self.RRCNN1(x)
>
> 41.
>
> 42. x2 = self.Maxpool(x1)
>
> 43. x2 = self.RRCNN2(x2)
>
> 44.
>
> 45. x3 = self.Maxpool(x2)
>
> 46. x3 = self.RRCNN3(x3)
>
> 47.
>
> 48. x4 = self.Maxpool(x3)
>
> 49. x4 = self.RRCNN4(x4)
>
> 50.
>
> 51. x5 = self.Maxpool(x4)
>
> 52. x5 = self.RRCNN5(x5)
>
> 53.
>
> 54. # decoding + concat path
>
> 55. d5 = self.Up5(x5)
>
> 56. x4 = self.Att5(g=d5,x=x4)
>
> 57. d5 = torch.cat((x4,d5),dim=1)
>
> 58. d5 = self.Up_RRCNN5(d5)
>
> 59.
>
> 60. d4 = self.Up4(d5)
>
> 61. x3 = self.Att4(g=d4,x=x3)
>
> 62. d4 = torch.cat((x3,d4),dim=1)
>
> 63. d4 = self.Up_RRCNN4(d4)
>
> 64.
>
> 65. d3 = self.Up3(d4)
>
> 66. x2 = self.Att3(g=d3,x=x2)
>
> 67. d3 = torch.cat((x2,d3),dim=1)
>
> 68. d3 = self.Up_RRCNN3(d3)
>
> 69.
>
> 70. d2 = self.Up2(d3)
>
> 71. x1 = self.Att2(g=d2,x=x1)
>
> 72. d2 = torch.cat((x1,d2),dim=1)
>
> 73. d2 = self.Up_RRCNN2(d2)
>
> 74.
>
> 75. d1 = self.Conv_1x1(d2)
>
> 76.
>
> 77. return d1
>
>
>
>
> python
>
>
>
> 
ResUNet
与传统串行网络架构相比,在残差单元中引入了跳跃连接机制,在实现输入与输出之间直接联系的同时有效弥补了卷积操作中所丢失的重要特征信息。这种设计思路在某种程度上与其说网架构存在显著差异。
具体而言,在这一过程中,
ResNet主要采用的是简单的Add操作,
而U-Net则采用了更为复杂的Concatenate方式。
