Efficient Multi-Scale Training训练代码解析
发布时间
阅读量:
阅读量
Efficient Multi-Scale Training训练代码解析
- def parser() 使用argparse模块实现命令行解析[1]
- 导入argparse模块用于命令行参数解析
- 生成ArgumentParser对象用于接收命令行参数
- 调用add_argument方法指定程序所需接受的命令行参数
指定参数(必填):
parser.add_argument("--echo", help="echo the string")
指定可选参数:
parser.add_argument("--verbosity", help="increase output verbosity")
- arg_parser.parse_args()
def parser():
arg_parser = argparse.ArgumentParser('SNIPER training module')
arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
default='configs/faster/pvalite_b5.yml',type=str)
arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info',
default=100, type=int)
arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float)
arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
default='SNIPER', type=str)
arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
default=None, nargs=argparse.REMAINDER)
return arg_parser.parse_args()
def main():
args = parser()
update_config(args.cfg)
- mx.gpu
context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]
configs.gpus --cfg参数指定的yml文件中的gpus

nGPUs = len(context) #实用的gpu数量
batch_size = nGPUs * config.TRAIN.BATCH_IMAGES #设定batch size
- Create Roidb 创建数据集
image_sets = [iset for iset in config.dataset.image_set.split('+')]
roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, proposal=config.dataset.proposal, append_gt=True, flip=config.TRAIN.FLIP, result_path=config.output_path, proposal_path=config.proposal_path, load_mask=config.TRAIN.WITH_MASK, only_gt=not config.TRAIN.USE_NEG_CHIPS) for image_set in image_sets]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb, config)
bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config)
train_iter = MNIteratorE2E(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs, threads=config.TRAIN.NUM_THREAD, pad_rois_to=400)
- Create the Logger 创建日志
logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
- 获取固定参数列表
sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))(n_proposals=400, momentum=args.momentum)
sym = sym_inst.get_symbol_rcnn(config)
fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym)
- Create the module 创建模型
for k in train_iter.provide_data_single:
print k[0]
mod = mx.mod.Module(symbol=sym,
context=context,
data_names=[k[0] for k in train_iter.provide_data_single],
label_names=[k[0] for k in train_iter.provide_label_single],
fixed_param_names=fixed_param_names)
shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single)
sym_inst.infer_shape(shape_dict)
arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True)
sym_inst.init_weight_rcnn(config, arg_params, aux_params)
- Create the metrics 创建指标
eval_metric = metric.RPNAccMetric()
cls_metric = metric.RPNLogLossMetric()
bbox_metric = metric.RPNL1LossMetric()
rceval_metric = metric.RCNNAccMetric(config)
rccls_metric = metric.RCNNLogLossMetric(config)
rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(eval_metric)
eval_metrics.add(cls_metric)
eval_metrics.add(bbox_metric)
eval_metrics.add(rceval_metric)
eval_metrics.add(rccls_metric)
eval_metrics.add(rcbbox_metric)
optimizer_params = get_optim_params(config, len(train_iter), batch_size)
- Checkpoint
prefix = os.path.join(output_path, args.save_prefix)
batch_end_callback = mx.callback.Speedometer(batch_size, args.display)
epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
eval('{}.checkpoint_callback'.format(config.symbol))(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)]
train_iter = PrefetchingIter(train_iter)
mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params,
eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore,
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params)
相关博文[1]Python命令行参数解析argparse核心功能解析
全部评论 (0)
还没有任何评论哟~
