Advertisement

医学图像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation

阅读量:

The TransUnet architecture stands out as an effective solution within the domain of medical image segmentation, where it excels by leveraging transformer-based encoders to achieve precise and reliable segmentation results.

复制代码
    这篇文章中你可以找到一下内容:
    - Attention是怎么样在CNN中火起来的?-Non Local
    - Transformer结构带来了什么?-Multi Head Self Attention
    - Transformer结构为何在CV中如此流行?-Vision Transformer和SETR
    - TransUnet又是如何魔改Unet和Transformer?-ResNet50+VIT作为backbone\Encoder
    - TransUnet的pytorch代码实现
    - 作者吐槽以及偷懒的痕迹

引文

在医学图像分割领域中,U-shaped网络特别是Unet架构已经展现了显著的应用前景。然而,在建立远程信息连接方面(即实现长距离信息传递),CNN表现欠佳。其局限性在于感受野范围有限。尽管可以通过重复使用CNN结构并配合空洞卷积等方式来扩大感受野幅度,但这种方式会带来一些潜在问题(例如导致卷积核退化以及产生栅格化现象)。

基于self-attention机制的Transformer架构已在NLP领域展现出显著的效果,在当年Vision Transformer成功引入了该架构至计算机视觉领域并实现了卓越的效果。这一突破性进展使得该方法在计算机视觉领域得到了广泛应用并迅速流行开来。

话说回来,为什么Transformer结构能够在CV领域中获得不错的效果?

Attention is all you need?

在介绍Transformer之前,我们可以先探索一下CNN结构中有哪些有趣的部分. 对 Non Local结构 值得我们进行回顾.

Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k} } )V

从Non Local角度出发,在17至18年间多个重要会议上神经网络领域掀起了一股关注热潮

自然地来说,为什么会想到采用Non Local方法来计算Attention呢?这是因为Non Local的研究者们受到了Transformer这一模型结构的强烈启发。这促使他们再次回望起这篇奠定现代自注意力机制基础的经典论文——《Attention is all you need》。

该论文主要涉及两个核心工作:其一是提出了一种称为Transformer的技术;其二是采用多头注意力机制作为其关键组成部分的技术。即通过多头注意力机制替代传统的单头注意力技术。

Transformer架构简单,并且主要包括以下三个关键模块:Multi-Head Atten-tion、FeedForward Network(FFN)以及Norm(归一化)机制。值得注意的是,在这一部分机制中存在一些特殊的处理方式

Multi-Head Attention其实并不复杂其本质就是一种多头注意力机制。其中Multi-Head Attn属于Attention机制中的一种即通过多个头来处理信息每个头负责计算一组注意力即每个头都能独立地关注不同的信息点这使得整个模型能够从不同角度理解和捕捉数据特征通过将这些结果综合考虑进去从而能够提供更为全面的信息

多个头注意力机制通过将head_1到head_h进行连接并乘以W^O来计算。
每个头的信息由注意力机制作用于经过线性变换后的查询、键和值向量生成。

Vision Transformer - the pioneer from CNN to Transformer

Vision Transformer可分为引领发展的先驱者与关键助力者,在非Vit时代还需长时间依赖非局域机制感到有些无奈。然而随着Transformer技术逐渐成熟和完善中……

实现原理较为基础,在处理序列数据方面Transformer已经具备很强的能力。然而由于图像数据无法直接输入到Transformer中进行处理。因此提出者Vit创造了一种创新方法将图像划分为9个区域块当然同样地还可以将其划分为16个区域块等数量的块具体情况取决于你所选择的区域大小。随后将这些区域块按顺序排列后连接成一个序列在经过位置编码处理后就可以将这个序列输入到Transformer中进行计算了这里的位移编码主要用于帮助模型理解各个区域块之间的相对位置关系从而提升模型的学习能力。

在ImageNet上取得显著成功后,VIT给计算机视觉领域带来了新的期待.分割无疑是计算机视觉的核心问题之一,既然VIT具备了分类能力,那么它就可以效仿ResNet的角色,在分割任务中担任Backbone角色.

SERT Vit也能用于语义分割!

在另一篇顶级会议论文《Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers》中发表的SERT研究就最先将ViT作为主干网络用于实现这一任务。

SERT模型搭建不复杂,并主要采用主流的encoder-decoder架构以Vit为骨架构建了三种不同的Decoder结构,并用于进行语义分割实验。这个方案非常简单,实现起来就能获得不错的性能(感谢Vit在顶会中提供便利)。

正文

