深度学习论文: Rethinking BiSeNet For Real-time Semantic Segmentation及其PyTorch实现
Revisiting the BiSeNet architecture for real-time scene understanding and segmentation, this paper presents an optimized PyTorch implementation that significantly enhances computational efficiency while maintaining high accuracy in semantic segmentation tasks. The enhanced model leverages advanced feature fusion techniques to achieve superior performance in real-time applications, making it particularly suitable for applications requiring immediate scene analysis and object recognition. The accompanying PyTorch code provides a comprehensive implementation guide, ensuring ease of integration and customization for researchers and practitioners in the field of computer vision. Available at https://arxiv.org/pdf/2104.13188.pdf, this implementation offers a robust foundation for advancing real-time semantic segmentation systems. For those interested in exploring the technical details, the complete PyTorch code is accessible at https://github.com/shanglianlm0525/CvPytorch.
1 概述
BiSeNet已被广泛认可为一种广受欢迎的用于实时分割的two-stream网络。然而,其引入了一个额外的路径来编码空间信息的方式非常耗时,并且由于未针对特定任务进行优化设计,从预训练的任务(如图像分类)中借用的基础知识对图像分割的实际效果并不理想。
STC通过逐步减少特征图的空间维度,并整合所有特征图进行图像表征来构建STDC网络的基本模块。在解码器部分,我们提出了一个Detail Aggregation module来整合空间信息的学习过程,并将其以单通道流的方式整合到低层层面上。最终步骤是整合低层特征与深层特征来进行分割预测

该系统在Cityscapes数据集上的1080Ti硬件配置下,在测试集上获得了更好的结果,在250.4FPS的速度下达到了71.9%的平均IoU值,并且相较于最新的同类方法实现了显著提升(提升幅度为45.2%)。此外,在97.0FPS运行时同样保持了较高的效率(达到了76.8%的平均IoU),同时能够推断出更高质量或更详细的图像。
2 STDC module
STDC在第1至第5个阶段(Stage-1至Stage-5)中分别实现了输入空间分辨率的两次下采样;在第6个阶段(Stage-6),经过一个ConvX层、一个全局平均池化层以及两个全连接层的串联处理后,最终生成预测结果logits
为提升效率起见,在完成前两个阶段(即Stage-1和Stage-2)时仅采用单个卷积块以减少计算开销。对于后续三个阶段(即Stage-3至Stage-5)而言,在每个STDC模块的第一层采用stride=2的卷积进行下采样之后随后在此过程中空间分辨率得以保持不变。

特点:
- 按照几何级数规律缩减滤波器规模的同时,大大降低了计算复杂度
- 通过将各个block的输出进行连接融合,从而维持了各尺度的空间感知能力和丰富的特征表示能力
3 STDC network
在编码器模块中,
STDC模块通过依次构造第3、4、5层,
分别输出下采样比率为1:8、1:16和1:32的空间特征图。
随后通过全局平均池化操作提取整体语义信息。
接着采用U-shaped连接策略将全局特征与第4层和第5层的局部特征进行融合。
在此基础上,
在BiSeNet之后引入注意力机制模块以进一步精炼相邻两个层级的特征融合特性。
解码器模块中,
引入Detail Guidance机制以促进低层次层的空间细节学习过程,
而非依赖额外辅助路径。

