Advertisement

Efficient Multi-Scale Training训练代码解析

阅读量:

Efficient Multi-Scale Training训练代码解析

  • def parser() 使用argparse模块实现命令行解析[1]
  1. 导入argparse模块用于命令行参数解析
  2. 生成ArgumentParser对象用于接收命令行参数
  3. 调用add_argument方法指定程序所需接受的命令行参数

指定参数(必填):
parser.add_argument("--echo", help="echo the string")
指定可选参数:
parser.add_argument("--verbosity", help="increase output verbosity")

  1. arg_parser.parse_args()
复制代码
    def parser():
    arg_parser = argparse.ArgumentParser('SNIPER training module')
    arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
    							default='configs/faster/pvalite_b5.yml',type=str)
    arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info',
                            default=100, type=int)
    arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float)
    arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
                            default='SNIPER', type=str)
    arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
                            default=None, nargs=argparse.REMAINDER)
    
    return arg_parser.parse_args()
复制代码
    def main():
    	args = parser()
    	update_config(args.cfg)
  • mx.gpu
复制代码
    context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]

configs.gpus --cfg参数指定的yml文件中的gpus

--cfg参数指定的yml文件中的gpus
复制代码
    nGPUs = len(context) #实用的gpu数量
    batch_size = nGPUs * config.TRAIN.BATCH_IMAGES #设定batch size
  • Create Roidb 创建数据集
复制代码
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, proposal=config.dataset.proposal, append_gt=True, flip=config.TRAIN.FLIP, result_path=config.output_path, proposal_path=config.proposal_path, load_mask=config.TRAIN.WITH_MASK, only_gt=not config.TRAIN.USE_NEG_CHIPS) for image_set in image_sets]
    
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config)
    
    train_iter = MNIteratorE2E(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs, threads=config.TRAIN.NUM_THREAD, pad_rois_to=400)
  • Create the Logger 创建日志
复制代码
    logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
  • 获取固定参数列表
复制代码
    sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))(n_proposals=400, momentum=args.momentum)
    sym = sym_inst.get_symbol_rcnn(config)
    
    fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym)
  • Create the module 创建模型
复制代码
    for k in train_iter.provide_data_single:
        print k[0]
    mod = mx.mod.Module(symbol=sym,
                        context=context,
                        data_names=[k[0] for k in train_iter.provide_data_single],
                        label_names=[k[0] for k in train_iter.provide_label_single],
                        fixed_param_names=fixed_param_names)
    
    shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single)
    sym_inst.infer_shape(shape_dict)
    arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True)
    sym_inst.init_weight_rcnn(config, arg_params, aux_params)
  • Create the metrics 创建指标
复制代码
    eval_metric = metric.RPNAccMetric()
    cls_metric = metric.RPNLogLossMetric()
    bbox_metric = metric.RPNL1LossMetric()
    rceval_metric = metric.RCNNAccMetric(config)
    rccls_metric  = metric.RCNNLogLossMetric(config)
    rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    
    eval_metrics.add(eval_metric)
    eval_metrics.add(cls_metric)
    eval_metrics.add(bbox_metric)
    eval_metrics.add(rceval_metric)
    eval_metrics.add(rccls_metric)
    eval_metrics.add(rcbbox_metric)
    
    	optimizer_params = get_optim_params(config, len(train_iter), batch_size)
  • Checkpoint
复制代码
    prefix = os.path.join(output_path, args.save_prefix)
    batch_end_callback = mx.callback.Speedometer(batch_size, args.display)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          eval('{}.checkpoint_callback'.format(config.symbol))(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)]
    
    train_iter = PrefetchingIter(train_iter)
    mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params,
            eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params)

相关博文[1]Python命令行参数解析argparse核心功能解析

全部评论 (0)

还没有任何评论哟~