在开头部分冗长地阐述了许多观点, 这些内容主要涉及深度学习中的关键技术和概念, 包括卷积神经网络(CNN)、注意力机制(Attention)、非局部操作(Non Local)以及变压器(Transformer)等技术点。随后, 我们转向讨论TransUnet模型的核心设计。值得注意的是, 许多现代计算机视觉(CV)论文中都存在大量文字性描述(尽管有时像TransUnet这样的架构也可能如此)。然而, 这种文字性描述本质上是一种技巧, 能够有效传达关键信息并引导读者理解核心算法流程。如图所示, TransUnet采用了经典的编码器-解码器架构框架

仍然是具有较强的代表性的Unet架构。然而与基于CNN的传统UNet相比,在这里其前三层采用的是基于CNN的技术,并且在最后一层采用了基于Transformer的技术。这相当于将UNet编码器的最后一层替换了Transformer模块。

为什么只有一层Transformer

TransUnet仅部分采用Transformer考虑到其独特优势。尽管Transformer能够捕捉到全局信息但在细节特征的捕捉上却存在不足。SegFormer:《Segmenter: Transformer for Semantic Segmentation》一文中探讨了不同Patch尺寸对模型预测结果的影响研究表明在计算速度方面较大尺寸的Patch表现更为突出但其在边缘分割方面的性能却明显不如较小尺寸的Patch因此在实际应用中较小尺寸的Patch往往能提供更为精确的结果

大量事实表明 该架构在精细粒度特征分割方面表现欠佳 相比之下 在利用有限的空间感受野进行特征提取方面具有明显优势 因此 TransUnet模型仅替换顶层编码器模块 该模块主要负责捕捉整体特征模式 这正是该架构在精细粒度特征分割方面的强项所在 而针对浅层细节识别 则由传统的卷积神经网络负责

TransUnet具体细节

  • decoder部分的设计相对简洁明了。
    • 对于encoder组件:
      • 作者采用了ResNet50网络的前三层来做CNN模块。
      • 这一设计非常出色。
      • 最后一层采用了Vision Transformer(Vit)架构。
      • 将编码器部分命名为R50-ViT模型。

关于Vit的相关介绍, 可参考另一篇文章: VIT+SETR, 本文有意略过.

建议提醒一下注意,在给定Vit的输入维度为(b,c,W,H),当设置patch size=P时,在patch size=P的情况下,则其输出结果为(b,c,W/P,H/P),具体来说,则是将宽高缩放比例分别为H/P和W/HP,并应将其上采样至(W,H)尺寸

TransUnet模型实现

Encoder部分

该编码器主要由ResNet50模块与Vision Transformer(Vit)组件构成。在ResNet50模块中, 通过移除stem_block结构中的4倍下采样操作, 保留前三层模型架构设计, 其中这三者均采用两倍的下采样策略, 并将最后一层输出直接传递给Vit模块使用作为输入数据。通过这一设计安排, 在保持特征图尺寸、通道数量与原始图像一致的基础上实现了有效的特征提取。

