Advertisement

mxnet fine-tune

阅读量:

在进行网络的fine-tuning过程中,主要包含两种不同的策略:其一是全面调整整个神经网络结构;其二是专注于对深层特征进行优化。以mobilefacenet进行fine-tuning为例说明,首先加载预训练好的MobileFaceNet模型,并将其中的参数设置为不可学习的;接着对该模块新增的部分全连接层进行优化。

通过查找发现fixed_param_names可达到fix的效果,代码如下

复制代码
 symbol, arg_params, aux_params = mx.model.load_checkpoint('model-y1-arcface', 0000)  
    
 all_layers = symbol.get_internals()    
    
 net = all_layers['fc1_output']    
    
 fixed_names=net.list_arguments()
    
  
    
  
    
  
    
 model = mx.module.Module(symbol=symbol, context=devs,fixed_param_names=fixed_names)

在实际运行过程中遇到了一个问题:模型参数依然出现了变化。经过详细排查和深入分析后发现,在MobileFaceNet中的Batch Normalization(BN)模块采用了momentum=0.9的设置,在模型训练过程中难以完全锁定所需参数值。然而,在模型训练完成后发现各层参数仍存在一定程度的变化。因此建议将MobileFaceNet中的BN模块的momentum属性调整为1以缓解这一问题。接下来需要解决的问题是如何实现这一调整。

在mxnet框架下,默认情况下,默认情况下默认情况下,默认情况下,默认情况下,默认情况下,默认情况下,默认情况下,默认情况下,在mxnet框架下,默认情况下

有趣的事情发生了,参数不变了!!!!!不可思议!!!!

到这里为止哦!假如这个方法确实能帮到大家,请别忘了点个赞呢!如果真的不起作用的话,请允许我仅想说:

关我屁事

哈哈哈,最后附关键代码

复制代码
 import logging

    
 import mxnet as mx
    
 import numpy as np
    
 import os.path, time,sys
    
  
    
  
    
 # data iterators: generate data iterator from .rec file
    
 def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 112, 112)):
    
     train = mx.io.ImageRecordIter(
    
     path_imgrec=rec_train,
    
     path_imglist=lst_train,
    
     data_name='data',
    
     label_name='softmax_label',
    
     batch_size=batch_size,
    
     data_shape=data_shape,
    
     shuffle=True,
    
     # shuffle=False,
    
     rand_crop=True,
    
     mirror =True,
    
     rand_mirror=True,
    
     max_rotate_angle=0)
    
     val = mx.io.ImageRecordIter(
    
     path_imgrec=rec_val,
    
     data_name='data',
    
     label_name='softmax_label',
    
     batch_size=batch_size,
    
     data_shape=data_shape)
    
     return train,val
    
  
    
 # load and tune model
    
 def get_fine_tune_model(model_name):
    
     # load model
    
     symbol, arg_params, aux_params = mx.model.load_checkpoint('/home/xxx/anaconda2/envs/mobilefacenet/insightface_root/insightface/models/MobileFaceNet/model-y1-arcface', 0000)
    
     # model tuning
    
     all_layers = symbol.get_internals()
    
     net = all_layers['fc1_output']
    
  
    
     fixed_names=net.list_arguments()
    
  
    
     _weight_newfc1 = mx.symbol.Variable("newfc1_weight", shape=(10, 128), lr_mult=1.0, wd_mult=5)    
    
     net = mx.symbol.FullyConnected(data=net, num_hidden=10, name='newfc1')
    
     net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    
     # eliminate weights of new layer
    
   
    
     #finetune_lr = dict({k: 0 for k in arg_params})
    
     #print(finetune_lr)
    
  
    
     new_args = dict({k:arg_params[k] for k in arg_params if 'newfc1' not in k})
    
  
    
     return (net, new_args,aux_params,fixed_names)
    
  
    
 #model training
    
 def fit(symbol, arg_params, aux_params, iter_train, iter_val, num_epoch, batch_size, gpu_avaliable,fixed_names):
    
     devs = [mx.gpu(i) for i in gpu_avaliable]
    
     model = mx.module.Module(symbol=symbol, context=devs,fixed_param_names=fixed_names)
    
     # metric
    
     com_metric = mx.metric.CompositeEvalMetric()
    
     com_metric.add(mx.metric.Accuracy())
    
     
    
     # optimizer: fix the weight of certain layers except the last fully connect layer
    
     sgd = mx.optimizer.Optimizer.create_optimizer('sgd',learning_rate=0.01,momentum=0,wd=0.01)
    
     finetune_lr = dict({k: 0 for k in arg_params})
    
     #print(finetune_lr)
    
     #sgd.set_lr_mult(finetune_lr)
    
     # training
    
     model.fit(iter_train, iter_val,
    
     num_epoch=num_epoch,
    
     arg_params=arg_params,
    
     aux_params=aux_params,
    
     allow_missing=True,
    
     batch_end_callback = mx.callback.Speedometer(batch_size, 10),
    
     #epoch_end_callback  = mx.callback.do_checkpoint('/home/xxxx/anaconda2/envs/mobilefacenet/insightface_root/insightface/newmodels/chkmodel2', 0),
    
     kvstore='device',
    
     optimizer=sgd,
    
     optimizer_params={'learning_rate':0.01},
    
     initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
    
     eval_metric='acc')
    
     
    
     arg, aux = model.get_params()
    
     mx.model.save_checkpoint(prefix, 107,model.symbol, arg, aux)               #(name,index,symbol,arg_params, aux_params)
    
     return model.score(iter_val, com_metric)
    
  
    
 #=======================================================================================================================
    
 # set logger, print message on screen and file
    
 logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')
    
 console = logging.StreamHandler()
    
 console.setLevel(logging.INFO)
    
 console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))
    
 logging.getLogger('').addHandler(console)
    
  
    
 # data and pre-train model
    
 prefix='/home/****/anaconda2/envs/mobilefacenet/insightface_root/insightface/models
    
   94. rec_train='/home/****/anaconda2/envs/jiemian/dataset_rec/train_train.rec'
    
   96. model_name='ahah'
    
   98. rec_val='/home/****/anaconda2/envs/jiemian/dataset_rec/train_val.rec'
    
 lst_train=rec_train[:-3]+'lst'
    
  
    
 # parameter
    
 num_classes = 10
    
   104. batch_per_gpu = 3
    
 num_epoch =30
    
 gpu_avaliable=[0]
    
 num_gpus = len(gpu_avaliable)
    
 batch_size = batch_per_gpu * num_gpus
    
 print(batch_size)
    
 #-----------------------------------------------------------------------------------------------------------------------
    
  
    
 (new_sym,new_args,aux_params,fixnames)=get_fine_tune_model(model_name)
    
   114. #mx.viz.plot_network(new_sym).view()                     #          model architecture
    
   116. print('========================= 1 =============================')
    
 (iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)
    
   119. print('========================= 2 =============================')
    
 mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, num_epoch, batch_size, gpu_avaliable,fixnames)
    
 print(mod_score)

溜了溜了

全部评论 (0)

还没有任何评论哟~