Advertisement

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

阅读量:

目录

摘要

Abstract

SegFormer

EnCoder

Decoder

实验

总结


摘要

基于Vision Transformer的演化而形成的一种语义分割模型,在改进架构设计的基础上实现了更好的性能表现。该模型借助分层Transformer编码器、重叠Patch Merging操作以及轻量化的全MLP解码器,在解决ViT在语义分割任务中计算复杂度高且内存消耗大的问题的同时还有效避免了细节信息丢失的问题。该模型在包括ADE20K、Cityscapes以及COCO-Stuff等数据集上的实验结果表明其表现优异,并且具备结构简单、计算高效且泛化能力强的特点,并为基于Transformer的密集预测任务提供了新的解决方案。

Abstract

SegFormer是一种基于Vision Transformer(ViT)发展而来的语义分割模型。通过引入分层Transformer编码器并采用重叠Patch Merging技术结合轻量级多层全连接解码器,在降低ViT在处理视觉分割任务时所面临的高计算复杂度与大内存消耗问题的同时,在包括ADE20K Cityscapes与COCO-Stuff等多个数据集上实现了超越性表现同时凭借其简洁高效的架构与强大的泛化能力显著提升了相关算法的技术水平并为其提供了更具竞争力的新方案

SegFormer

论文链接:[2105.15203] SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

GitHub - NVlabs/SegFormers: The official PyTorch implementation of SegFormers

SegFormer网络结构图如下所示:

SegFormer主要由Transformer的编码器和轻量级的MLP的解码器组成。

网络特点:

1、结合了Transformers与轻量级的MLP解码器;

2、采用一种创新性的分层架构设计的Transformer编码器能够生成多尺度特征表示。该编码器无需采用位置编码方案从而成功规避了位置信息插值的问题但可能在测试分辨率与训练阶段存在差异的情况下导致性能下降

3、建议放弃采用过于复杂的解码器方案。所提出的MLP解码器通过聚合不同层级的信息流, 从而有效地融合了局部与全局注意力机制, 最终展现出强大的表示能力;

4、设计非常简单和轻量级,这是在Transformers上实现高效分割的关键;

5、SegFormer系列模型经历了从B0到B5的发展过程,在各个阶段都进行了优化更新。与前一阶段相比,在性能方面提升幅度较大,在效率方面也取得了明显提升。

EnCoder

Overlap Patch Embeddings:利用二维卷积操作将图像划分为四个区域,并将其嵌入到预定义维度的空间中;借助层次化特征表示方法(Hierarchical Feature Representation),编码器能够同时提取高分辨率的大致特征以及低分辨率的细节信息;从而能够更有效地捕获不同尺度的空间关系。

复制代码
 #block1 对输入图像进行分区,并下采样512, 512, 3 => 128, 128, 32 => 16384, 32

    
     self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0])
    
    
    
 #block2对输入图像进行分区,并下采样,128, 128, 32 => 64, 64, 64 => 4096, 64
    
     self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
    
  
    
 #block3对输入图像进行分区,并下采样  64, 64, 64 => 32, 32, 160 => 1024, 160
    
     self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
    
  
    
 #block4对输入图像进行分区,并下采样32, 32, 160 => 16, 16, 256 => 256, 256
    
     self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])

Efficient self-attention mechanism: The Attention注意力机制 core of the system. In the encoder, the most computationally intensive part is the self-attention module responsible for feature extraction.

复制代码
     self.attn = Attention(

    
         dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
    
         attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio
    
     )

Mix FNN: 在深度神经网络(DNN)框架中集成了一种基于 3×3 卷积与多层感知机(MLP)的设计方案,并提出了一种 novel 的混合型前馈神经网络(Mix-FFN)结构。作者指出,在语义分割任务中应用位置编码来提取局部空间信息并非必要性,并因此提出了一种新的混合型前馈神经网络架构。该方法考虑到零填充操作可能导致的位置泄露问题,在传统的 FFN 结构中采用了 3×3 的卷积核进行处理。通过这种方法设计的 Mix-FFN 结构能够避免传统 Transformer 模型固定化的这一缺陷,并通过动态施加局部位置信息的方式提升了模型在不同分辨率输入数据下的适应能力的同时还简化了整体架构设计使其更具高效性

复制代码
    self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)

Overlapped Patch Merging

Decoder

特征通道统一: 解码器整合了来自分层Transformer编码器的不同层次特征信息,并通过多组MLP网络对其进行独立处理操作。具体而言,在这一过程中, 解码器不仅整合了来自不同层次编码器的高阶抽象表征, 还包括但不限于以下几种常见采样率: 1/4, 1/8等, 最终实现各个尺度上的通道数量的一致化处理以确保网络的有效运行效率和数据完整性