复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
    
        out = self.conv2(out)
        out = self.bn2(out)
    
        if self.downsample is not None:
            identity = self.downsample(x)
    
        out += identity
        out = self.relu(out)
    
        return out
    
    
    class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
    
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
    
        out = self.conv3(out)
        out = self.bn3(out)
    
        if self.downsample is not None:
            identity = self.downsample(x)
    
        out += identity
        out = self.relu(out)
        return out
    
    
    class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64//4, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128//4, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256//4, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512//4, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
    
    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))
    
        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)
    
    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        out.append(x)
        x = self.layer2(x)
        out.append(x)
        x = self.layer3(x)
        out.append(x)
        # 最后一层不输出
        # x = self.layer4(x)
        # out.append(x)
        return out
    
    def forward(self, x) :
        return self._forward_impl(x)
    
    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3], pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3], pretrained_path,**kwargs)
    
    if __name__ == "__main__":
    v = ResNet.resnet50().cuda()
    img = torch.randn(1, 3, 512, 512).cuda()
    preds = v(img)
    # torch.Size([1, 64, 256, 256])
    print(preds[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(preds[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(preds[2].shape)

接着是Vit部分,Vit接受ResNet50的第三个输出。

复制代码
    import torch
    from torch import nn
    from einops import rearrange, repeat
    from einops.layers.torch import Rearrange
    
    
    def pair(t):
    return t if isinstance(t, tuple) else (t, t)
    
    class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    
    class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
    
    class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
    
        self.heads = heads
        self.scale = dim_head ** -0.5
    
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
    
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
    
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
    
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
    
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
    
        attn = self.attend(dots)
        attn = self.dropout(attn)
    
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
    
    class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
    
    class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 512, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
    
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
    
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
    
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
    
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
    
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
    
        self.out = Rearrange("b (h w) c->b c h w", h=image_height//patch_height, w=image_width//patch_width)
        
    		# 这里上采样倍数为8倍。为了保持和图中的feature size一样
        self.upsample = nn.UpsamplingBilinear2d(scale_factor = patch_size//2)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.ReLU())
    
    def forward(self, img):
    	# 这里对应了图中的Linear Projection,主要是将图片分块嵌入,成为一个序列
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        # 为图像切片序列加上索引
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        # 输入到Transformer中处理
        x = self.transformer(x)
    
        # delete cls_tokens, 输出前需要删除掉索引
        output = x[:,1:,:]
        output = self.out(output)
    
        # Transformer输出后,上采样到原始尺寸
        output = self.upsample(output)
        output = self.conv(output)
    
        return output
    
    
    import torch
    if __name__ == "__main__":
    v = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1).cpu()
    # 假设ResNet50第三层输出大小是 1, 256, 64, 64 也就是b, c, W/8, H/8
    img = torch.randn(1, 256, 64, 64).cpu()
    preds = v(img)
    # 输出是 b, c, W/16, H/16
    # preds:  torch.Size([1, 512, 32, 32])
    print("preds: ",preds.size())

再把两个部分合并一下,包装成TransUnetEncoder类。

复制代码
    class TransUnetEncoder(nn.Module):
    def __init__(self, **kwargs):
        super(TransUnetEncoder, self).__init__()
        self.R50 = ResNet.resnet50()
        self.Vit = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1)
    
    def forward(self, x):
        x1, x2, x3 = self.R50(x)
        x4 = self.Vit(x3)
        return [x1, x2, x3, x4]
    
    if __name__ == "__main__":
    x = torch.randn(1, 3, 512, 512).cuda()
    net = TransUnetEncoder().cuda()
    out = net(x)
    # torch.Size([1, 64, 256, 256])
    print(out[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(out[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(out[2].shape)
    # torch.Size([1, 512, 32, 32])
    print(out[3].shape)

Decoder部分

Decoder部分属于经典的Unet decoder模块,在处理过程中被用来接收skip connections,并依次应用卷积操作和上采样过程。为了便于整合到更大的架构中,同样地将整个结构封装为TransUnetDecoder类。

复制代码
    class TransUnetDecoder(nn.Module):
    def __init__(self, out_channels=64, **kwargs):
        super(TransUnetDecoder, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.Conv2d(out_channels//4, out_channels//4, 3, padding=1), 
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()            
        )
        self.upsample1 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels, out_channels//4, 3, padding=1),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()     
        )
    
        self.decoder2 = nn.Sequential(
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()            
        )
        self.upsample2 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()     
        )
    
        self.decoder3 = nn.Sequential(
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()            
        )        
        self.upsample3 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()     
        )
    
        self.decoder4 = nn.Sequential(
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()                           
        )
        self.upsample4 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()     
        )
    
    def forward(self, inputs):
        x1, x2, x3, x4 = inputs
        # b 512 H/8 W/8
        
        x4 = self.upsample4(x4)
        x = self.decoder4(torch.cat([x4, x3], dim=1))        
        
        x = self.upsample3(x)
        x = self.decoder3(torch.cat([x, x2], dim=1))
    
        x = self.upsample2(x)
        x = self.decoder2(torch.cat([x, x1], dim=1))
    
        x = self.upsample1(x)
        x = self.decoder1(x)
    
        return x
    
    if __name__ == "__main__":
    x1 = torch.randn([1, 64, 256, 256]).cuda()
    x2 = torch.randn([1, 128, 128, 128]).cuda()
    x3 = torch.randn([1, 256, 64, 64]).cuda()
    x4 = torch.randn([1, 512, 32, 32]).cuda()
    net = TransUnetDecoder().cuda()
    out = net([x1,x2,x3,x4])
    # out: torch.Size([1, 16, 512, 512])
    print(out.shape)

TransUnet类

最后将Encoder和Decoder包装成TransUnet。

复制代码
    class TransUnet(nn.Module):
    	# 主要是修改num_classes 
    def __init__(self, num_classes=4, **kwargs):
        super(TransUnet, self).__init__()
        self.TransUnetEncoder = TransUnetEncoder()
        self.TransUnetDecoder = TransUnetDecoder()
        self.cls_head = nn.Conv2d(16, num_classes, 1)
    def forward(self, x):
        x = self.TransUnetEncoder(x)
        x = self.TransUnetDecoder(x)
        x = self.cls_head(x)
        return x
    
    if __name__ == "__main__":
    	# 输入的图像尺寸 [1, 3, 512, 512]
    x1 = torch.randn([1, 3, 512, 512]).cuda()
    net = TransUnet().cuda()
    out = net(x1)
    # 输出的结果[batch, num_classes, 512, 512]
    print(out.shape)

在Camvid测试集上测试一下

由于缺乏专业的医学图像数据集进行测试,在我的电脑上使用一个通用的数据集进行分割效果测试会更加高效一些

复制代码
    # 导入库
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    import warnings
    warnings.filterwarnings("ignore")
    from PIL import Image
    import numpy as np
    import albumentations as A
    from albumentations.pytorch.transforms import ToTensorV2
     
    torch.manual_seed(17)
    # 自定义数据集CamVidDataset
    class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(512, 512),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
     
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
    # 设置数据集路径
    DATA_DIR = r'../blork_file/dataset//camvid/' # 根据自己的路径来设置
    x_train_dir = os.path.join(DATA_DIR, 'train_images')
    y_train_dir = os.path.join(DATA_DIR, 'train_labels')
    x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
    y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
    train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
    )
    val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
    )
     
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, drop_last=True)

一些模型和训练过程设置

复制代码
    from d2l import torch as d2l
    from tqdm import tqdm
    import pandas as pd
    import monai
    # model
    model = TransUnet(num_classes=33).cuda()
    # training loop 100 epochs
    epochs_num = 100
    # 选用SGD优化器来训练
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)
    
    # 损失函数选用多分类交叉熵损失函数
    lossf = nn.CrossEntropyLoss(ignore_index=255)
    
    def evaluate_accuracy_gpu(net, data_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)
    
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            output = net(X)
            metric.add(d2l.accuracy(output, y), d2l.size(y))
    return metric[0] / metric[1]
    
    
    # 训练函数
    def train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, schedule, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 用来保存一些训练参数
    
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    lr_list = []
    
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (X, labels) in enumerate(train_iter):
            timer.start()
    
            if isinstance(X, list):
                X = [x.to(devices[0]) for x in X]
            else:
                X = X.to(devices[0])
            gt = labels.long().to(devices[0])
    
            net.train()
            optimizer.zero_grad()
            result = net(X)
            loss_sum = loss(result, gt)
            loss_sum.sum().backward()
            optimizer.step()
    
            acc = d2l.accuracy(result, gt)
            metric.add(loss_sum, acc, labels.shape[0], labels.numel())
    
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))
                
        schedule.step()
    
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        print(f"epoch {epoch+1}/{epochs_num} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- lr {optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df["lr"] = lr_list
        df['time'] = time_list
        
        df.to_excel("../blork_file/savefile/TransUnet_camvid.xlsx")
        #----------------保存模型------------------- 
        if np.mod(epoch+1, 5) == 0:
            torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_{epoch+1}.pth')
    
    # 保存下最后的model
    torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_last.pth')
    
    # 开始训练
    train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num, schedule)

