Advertisement

ResUNet++一种先进的医疗图像分割的keras实现

阅读量:

The ResUNet++ architecture is an advanced medical image segmentation model based on the Deep Residual U-Net (ResUNet). It incorporates residual blocks, squeeze-and-excitation blocks, ASPP (Asymmetric Spatial Pyramid) blocks, and attention mechanisms to enhance performance. The architecture includes an encoder for feature extraction, a bridge for context aggregation, and a decoder for upsampling and concatenation with skip connections. The implementation uses Keras and TensorFlow, with detailed code provided for each component such as stem blocks, residual blocks, attention blocks, and aspp blocks. The complete model is available on GitHub for both Keras/TensorFlow and PyTorch frameworks.

The ResUNet++架构基于Deep Residual U-Net(ResUNet)构建而成,并充分利用了深度残差学习与U-Net的优势。该架构通过巧妙结合残差块、squeeze与excitation模块、ASPP以及注意力机制等组件实现了显著性能提升。关于该架构的详细说明,请参考论文《ResUNet++:一种改进型医学图像分割架构》(https://arxiv.org/pdf/1911.07067.pdf)。

复制代码
 ResUNet++ architecture in Keras TensorFlow

    
 """
    
 import os
    
 import numpy as np
    
 import cv2
    
   7. import tensorflow as tf
    
 from tensorflow.keras.layers import *
    
 from tensorflow.keras.models import Model
    
   11. def squeeze_excite_block(inputs, ratio=8):
    
     init = inputs
    
     channel_axis = -1
    
     filters = init.shape[channel_axis].value
    
     se_shape = (1, 1, filters)
    
   17.     se = GlobalAveragePooling2D()(init)
    
     se = Reshape(se_shape)(se)
    
     se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    
     se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
    
   22.     x = Multiply()([init, se])
    
     return x
    
   25. def stem_block(x, n_filter, strides):
    
     x_init = x
    
   28.     ## Conv 1
    
     x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
    
     x = BatchNormalization()(x)
    
     x = Activation("relu")(x)
    
     x = Conv2D(n_filter, (3, 3), padding="same")(x)
    
   34.     ## Shortcut
    
     s  = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
    
     s = BatchNormalization()(s)
    
   38.     ## Add
    
     x = Add()([x, s])
    
     x = squeeze_excite_block(x)
    
     return x
    
   43.   44. def resnet_block(x, n_filter, strides=1):
    
     x_init = x
    
   47.     ## Conv 1
    
     x = BatchNormalization()(x)
    
     x = Activation("relu")(x)
    
     x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
    
     ## Conv 2
    
     x = BatchNormalization()(x)
    
     x = Activation("relu")(x)
    
     x = Conv2D(n_filter, (3, 3), padding="same", strides=1)(x)
    
   56.     ## Shortcut
    
     s  = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
    
     s = BatchNormalization()(s)
    
   60.     ## Add
    
     x = Add()([x, s])
    
     x = squeeze_excite_block(x)
    
     return x
    
   65. def aspp_block(x, num_filters, rate_scale=1):
    
     x1 = Conv2D(num_filters, (3, 3), dilation_rate=(6 * rate_scale, 6 * rate_scale), padding="SAME")(x)
    
     x1 = BatchNormalization()(x1)
    
   69.     x2 = Conv2D(num_filters, (3, 3), dilation_rate=(12 * rate_scale, 12 * rate_scale), padding="SAME")(x)
    
     x2 = BatchNormalization()(x2)
    
   72.     x3 = Conv2D(num_filters, (3, 3), dilation_rate=(18 * rate_scale, 18 * rate_scale), padding="SAME")(x)
    
     x3 = BatchNormalization()(x3)
    
   75.     x4 = Conv2D(num_filters, (3, 3), padding="SAME")(x)
    
     x4 = BatchNormalization()(x4)
    
   78.     y = Add()([x1, x2, x3, x4])
    
     y = Conv2D(num_filters, (1, 1), padding="SAME")(y)
    
     return y
    
   82. def attetion_block(g, x):
    
     """
    
     g: Output of Parallel Encoder block
    
     x: Output of Previous Decoder block
    
     """
    
   88.     filters = x.shape[-1].value
    
   90.     g_conv = BatchNormalization()(g)
    
     g_conv = Activation("relu")(g_conv)
    
     g_conv = Conv2D(filters, (3, 3), padding="SAME")(g_conv)
    
   94.     g_pool = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(g_conv)
    
   96.     x_conv = BatchNormalization()(x)
    
     x_conv = Activation("relu")(x_conv)
    
     x_conv = Conv2D(filters, (3, 3), padding="SAME")(x_conv)
    
   100.     gc_sum = Add()([g_pool, x_conv])
    
   102.     gc_conv = BatchNormalization()(gc_sum)
    
     gc_conv = Activation("relu")(gc_conv)
    
     gc_conv = Conv2D(filters, (3, 3), padding="SAME")(gc_conv)
    
   106.     gc_mul = Multiply()([gc_conv, x])
    
     return gc_mul
    
   109. class ResUnetPlusPlus:
    
     def __init__(self, input_size=256):
    
     self.input_size = input_size
    
   113.     def build_model(self):
    
     n_filters = [16, 32, 64, 128, 256]
    
     inputs = Input((self.input_size, self.input_size, 3))
    
   117.         c0 = inputs
    
     c1 = stem_block(c0, n_filters[0], strides=1)
    
   120.         ## Encoder
    
     c2 = resnet_block(c1, n_filters[1], strides=2)
    
     c3 = resnet_block(c2, n_filters[2], strides=2)
    
     c4 = resnet_block(c3, n_filters[3], strides=2)
    
   125.         ## Bridge
    
     b1 = aspp_block(c4, n_filters[4])
    
   128.         ## Decoder
    
     d1 = attetion_block(c3, b1)
    
     d1 = UpSampling2D((2, 2))(d1)
    
     d1 = Concatenate()([d1, c3])
    
     d1 = resnet_block(d1, n_filters[3])
    
   134.         d2 = attetion_block(c2, d1)
    
     d2 = UpSampling2D((2, 2))(d2)
    
     d2 = Concatenate()([d2, c2])
    
     d2 = resnet_block(d2, n_filters[2])
    
   139.         d3 = attetion_block(c1, d2)
    
     d3 = UpSampling2D((2, 2))(d3)
    
     d3 = Concatenate()([d3, c1])
    
     d3 = resnet_block(d3, n_filters[1])
    
   144.         ## output
    
     outputs = aspp_block(d3, n_filters[0])
    
     outputs = Conv2D(1, (1, 1), padding="same")(outputs)
    
     outputs = Activation("sigmoid")(outputs)
    
   149.         ## Model
    
     model = Model(inputs, outputs)
    
     return model

keras implementation of ResUNet++ can be found here :

https://github.com/DebeshJha/ResUNetPlusPlus

Pytorch implementation of ResUNet++ can be found here :

https://github.com/rishikksh20/ResUnet/blob/master/core/res_unet_plus.py

全部评论 (0)

还没有任何评论哟~