Advertisement

Proxyless NAS: Direct Neural Architecture Search on Target Task and Hardware

阅读量:

Paper:https://arxiv.org/abs/1812.00332

GitHub (PyTorch):https://github.com/mit-han-lab/proxylessnas

Introduction

  • NAS 能够针对目标任务、以及目标部署平台,按自动化方式、搜索出性能与效率折中最好的网络结构;
  • 如果目标任务的数据集规模比较大,通常会设计Proxy task (简化的中间任务) ,作为桥接目标任务的桥梁:
  1. 在Proxy task上的评估结果,可间接反映目标任务上的效果,用来指导NAS选择合适的网络结构;
  2. 但基于Proxy task的搜索结果,未必是目标任务上的最优结果,可能会导致次优化
  • 如果进一步考虑目标平台的资源预算 (latency, power consumption, memory等),需要在搜索过程中添加相应的资源约束
  • 如果仅堆叠重复的building block,会影响网络结构的多样性
  • 本文提出了Proxyless NAS:
  1. 能够结合部署平台的资源约束 ,按可微分方式 、直接在目标任务上执行NAS进程,节省了时间成本
  2. 搜索空间的构建不局限于重复的building block,因此允许支持相对更大的候选集
  3. 可微分搜索过程中,采用二值化 方式筛选子网络结构,节省了内存开销
  4. 搜索效果:
    1. 在ImageNet上,仅需200 GPU hours便可搜索出优于MobileNetV2的网络;
    2. 以8卡V100为例,搜索成本仅需25hours,相当于训练MobileNetV1/V2的时间开销;
  • 早期的NAS或AutoML,通常以Reinforcement learning或Evolution algorithm作为meta controller (例如RNN、EA等),在搜索空间内按离散方式选择或组建网络结构:
  1. 如果选择RL策略,需要在内层循环 训练、评估每次搜索的结果,而在外层循环 更新RNN controller,以指导下一次更为有效的搜索;
  2. 为了节省搜索成本,需要设计合理的Proxy task ,然而搜索开销依旧非常高;
  3. 对于较复杂的模型 (如Object detection),通常只搜索部分组件 (如FPN、或backbone);
  4. 典型代表包括: NASNet、MNASNet、NAS-FCOS、NAS-FPN、Evolved Transformer等;
  • 事先训练超网络 (super-network)或元网络 (meta-network) ,再通过参数共享方式采样、评估子网络结构:
  1. 这种方法可避免子网络的训练,因而节省了搜索成本;
  2. 然而网络结构相对单一,且训练成本、收敛难度转移到了超网络的训练;
  • DARTS 是典型的Gradient-based NAS方法:
  1. 为每个candidate cell分配结构参数,实现可微分搜索;
  2. 但是在网络参数的学习过程中,每个candidate cell都参与训练,因此内存开销会随着搜索空间的增加而显著增加;
  • 本文提出的Proxyless NAS也是Gradient-based NAS方法:
  1. GPU hours与GPU memory都相对更低;
  2. 并且,通过将部署平台的资源预算 (Latency等)作为次优化目标,Proxyless NAS具备hardware-aware的适应能力;
  3. 其他类似的兼顾搜索效率、与资源约束的NAS方法,还有FBNet、DenseNAS、NetAdapt 等;

Method

  • Basics: * Over-parameterized Super-network由N条edge构成:
    • 每条edge都是N个Candidate primitive operation paths的聚合:

    • 超网络又可以表示为:

    • One-shot方法 的聚合方式为:

    • DARTS方法 的聚合方式为:

    • One-shot与DARTS需要将N个paths都存放到内存中,因此内存消耗是单个path的N倍;

  • Binarized path of Proxyless NAS: * 与DARTS一样,为N条path定义可微分的结构参数:
    • 为节省内存开销,在训练超网络过程中,通过结构参数的二值化,只保留单个path:

    • Training binarized architecture parameters:

    • 第一阶段: 固定结构参数,在训练集上,根据path probability取样candidate paths,然后更新网络参数(代码以取样一条path为例):