训练结果:

在这里插入图片描述

说在最后

文章中的代码较为简陋;大致上与TransUnet源图相匹配。如果你希望获取不同尺寸的模型;仅需改动各层通道数量设置;你可以分别在ResNet50架构部分;Vision Transformer模块以及解码器部分进行相关设置。如果希望将TransUnet应用于其他数据集;即可完成对分类任务的学习目标设定

作者注

  • num_classes的构成主要为:background+类别1+类别2+类别n。
  • 作者比较懒,还在自我批评中。如果作者不懒的话,可以把通道数的关系连接一下,这样只需要改一处就可以修改模型规模了,不像现在需要改好几个地方,还需要进行验证。
  • 不过,验证的过程也是学习的过程,所以,多看一看代码改一改对小白来说是有很大的好处的。
  • 因此,作者在这里为自己偷懒找了一个不错的借口。
  • 这篇文章写完了TransUnet,应某位读者的要求,下一篇文章会写SwinUnet。
  • 个人认为,Transformer效果不一定会很好。至少作者在自己的细胞数据集上测试情况来讲,Swin Transformer的结果不如传统的CNN模型来得更好。Transformer存在的缺陷很明显,同时GPU资源消耗很大。但是在大物体上的分割效果会很不错,这也是注意力机制的强大之处。但其在细小物体和边界的处理上,明显来的不那么好。这种情况下,使用deformable-DETR中提到的multi-scale Deformable Attention或许会达到一个不错的效果,毕竟可以更关注局部信息。不过2022年的各大顶会已经也都开始了对Transformer的魔改,融合CNN到Transformer中,从而达到局部全局两手抓的效果,像什么MixFormer、MaxVit啊等等。
  • 总之呢,个人认为,CV快到瓶颈期了,期待下一匹黑马诞生,干翻Transformer和CNN。

全部评论 (0)

还没有任何评论哟~