Advertisement

深度学习论文: Rethinking “Batch” in BatchNorm及其PyTorch实现

阅读量:

深度学习论文: Re-examining "Batch" in BatchNorm及其实现的PyTorch版本。
Rethinking "Batch" in BatchNorm.
PDF: https://arxiv.org/pdf/2105.07576.pdf
PyTorch代码: https://github.com/shanglianlm0525/CvPytorch
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

批归(Batch Normalization)是现代卷神经网络的核心组件。其显著特点是针对"批次"而非单个样本进行处理,在实际应用中却存在诸多潜在问题可能通过微小途径影响模型性能。本文将系统评估视觉识别任务中的此类问题,并探讨解决这些问题的关键在于重新审视批归中"批次"概念的不同取向。通过对相关挑战及应对策略的阐述, 本综述旨在帮助研究者更高效地运用批归技术

2 A Review of BatchNorm

BatchNorm的计算过程如下:

在这里插入图片描述

其中训练过程中的 µ and σ^{2} (使用来自同一batch数据)的计算如下:

在这里插入图片描述

但是,推理时 µ and σ^{2} 来自全部训练集的统计。

在选择"批次"时存在多种可能性,在具体实施过程中需要考虑多个因素以确保数据处理的有效性。具体而言,则是探究 µ 和 σ² 的数据来源及计算方式之间的差异问题。这些差异可能导致结果的一致性受到威胁,并最终影响模型的一般化能力。

3 Whole Population as a Batch

指数加权平均(EWMA)是一种被用来快速计算总体统计量的方法。在深度学习领域中,这种方法现已成为主流的标准算法之一。

在这里插入图片描述

虽然在实际应用中被普遍采用, 但EMA指数平均法仍可能对总体统计数据产生次优估计的效果, 其具体原因包括以下几点: 首先, 数据分布的不均衡可能导致模型拟合效果受限; 其次, 样本选择标准的松散性可能影响结果准确性; 最后, 指数平滑参数的选择不够科学也会导致预测偏差

  • 当λ较大时, 统计量的收敛速度变慢. 每次更新迭代仅向指数平均法 (EMA) 提供一个较小的部分 (1−λ). 因此, 达成稳定估计值需要大量迭代. 随着模型不断更新, 情况会更加糟糕, 因为EMA主要受到过时的过去输入特征的影响.
  • 当较小时, EMA统计量主要由 recent 的小批量数据主导, 无法反映整个总体.
在这里插入图片描述

研究表明,在训练初期,指数平均法(EMA)难以精确反映小规模数据集的小批量统计和总体数据集的统计。基于此建议采用PreciseBN。

PreciseBN

旨在准确计算整个训练集的数据特征值分布情况, PreciseBN主要采用以下两项技术手段:
1.通过反复使用相同的模型架构来累积各批次的运行结果
2.通过聚合所有批次的数据特征生成总体特征分布

在训练过程中,在每个迭代周期中采用大小为B的小批量进行PreciseBN统计量计算。进而累计求得N/B次的结果。

在这里插入图片描述

相较于传统的EMA方法,PreciseBN凭借两大核心优势脱颖而出。
1.PreciseBN中的统计指标均通过同一套模型推导得出,与之相比,EMA则依赖于多个不同历史版本的数据进行综合考量。
2.在这一过程中,PreciseBN为每个样本点统一赋予权重,而EMA则赋予各个不同的样本点各自独立且各异的权重。

PreciseBN代码:

复制代码
    import torch
    import torch.nn as nn
    
    class PreciseBN(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(PreciseBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
    
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
    
        x = (x - mean) / (torch.sqrt(var + self.eps))
        return x
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4 Batch in Training and Testing

BN在训练与测试阶段呈现出运行模式差异:在训练过程中,BN通过 mini-batch 的数据统计量进行参数更新;而在测试阶段,则基于全体样本集合(population)计算相应的统计参数。本研究深入分析了这种行为不一致性对模型性能的影响,并提出了一种有效策略以解决这一问题。

4-1 Effect of Normalization Batch Size

normalization batch size 直接作用于 training noise 和 train-test inconsistency;采用更大的批量后, 其mini-batch的统计特性趋近于总体数据的统计特征, 从而有效降低training noise以及train-test inconsistency

在这里插入图片描述

Training noise: 当normalization batch size极小时, 每个样本都会受到同一批次中其他样本的显著影响, 从而导致训练效果欠佳, 并使优化过程变得极为困难.

Generalization gap: 当归一化批次大小减小时,在mini-batch中验证集与训练集之间的泛化误差呈现出单调递减的趋势。这种现象可能归因于training noise和train-test inconsistency这两者的正则化效应有所减弱。

Train-test inconsistency:

4-2 Use Mini-batch in Inference

在这里插入图片描述

研究者通过Mask R-CNN展开实验研究,在对比实验中发现基于小批量的方法表现优于基于种群的方法,并通过实验证明了在推理过程中采用小批量方法能够有效地缓解训练与测试阶段之间的不一致性问题。

4-3 Use Population Batch in Training

为了在训练阶段中利用population统计量,在研究中采用了 FrozenBN 技术。该方法依赖于 population 统计量。具体而言,在第 80 个 epoch 时作者选择了该模型作为基准,并在此基础上将所有 BN 转换为 FrozenBN 后进行了 20 个 epoch 的训练过程。

在这里插入图片描述

FrozenBN能够有效缓解训练与测试之间的不一致性问题。尽管在小规模的归一化批次情况下仍能取得较好的性能表现。然而,在归一化批次大小逐渐增大的情况下,则发现作者所提出的方法较之常规BN的效果略显不足。

5 Batch from Different Domains

BN的训练过程可划分为两个相互独立的阶段:第一阶段主要通过SGD提取特征表示;第二阶段则利用这些特征表示生成对应的总体统计量。这两个阶段分别命名为SGD训练和总体统计量提取。

由于BN新增了一个population统计环节的原因源于其结构上的优化[1] ,因此会导致训练与测试之间的 domain shift[2] 。而当数据源自多个 domain 时[3] ,SGD training[4] 、 population statistics training[5] 和 testing[6] 三个步骤所造成的 domain gap 均会直接影响模型的泛化能力。

5-1 Domain to Compute Population Statistics

在这里插入图片描述

研究显示:在遇到明显的domain shift的情况下

5-2 BatchNorm in Multi-Domain Training

在这里插入图片描述

研究表明,在保持一致的前提下非常关键的是SGD训练、人口统计训练以及testing的一致性。此外,在保持一致的前提下非常关键的是通过采用领域特定的方法能够达到最佳效果。相较于其他方法而言,GN的表现更为出色。

6 Information Leakage within a Batch

在实际应用中,Batch Normalization(BN)方法也面临着信息泄漏的问题。其本质是基于小批量数据估计并归一化的统计参数,在这种机制下,在模型对单个样本进行独立预测时能够间接利用该批次内其他样本的信息来辅助当前样本的归一化过程。

6-1 Exploit Patterns in Mini-batches

在这里插入图片描述

通过实验研究发现,在采用随机采样下的mini-batch统计量时(...),验证误差会出现上升现象。而在采用总体统计量的情况下(...),则会随着训练轮次的增多而逐渐出现升高的趋势。这些观察结果从而实证证明了BN信息泄露问题的存在性

在这里插入图片描述

为了解决信息泄露问题,过去常用的方法是采用同步归一化技术(SyncBN),以减弱mini-batch内部样本之间的相关性。另外一种解决方案是将RoI features分配到各个GPU,并对该区域分配的样本进行归一化处理,在进入头部层前完成此操作。这样做不仅降低了min-batch间的相关性。

6-2 Cheating in Contrastive Learning

在对比学习与度量学习中,在小批量数据下的对比分析是其训练目标之一;然而,在这一过程中Batch Normalization(BN)也存在潜在风险,其机制可能导致了信息泄露;此前已有诸多研究就针对对比学习与度量学习中的信息泄露问题展开了系统性研究。

假设有n个样本,则通过对比学习会产生2n个样本,并以无序配对的方式成n \times (2)的形式存在。从这些配对中选取一对进行对比学习时,则由于每个mini-batch的统计量是共享的,在这样的情况下,在n轮对比学习中会导致每个样本的信息被多次引入到不同的配对中去。从而导致n轮信息泄露。

全部评论 (0)

还没有任何评论哟~