复制代码
              1. self.log_prob = None

        
              2. # reset binary gates
        
              3. self.AP_path_wb.data.zero_()
        
              4. # binarize according to probs
        
              5. probs = self.probs_over_ops
        
              6.  
        
              7. sample = torch.multinomial(probs.data, 1)[0].item()
        
              8. self.active_index = [sample]
        
              9. self.inactive_index = [_i for _i in range(0, sample)] + \
        
              10.                        [_i for _i in range(sample + 1, self.n_choices)]
        
              11. # set binary gate
        
              12. self.AP_path_wb.data[sample] = 1.0
    
  • 第二阶段: 固定网络参数,在验证集上,更新结构参数:

首先确定反传到binary gate的梯度

复制代码
                    1. # foward of edge

            
                    2. def run_function(candidate_ops, active_id):
            
                    3. 	def forward(_x):
            
                    4. 		return candidate_ops[active_id](_x)
            
                    5. 	return forward
            
                    6.  
            
                    7. # backward of edge
            
                    8. def backward_function(candidate_ops, active_id, binary_gates):
            
                    9. 	def backward(_x, _output, grad_output):
            
                    10.     	binary_grads = torch.zeros_like(binary_gates.data)
            
                    11.         with torch.no_grad():
            
                    12.         	for k in range(len(candidate_ops)):
            
                    13.             	if k != active_id:
            
                    14.                 	out_k = candidate_ops[k](_x.data)
            
                    15.                 else:
            
                    16.                 	out_k = _output.data
            
                    17.                 grad_k = torch.sum(out_k * grad_output)
            
                    18.                 binary_grads[k] = grad_k
            
                    19.         return binary_grads
            
                    20. 	return backward
            
                    21.  
            
                    22. output = ArchGradientFunction.apply(
            
                    23. 	x, self.AP_path_wb, run_function(self.candidate_ops, self.active_index[0]),
            
                    24.     backward_function(self.candidate_ops, self.active_index[0], self.AP_path_wb)
            
                    25. )
        
    ```
* 其次在按下述反传公式确定**结构参数的梯度** (代码以取样一条path为例): 
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-31/27e0Cfvkg9wIY5iqVmsEr43ahOzT.png)
复制代码
                1. probs = self.probs_over_ops.data

        
                2. for i in range(self.n_choices):
        
                3.     for j in range(self.n_choices):
        
                4. 		self.AP_path_alpha.grad.data[i] += \
        
                5.         	binary_grads[j] * probs[j] * (delta_ij(i, j) - probs[i])
    
```
  • 文章也提到取样两条paths的效果:为了确保总体的softmax概率之和为零,两条path的结构参数更新之后,需要rescale,使得其中一条path的alpha会增强,另一条的alpha会衰减;
    • Handling non-differentiable hardware metrics * 与MNASNet、Netadapt、DenseNas的做法类似,通过Latency Look-up Table 获取每个op的资源预算(Latency等),然后累加、或按path probability加权求和,获得正则约束项:



      • Latency estimator 示例如下,表示MBConv 的查找表 (Look-up Table):
复制代码
              1. def predict(self, ltype: str, _input, output, expand=None, kernel=None, stride=None, idskip=None, ):

        
              2.         """
        
              3.         :param ltype:
        
              4.             Layer type must be one of the followings
        
              5.                 1. `Conv`: The initial 3x3 conv with stride 2.
        
              6.                 2. `Conv_1`: The upsample 1x1 conv that increases num_filters by 4 times.
        
              7.                 3. `Logits`: All operations after `Conv_1`.
        
              8.                 4. `expanded_conv`: MobileInvertedResidual
        
              9.         :param _input: input shape (h, w, #channels)
        
              10.         :param output: output shape (h, w, #channels)
        
              11.         :param expand: expansion ratio
        
              12.         :param kernel: kernel size
        
              13.         :param stride:
        
              14.         :param idskip: indicate whether has the residual connection
        
              15.         """
        
              16.         infos = [ltype, 'input:%s' % self.repr_shape(_input), 
        
              17.                  'output:%s' % self.repr_shape(output), ]
        
              18.  
        
              19.         if ltype in ('expanded_conv',):
        
              20.             assert None not in (expand, kernel, stride, idskip)
        
              21.             infos += ['expand:%d' % expand, 'kernel:%d' % kernel, 
        
              22.                       'stride:%d' % stride, 'idskip:%d' % idskip]
        
              23.         key = '-'.join(infos)
        
              24.         return self.lut[key]['mean']
    

Experiments




全部评论 (0)

还没有任何评论哟~