深度学习论文: ICNet for Real-Time Semantic Segmentation on High-Resolution Images及其PyTorch实现
 发布时间 
 阅读量: 
 阅读量 
ICNet:
用于高分辨率图像实时语义分割(2018)
PDF版本:https://arxiv.org/pdf/1704.08545.pdf
PyTorch版本:https://github.com/shanglianlm0525/PyTorch-Networks
1 概述
ICNet是一种基于PSPNet的实时语义分割网络,在降低推断时间的同时保证了较高的检测精度。其主要特点在于能够在1024 × 2048分辨率下稳定达到30帧每秒的运行速度。

2 ICNet
ICNet在低分辨率图像处理方面的性能表现与高分辨率图像上的推理能力相结合,在构建基于分层精化的分割预测机制的基础上,实现了对目标区域逐层细化的高效预测。
2-1 Cascade Feature Fusion

PyTorch代码:
    class CascadeFeatureFusion(nn.Module):
    def __init__(self,low_channels, high_channels, out_channels, num_classes):
        super(CascadeFeatureFusion, self).__init__()
    
        self.conv_low = Conv3x3BNReLU(low_channels,out_channels,1,dilation=2)
        self.conv_high = Conv3x3BNReLU(high_channels,out_channels,1,dilation=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv_low_cls = nn.Conv2d(out_channels, num_classes, 1, bias=False)
    
    def forward(self, x_low, x_high):
        x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True)
        x_low = self.conv_low(x_low)
        x_high = self.conv_high(x_high)
        out = self.relu(x_low + x_high)
        x_low_cls = self.conv_low_cls(x_low)
        return out, x_low_cls
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
        2-2 Cascade Label Guidance
为了提升学习效果, 我们采用了分级引导策略. 该方法采用不同分辨率的真实标签作为指引, 指导低、中及高分辨率输入的学习阶段. 在验证环节, 则跳过了低分辨率与中分辨率的指导步骤. 这种分级引导策略既降低了推理开销又保证了预测精度.
2-3 Network Architecture
具体改进如下:
| 分支 | 过程 | 耗时 | 
|---|---|---|
| 低分辨率 | 在中分辨率的1/16输出的基础上,再缩放到1/32.经过卷积后,然后使用几个dilated convolution扩展接受野但不缩小尺寸,最终以原图的1/32大小输出feature map。 | 虽然层数较多,但是分辨率低,速度快,且与分支二共享一部分权重 | 
| 中分辨率 | 以原图的1/2的分辨率作为输入,经过卷积后以1/8缩放,得到原图的1/16大小feature map,再将低分辨率分支的输出feature map通过CFF(cascade feature fusion)单元相融合得到最终输出。值得注意的是:低分辨率和中分辨率的卷积参数是共享的。 | 有17个卷积层,与分支一共享一部分权重,与分支一一起一共耗时6ms | 
| 高分辨率 | 原图输入,经过卷积后以1/8缩放,得到原图的1/8大小的feature map,再将中分辨率处理后的输出通过CFF单元融合 | 有3个卷积层,虽然分辨率高,因为少,耗时为9ms | 

PyTorch代码:
    class Backbone(nn.Module):
    def __init__(self, pyramids=[1,2,3,6]):
        super(Backbone, self).__init__()
        self.pretrained = torchvision.models.resnet50(pretrained=True)
    
    def forward(self, x):
        x = self.pretrained.conv1(x)
        x = self.pretrained.bn1(x)
        x = self.pretrained.relu(x)
        x = self.pretrained.maxpool(x)
        c1 = self.pretrained.layer1(x)
        c2 = self.pretrained.layer2(c1)
        c3 = self.pretrained.layer3(c2)
        c4 = self.pretrained.layer4(c3)
        return c1, c2, c3, c4
    
    class PyramidPoolingModule(nn.Module):
    def __init__(self, pyramids=[1,2,3,6]):
        super(PyramidPoolingModule, self).__init__()
        self.pyramids = pyramids
    
    def forward(self, x):
        feat = x
        height, width = x.shape[2:]
        for bin_size in self.pyramids:
            feat_x = F.adaptive_avg_pool2d(x, output_size=bin_size)
            feat_x = F.interpolate(feat_x, size=(height, width), mode='bilinear', align_corners=True)
            feat  = feat + feat_x
        return feat
    
    
    class ICNet(nn.Module):
    def __init__(self, num_classes):
        super(ICNet, self).__init__()
    
        self.conv_sub1 = nn.Sequential(
            Conv3x3BNReLU(3, 32, 2),
            Conv3x3BNReLU(32, 32, 2),
            Conv3x3BNReLU(32, 64, 2)
        )
        self.backbone = Backbone()
        self.ppm = PyramidPoolingModule()
    
        self.cff_12 = CascadeFeatureFusion(128, 64, 128, num_classes)
        self.cff_24 = CascadeFeatureFusion(2048, 512, 128, num_classes)
    
        self.conv_cls = nn.Conv2d(128, num_classes, 1, bias=False)
    
    def forward(self, x):
        # sub 1
        x_sub1 = self.conv_sub1(x)
        # sub 2
        x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
        _, x_sub2, _, _ = self.backbone(x_sub2)
        # sub 4
        x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear')
        _, _, _, x_sub4 = self.backbone(x_sub4)
    
        # add PyramidPoolingModule
        x_sub4 = self.ppm(x_sub4)
    
        outs = list()
        x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
        outs.append(x_24_cls)
        # x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
        x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
        outs.append(x_12_cls)
    
        up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear')
        up_x2 = self.conv_cls(up_x2)
        outs.append(up_x2)
        up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear')
        outs.append(up_x8)
        # 1 -> 1/4 -> 1/8 -> 1/16
        outs.reverse()
    
        return outs
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
        3 Experimental

全部评论 (0)
 还没有任何评论哟~ 