Detail gt generation:
- 通过采用不同 stride 的二维卷积操作(命名为 Laplacian 核),在 Fig4(e) 中生成不同尺度的 soft thin detail 特征图。
- 将三个特征图放大至原始尺寸,并通过一个 1x1 的卷积层来进行动态融合。
- 通过阈值 0.1 将预测结果转换为二值 gt 图。
PyTorch代码
stdcnet.py
class ConvX(nn.Module):
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
super(ConvX, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel // 2, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.relu(self.bn(self.conv(x)))
return out
class AddBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(AddBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,
bias=False),
nn.BatchNorm2d(out_planes // 2),
)
self.skip = nn.Sequential(
nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
nn.BatchNorm2d(in_planes),
nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
nn.BatchNorm2d(out_planes),
)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))
else:
self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))
def forward(self, x):
out_list = []
out = x
for idx, conv in enumerate(self.conv_list):
if idx == 0 and self.stride == 2:
out = self.avd_layer(conv(out))
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
x = self.skip(x)
return torch.cat(out_list, dim=1) + x
class CatBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(CatBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,
bias=False),
nn.BatchNorm2d(out_planes // 2),
)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))))
else:
self.conv_list.append(ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))))
def forward(self, x):
out_list = []
out1 = self.conv_list[0](x)
for idx, conv in enumerate(self.conv_list[1:]):
if idx == 0:
if self.stride == 2:
out = conv(self.avd_layer(out1))
else:
out = conv(out1)
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
out1 = self.skip(out1)
out_list.insert(0, out1)
out = torch.cat(out_list, dim=1)
return out
class STDCNet(nn.Module):
def __init__(self, subtype='stdc1', out_stages=[3, 4, 5], output_stride = 32, classifier=False, backbone_path=None, pretrained = False):
super(STDCNet, self).__init__()
self.subtype = subtype
self.out_stages = out_stages
self.output_stride = output_stride # 8, 16, 32
self.classifier = classifier
self.backbone_path = backbone_path
self.pretrained = pretrained
base = 64
block_num = 4
self.out_channels = [3, 32, 64, 256, 512, 1024]
if self.subtype == 'stdc1':
layers = [2, 2, 2]
features = self._make_layers(base, layers, block_num, CatBottleneck) # AddBottleneck
self.layer1 = nn.Sequential(features[:1]) # x2
self.layer2 = nn.Sequential(features[1:2]) # x4
self.layer3 = nn.Sequential(features[2:4]) # x8
self.layer4 = nn.Sequential(features[4:6]) # x16
self.layer5 = nn.Sequential(features[6:]) # x32
elif self.subtype == 'stdc2':
layers = [4, 5, 3]
features = self._make_layers(base, layers, block_num, CatBottleneck) # AddBottleneck
self.layer1 = nn.Sequential(features[:1]) # x2
self.layer2 = nn.Sequential(features[1:2]) # x4
self.layer3 = nn.Sequential(features[2:6]) # x8
self.layer4 = nn.Sequential(features[6:11]) # x16
self.layer5 = nn.Sequential(features[11:]) # x32
else:
raise NotImplementedError
if self.classifier:
self.conv_last = ConvX(base * 16, max(1024, base * 16), 1, 1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(max(1024, base * 16), max(1024, base * 16), bias=False)
self.bn = nn.BatchNorm1d(max(1024, base * 16))
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.2)
self.linear = nn.Linear(max(1024, base * 16), 1000, bias=False)
self.out_channels = [self.out_channels[ost] for ost in self.out_stages]
if self.pretrained:
self.load_pretrained_weights()
else:
self.init_weights()
def _make_layers(self, base, layers, block_num, block):
features = []
features += [ConvX(3, base // 2, 3, 2)]
features += [ConvX(base // 2, base, 3, 2)]
for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base * 4, block_num, 2))
elif j == 0:
features.append(block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2))
else:
features.append(block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1))
return nn.Sequential(*features)
def forward(self, x):
output = []
for i in range(1, 6):
layer = getattr(self, 'layer{}'.format(i))
x = layer(x)
if i in self.out_stages:
output.append(x)
if self.classifier:
x = self.conv_last(x).pow(2)
x = self.gap(x).flatten(1)
x = self.fc(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear(x)
return x
return tuple(output) if len(self.out_stages) > 1 else output[0]
代码解读
stdc_head.py
class AttentionRefinementModule(nn.Module):
def __init__(self, in_channel, out_channel, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvModule(in_channel, out_channel, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU')
self.conv_atten = nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_channel)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
class FeatureFusionModule(nn.Module):
def __init__(self, in_channel, out_channel, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvModule(in_channel, out_channel, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU')
self.conv1 = nn.Conv2d(out_channel, out_channel // 4, kernel_size=1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_channel // 4, out_channel, kernel_size=1, stride=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
class StdcHead(nn.Module):
def __init__(self, in_channels, num_classes, mid_channel = 128):
super(StdcHead, self).__init__()
# 256, 512, 1024
self.conv_avg = ConvModule(in_channels[2], 128, 1, 1, 0, norm_cfg=dict(type='BN'), activation='ReLU')
self.arm32 = AttentionRefinementModule(in_channels[2], 128)
self.conv_head32 = ConvModule(128, 128, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU')
self.arm16 = AttentionRefinementModule(in_channels[1], 128)
self.conv_head16 = ConvModule(128, 128, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU')
self.ffm = FeatureFusionModule(in_channels[0] + mid_channel, 256)
self.conv_out8 = nn.Sequential(
ConvModule(in_channels[0], 256, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU'),
nn.Conv2d(256, num_classes, kernel_size=1, bias=False)
)
self.conv_out16 = nn.Sequential(
ConvModule(mid_channel, 64, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU'),
nn.Conv2d(64, num_classes, kernel_size=1, bias=False)
)
self.conv_out32 = nn.Sequential(
ConvModule(mid_channel, 64, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU'),
nn.Conv2d(64, num_classes, kernel_size=1, bias=False)
)
self.conv_out_sp8 = nn.Sequential(
ConvModule(in_channels[0], 64, 3, 1, 1, norm_cfg=dict(type='BN'), activation='ReLU'),
nn.Conv2d(64, 1, kernel_size=1, bias=False)
)
self._init_weight()
def forward(self, x):
feat8, feat16, feat32 = x
# feat32_avg
feat32_avg = F.adaptive_avg_pool2d(feat32, 1)
feat32_avg = self.conv_avg(feat32_avg)
feat32_avg_up = F.interpolate(feat32_avg, feat32.size()[2:], mode='nearest')
# feat32
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + feat32_avg_up
feat32_up = F.interpolate(feat32_sum, feat16.size()[2:], mode='nearest')
feat32_up = self.conv_head32(feat32_up)
# feat16
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, feat8.size()[2:], mode='nearest')
feat16_up = self.conv_head16(feat16_up)
# ffm
feat8_fuse = self.ffm(feat8, feat16_up)
# out
feat_out8 = self.conv_out8(feat8_fuse)
feat_out16 = self.conv_out16(feat16_up)
feat_out32 = self.conv_out32(feat32_up)
feat_out_sp8 = self.conv_out_sp8(feat8)
return feat_out8, feat_out16, feat_out32, feat_out_sp8
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
代码解读