特征上采样与拼接: 经MLP处理后生成的特征依次采用双线性插值法进行upsampling,并将其连接起来;

特征融合: 拼接得到的特征序列经过多层感知机(MLP)进行深度学习融合,有效整合了不同尺度的空间与语义信息;

分割掩码预测: 利用另一个MLP层将融合后的特征对应到类别空间,并得到最终结果作为分割掩码;该方法输出的结果图像分辨率是输入图像大小的四分之一。

在处理过程中,在进行图像分辨率的上采样处理后,在确保上采样后的分割掩码与原始图像保持一致的空间位置关系下,在这些像素级分类结果可以直接应用于语义分割的任务中

实验

模型训练代码:

复制代码
 import argparse

    
 import copy
    
 import os
    
 import os.path as osp
    
 import time
    
  
    
 import mmcv
    
 import torch
    
 from mmcv.runner import init_dist
    
 from mmcv.utils import Config, DictAction, get_git_hash
    
  
    
 from mmseg import __version__
    
 from mmseg.apis import set_random_seed, train_segmentor
    
 from mmseg.datasets import build_dataset
    
 from mmseg.models import build_segmentor
    
 from mmseg.utils import collect_env, get_root_logger
    
  
    
  
    
 def parse_args():
    
     parser = argparse.ArgumentParser(description='Train a segmentor')
    
     parser.add_argument('config', help='train config file path')
    
     parser.add_argument('--work-dir', help='the dir to save logs and models')
    
     parser.add_argument(
    
     '--load-from', help='the checkpoint file to load weights from')
    
     parser.add_argument(
    
     '--resume-from', help='the checkpoint file to resume from')
    
     parser.add_argument(
    
     '--no-validate',
    
     action='store_true',
    
     help='whether not to evaluate the checkpoint during training')
    
     group_gpus = parser.add_mutually_exclusive_group()
    
     group_gpus.add_argument(
    
     '--gpus',
    
     type=int,
    
     help='number of gpus to use '
    
     '(only applicable to non-distributed training)')
    
     group_gpus.add_argument(
    
     '--gpu-ids',
    
     type=int,
    
     nargs='+',
    
     help='ids of gpus to use '
    
     '(only applicable to non-distributed training)')
    
     parser.add_argument('--seed', type=int, default=None, help='random seed')
    
     parser.add_argument(
    
     '--deterministic',
    
     action='store_true',
    
     help='whether to set deterministic options for CUDNN backend.')
    
     parser.add_argument(
    
     '--options', nargs='+', action=DictAction, help='custom options')
    
     parser.add_argument(
    
     '--launcher',
    
     choices=['none', 'pytorch', 'slurm', 'mpi'],
    
     default='none',
    
     help='job launcher')
    
     parser.add_argument('--local_rank', type=int, default=0)
    
     args = parser.parse_args()
    
     if 'LOCAL_RANK' not in os.environ:
    
     os.environ['LOCAL_RANK'] = str(args.local_rank)
    
  
    
     return args
    
  
    
  
    
 def main():
    
  
    
     args = parse_args()
    
  
    
     cfg = Config.fromfile(args.config)
    
     if args.options is not None:
    
     cfg.merge_from_dict(args.options)
    
     # set cudnn_benchmark
    
     if cfg.get('cudnn_benchmark', False):
    
     torch.backends.cudnn.benchmark = True
    
  
    
     # work_dir is determined in this priority: CLI > segment in file > filename
    
     if args.work_dir is not None:
    
     # update configs according to CLI args if args.work_dir is not None
    
     cfg.work_dir = args.work_dir
    
     elif cfg.get('work_dir', None) is None:
    
     # use config filename as default work_dir if cfg.work_dir is None
    
     cfg.work_dir = osp.join('./work_dirs',
    
                             osp.splitext(osp.basename(args.config))[0])
    
     if args.load_from is not None:
    
     cfg.load_from = args.load_from
    
     if args.resume_from is not None:
    
     cfg.resume_from = args.resume_from
    
     if args.gpu_ids is not None:
    
     cfg.gpu_ids = args.gpu_ids
    
     else:
    
     cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
    
  
    
  
    
     # init distributed env first, since logger depends on the dist info.
    
     if args.launcher == 'none':
    
     distributed = False
    
     else:
    
     distributed = True
    
     init_dist(args.launcher, **cfg.dist_params)
    
  
    
     # create work_dir
    
     mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    
     # dump config
    
     cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    
     # init the logger before other steps
    
     timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    
     log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    
     logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
    
  
    
     # init the meta dict to record some important information such as
    
     # environment info and seed, which will be logged
    
     meta = dict()
    
     # log env info
    
     env_info_dict = collect_env()
    
     env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    
     dash_line = '-' * 60 + '\n'
    
     logger.info('Environment info:\n' + dash_line + env_info + '\n' +
    
             dash_line)
    
     meta['env_info'] = env_info
    
  
    
     # log some basic info
    
     logger.info(f'Distributed training: {distributed}')
    
     logger.info(f'Config:\n{cfg.pretty_text}')
    
  
    
     # set random seeds
    
     if args.seed is not None:
    
     logger.info(f'Set random seed to {args.seed}, deterministic: '
    
                 f'{args.deterministic}')
    
     set_random_seed(args.seed, deterministic=args.deterministic)
    
     cfg.seed = args.seed
    
     meta['seed'] = args.seed
    
     meta['exp_name'] = osp.basename(args.config)
    
  
    
     model = build_segmentor(
    
     cfg.model,
    
     train_cfg=cfg.get('train_cfg'),
    
     test_cfg=cfg.get('test_cfg'))
    
  
    
     logger.info(model)
    
  
    
     datasets = [build_dataset(cfg.data.train)]
    
  
    
     if len(cfg.workflow) == 2:
    
     val_dataset = copy.deepcopy(cfg.data.val)
    
     val_dataset.pipeline = cfg.data.train.pipeline
    
     datasets.append(build_dataset(val_dataset))
    
     if cfg.checkpoint_config is not None:
    
     # save mmseg version, config file content and class names in
    
     # checkpoints as meta data
    
     cfg.checkpoint_config.meta = dict(
    
         mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
    
         config=cfg.pretty_text,
    
         CLASSES=datasets[0].CLASSES,
    
         PALETTE=datasets[0].PALETTE)
    
     # add an attribute for visualization convenience
    
     model.CLASSES = datasets[0].CLASSES
    
     train_segmentor(
    
     model,
    
     datasets,
    
     cfg,
    
     distributed=distributed,
    
     validate=(not args.no_validate),
    
     timestamp=timestamp,
    
     meta=meta)
    
  
    
  
    
 if __name__ == '__main__':
    
     main()

