Proxyless NAS: Direct Neural Architecture Search on Target Task and Hardware
发布时间
阅读量:
阅读量
Paper:https://arxiv.org/abs/1812.00332
GitHub (PyTorch):https://github.com/mit-han-lab/proxylessnas
Introduction
- NAS 能够针对目标任务、以及目标部署平台,按自动化方式、搜索出性能与效率折中最好的网络结构;
- 如果目标任务的数据集规模比较大,通常会设计Proxy task (简化的中间任务) ,作为桥接目标任务的桥梁:
- 在Proxy task上的评估结果,可间接反映目标任务上的效果,用来指导NAS选择合适的网络结构;
- 但基于Proxy task的搜索结果,未必是目标任务上的最优结果,可能会导致次优化 ;
- 如果进一步考虑目标平台的资源预算 (latency, power consumption, memory等),需要在搜索过程中添加相应的资源约束 ;
- 如果仅堆叠重复的building block,会影响网络结构的多样性 ;
- 本文提出了Proxyless NAS:
- 能够结合部署平台的资源约束 ,按可微分方式 、直接在目标任务上执行NAS进程,节省了时间成本 ;
- 搜索空间的构建不局限于重复的building block,因此允许支持相对更大的候选集 ;
- 可微分搜索过程中,采用二值化 方式筛选子网络结构,节省了内存开销 ;
- 搜索效果:
1. 在ImageNet上,仅需200 GPU hours便可搜索出优于MobileNetV2的网络;
2. 以8卡V100为例,搜索成本仅需25hours,相当于训练MobileNetV1/V2的时间开销;

Related Work
- 早期的NAS或AutoML,通常以Reinforcement learning或Evolution algorithm作为meta controller (例如RNN、EA等),在搜索空间内按离散方式选择或组建网络结构:
- 如果选择RL策略,需要在内层循环 训练、评估每次搜索的结果,而在外层循环 更新RNN controller,以指导下一次更为有效的搜索;
- 为了节省搜索成本,需要设计合理的Proxy task ,然而搜索开销依旧非常高;
- 对于较复杂的模型 (如Object detection),通常只搜索部分组件 (如FPN、或backbone);
- 典型代表包括: NASNet、MNASNet、NAS-FCOS、NAS-FPN、Evolved Transformer等;
- 事先训练超网络 (super-network)或元网络 (meta-network) ,再通过参数共享方式采样、评估子网络结构:
- 这种方法可避免子网络的训练,因而节省了搜索成本;
- 然而网络结构相对单一,且训练成本、收敛难度转移到了超网络的训练;
- DARTS 是典型的Gradient-based NAS方法:
- 为每个candidate cell分配结构参数,实现可微分搜索;
- 但是在网络参数的学习过程中,每个candidate cell都参与训练,因此内存开销会随着搜索空间的增加而显著增加;
- 本文提出的Proxyless NAS也是Gradient-based NAS方法:
- GPU hours与GPU memory都相对更低;
- 并且,通过将部署平台的资源预算 (Latency等)作为次优化目标,Proxyless NAS具备hardware-aware的适应能力;
- 其他类似的兼顾搜索效率、与资源约束的NAS方法,还有FBNet、DenseNAS、NetAdapt 等;
Method
- Basics: * Over-parameterized Super-network由N条edge构成:
-
每条edge都是N个Candidate primitive operation paths的聚合:

-
超网络又可以表示为:

-
One-shot方法 的聚合方式为:

-
DARTS方法 的聚合方式为:

-
One-shot与DARTS需要将N个paths都存放到内存中,因此内存消耗是单个path的N倍;
-
- Binarized path of Proxyless NAS: * 与DARTS一样,为N条path定义可微分的结构参数:
-
为节省内存开销,在训练超网络过程中,通过结构参数的二值化,只保留单个path:

-
Training binarized architecture parameters:

-
第一阶段: 固定结构参数,在训练集上,根据path probability取样candidate paths,然后更新网络参数(代码以取样一条path为例):
-
1. self.log_prob = None
2. # reset binary gates
3. self.AP_path_wb.data.zero_()
4. # binarize according to probs
5. probs = self.probs_over_ops
6.
7. sample = torch.multinomial(probs.data, 1)[0].item()
8. self.active_index = [sample]
9. self.inactive_index = [_i for _i in range(0, sample)] + \
10. [_i for _i in range(sample + 1, self.n_choices)]
11. # set binary gate
12. self.AP_path_wb.data[sample] = 1.0
- 第二阶段: 固定网络参数,在验证集上,更新结构参数:
首先确定反传到binary gate的梯度 :
1. # foward of edge
2. def run_function(candidate_ops, active_id):
3. def forward(_x):
4. return candidate_ops[active_id](_x)
5. return forward
6.
7. # backward of edge
8. def backward_function(candidate_ops, active_id, binary_gates):
9. def backward(_x, _output, grad_output):
10. binary_grads = torch.zeros_like(binary_gates.data)
11. with torch.no_grad():
12. for k in range(len(candidate_ops)):
13. if k != active_id:
14. out_k = candidate_ops[k](_x.data)
15. else:
16. out_k = _output.data
17. grad_k = torch.sum(out_k * grad_output)
18. binary_grads[k] = grad_k
19. return binary_grads
20. return backward
21.
22. output = ArchGradientFunction.apply(
23. x, self.AP_path_wb, run_function(self.candidate_ops, self.active_index[0]),
24. backward_function(self.candidate_ops, self.active_index[0], self.AP_path_wb)
25. )
```
* 其次在按下述反传公式确定**结构参数的梯度** (代码以取样一条path为例):

1. probs = self.probs_over_ops.data
2. for i in range(self.n_choices):
3. for j in range(self.n_choices):
4. self.AP_path_alpha.grad.data[i] += \
5. binary_grads[j] * probs[j] * (delta_ij(i, j) - probs[i])
```
- 文章也提到取样两条paths的效果:为了确保总体的softmax概率之和为零,两条path的结构参数更新之后,需要rescale,使得其中一条path的alpha会增强,另一条的alpha会衰减;
-
Handling non-differentiable hardware metrics * 与MNASNet、Netadapt、DenseNas的做法类似,通过Latency Look-up Table 获取每个op的资源预算(Latency等),然后累加、或按path probability加权求和,获得正则约束项:




- Latency estimator 示例如下,表示MBConv 的查找表 (Look-up Table):
-
1. def predict(self, ltype: str, _input, output, expand=None, kernel=None, stride=None, idskip=None, ):
2. """
3. :param ltype:
4. Layer type must be one of the followings
5. 1. `Conv`: The initial 3x3 conv with stride 2.
6. 2. `Conv_1`: The upsample 1x1 conv that increases num_filters by 4 times.
7. 3. `Logits`: All operations after `Conv_1`.
8. 4. `expanded_conv`: MobileInvertedResidual
9. :param _input: input shape (h, w, #channels)
10. :param output: output shape (h, w, #channels)
11. :param expand: expansion ratio
12. :param kernel: kernel size
13. :param stride:
14. :param idskip: indicate whether has the residual connection
15. """
16. infos = [ltype, 'input:%s' % self.repr_shape(_input),
17. 'output:%s' % self.repr_shape(output), ]
18.
19. if ltype in ('expanded_conv',):
20. assert None not in (expand, kernel, stride, idskip)
21. infos += ['expand:%d' % expand, 'kernel:%d' % kernel,
22. 'stride:%d' % stride, 'idskip:%d' % idskip]
23. key = '-'.join(infos)
24. return self.lut[key]['mean']
Experiments




全部评论 (0)
还没有任何评论哟~