模型测试代码:

复制代码
 import argparse

    
 import os
    
  
    
 import mmcv
    
 import torch
    
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
    
 from mmcv.runner import get_dist_info, init_dist, load_checkpoint
    
 from mmcv.utils import DictAction
    
  
    
 from mmseg.apis import multi_gpu_test, single_gpu_test
    
 from mmseg.datasets import build_dataloader, build_dataset
    
 from mmseg.models import build_segmentor
    
 from IPython import embed
    
  
    
 def parse_args():
    
     parser = argparse.ArgumentParser(
    
     description='mmseg test (and eval) a model')
    
     parser.add_argument('config', help='test config file path')
    
     parser.add_argument('checkpoint', help='checkpoint file')
    
     parser.add_argument(
    
     '--aug-test', action='store_true', help='Use Flip and Multi scale aug')
    
     parser.add_argument('--out', default='work_dirs/res.pkl', help='output result file in pickle format')
    
     parser.add_argument(
    
     '--format-only',
    
     action='store_true',
    
     help='Format the output results without perform evaluation. It is'
    
     'useful when you want to format the result to a specific format and '
    
     'submit it to the test server')
    
     parser.add_argument(
    
     '--eval',
    
     type=str,
    
     nargs='+',
    
     default='mIoU',
    
     help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
    
     ' for generic datasets, and "cityscapes" for Cityscapes')
    
     parser.add_argument('--show', action='store_true', help='show results')
    
     parser.add_argument(
    
     '--show-dir', help='directory where painted images will be saved')
    
     parser.add_argument(
    
     '--gpu-collect',
    
     action='store_true',
    
     help='whether to use gpu to collect results.')
    
     parser.add_argument(
    
     '--tmpdir',
    
     help='tmp directory used for collecting results from multiple '
    
     'workers, available when gpu_collect is not specified')
    
     parser.add_argument(
    
     '--options', nargs='+', action=DictAction, help='custom options')
    
     parser.add_argument(
    
     '--eval-options',
    
     nargs='+',
    
     action=DictAction,
    
     help='custom options for evaluation')
    
     parser.add_argument(
    
     '--launcher',
    
     choices=['none', 'pytorch', 'slurm', 'mpi'],
    
     default='none',
    
     help='job launcher')
    
     parser.add_argument('--local_rank', type=int, default=0)
    
     args = parser.parse_args()
    
     if 'LOCAL_RANK' not in os.environ:
    
     os.environ['LOCAL_RANK'] = str(args.local_rank)
    
     return args
    
  
    
  
    
 def main():
    
     args = parse_args()
    
  
    
     assert args.out or args.eval or args.format_only or args.show \
    
     or args.show_dir, \
    
     ('Please specify at least one operation (save/eval/format/show the '
    
      'results / save the results) with the argument "--out", "--eval"'
    
      ', "--format-only", "--show" or "--show-dir"')
    
  
    
     if 'None' in args.eval:
    
     args.eval = None
    
     if args.eval and args.format_only:
    
  
    
     raise ValueError('--eval and --format_only cannot be both specified')
    
  
    
     if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
    
     raise ValueError('The output file must be a pkl file.')
    
  
    
     cfg = mmcv.Config.fromfile(args.config)
    
     if args.options is not None:
    
     cfg.merge_from_dict(args.options)
    
     # set cudnn_benchmark
    
     if cfg.get('cudnn_benchmark', False):
    
     torch.backends.cudnn.benchmark = True
    
     if args.aug_test:
    
     if cfg.data.test.type == 'CityscapesDataset':
    
         # hard code index
    
         cfg.data.test.pipeline[1].img_ratios = [
    
             0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
    
         ]
    
         cfg.data.test.pipeline[1].flip = True
    
     elif cfg.data.test.type == 'ADE20KDataset':
    
         # hard code index
    
         cfg.data.test.pipeline[1].img_ratios = [
    
             0.75, 0.875, 1.0, 1.125, 1.25
    
         ]
    
         cfg.data.test.pipeline[1].flip = True
    
     else:
    
         # hard code index
    
         cfg.data.test.pipeline[1].img_ratios = [
    
             0.5, 0.75, 1.0, 1.25, 1.5, 1.75
    
         ]
    
         cfg.data.test.pipeline[1].flip = True
    
  
    
     cfg.model.pretrained = None
    
     cfg.data.test.test_mode = True
    
  
    
     # init distributed env first, since logger depends on the dist info.
    
     if args.launcher == 'none':
    
     distributed = False
    
     else:
    
     distributed = True
    
     init_dist(args.launcher, **cfg.dist_params)
    
  
    
     # build the dataloader
    
     # TODO: support multiple images per gpu (only minor changes are needed)
    
     dataset = build_dataset(cfg.data.test)
    
     data_loader = build_dataloader(
    
     dataset,
    
     samples_per_gpu=1,
    
     workers_per_gpu=cfg.data.workers_per_gpu,
    
     dist=distributed,
    
     shuffle=False)
    
  
    
     # build the model and load checkpoint
    
     cfg.model.train_cfg = None
    
     model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
    
     checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    
     model.CLASSES = checkpoint['meta']['CLASSES']
    
     model.PALETTE = checkpoint['meta']['PALETTE']
    
  
    
     efficient_test = True #False
    
     if args.eval_options is not None:
    
     efficient_test = args.eval_options.get('efficient_test', False)
    
  
    
     if not distributed:
    
     model = MMDataParallel(model, device_ids=[0])
    
     outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
    
                               efficient_test)
    
     else:
    
     model = MMDistributedDataParallel(
    
         model.cuda(),
    
         device_ids=[torch.cuda.current_device()],
    
         broadcast_buffers=False)
    
     outputs = multi_gpu_test(model, data_loader, args.tmpdir,
    
                              args.gpu_collect, efficient_test)
    
  
    
     rank, _ = get_dist_info()
    
     if rank == 0:
    
     if args.out:
    
         print(f'\nwriting results to {args.out}')
    
         mmcv.dump(outputs, args.out)
    
     kwargs = {} if args.eval_options is None else args.eval_options
    
     if args.format_only:
    
         dataset.format_results(outputs, **kwargs)
    
     if args.eval:
    
         dataset.evaluate(outputs, args.eval, **kwargs)
    
  
    
  
    
 if __name__ == '__main__':
    
     main()

输入图像:

输出图像:

总结

SegFormer是一种基于Transformer架构实现的高效语义分割模型。它采用了创新性的多层编码器结构以及轻量化全MLP解码器设计,在完全摒弃传统位置编码方案的基础上实现了性能与计算复杂度的良好平衡。其在多个基准数据集上的实验结果显示显著优势,并展现了卓越的泛化能力以及广泛的应用前景。未来研究工作有望进一步推动该技术的发展进程,并为相关领域带来新的突破可能性。

全部评论 (0)

还没有任何评论哟~