以前在神经网络训练中,只是对输入层数据进行归一化处理,却没有在中间层进行归一化处理。要知道,虽然我们对输入数据进行了归一化处理,但是输入数据经过矩阵乘法以及非线性运算之后,其数据分布很可能被改变,而随着深度网络的多层运算之后,数据分布的变化将越来越大。如果我们能在网络的中间也进行归一化处理,是否对网络的训练起到改进作用呢?答案是肯定的。
这种在神经网络中间层也进行归一化处理,使训练效果更好的方法,就是批归一化Batch Normalization(BN)。
下面总结一下BN算法的优点:
下面给出 BN 算法在训练时的过程
输入:上一层输出结果 X = x 1 , x 2 , . . . , x m X = {x_1, x_2, ..., x_m} X=x1,x2,...,xm,学习参数 γ , β \gamma, \beta γ,β
算法流程:
μ β = 1 m ∑ i = 1 m ( x i ) \mu_{\beta} = \frac{1}{m} \sum_{i=1}^m(x_i) μβ=m1i=1∑m(xi)
其中,m是此次训练样本 batch 的大小。
σ β 2 = 1 m ∑ i = 1 m ( x i − μ β ) 2 \sigma_{\beta}^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\beta})^2 σβ2=m1i=1∑m(xi−μβ)2
x ^ i = x i + μ β σ β 2 + ϵ \hat x_i = \frac{x_i + \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} x^i=σβ2+ϵxi+μβ
其中 ϵ \epsilon ϵ 是为了避免分母为 0 而加进去的接近于 0 的很小值
y i = γ x ^ i + β y_i = \gamma \hat x_i + \beta yi=γx^i+β
其中, γ , β \gamma, \beta γ,β 为可学习参数。
注:上述是 BN 训练时的过程,但是当在投入使用时,往往只是输入一个样本,没有所谓的均值 μ β \mu_{\beta} μβ 和标准差 σ β 2 \sigma_{\beta}^2 σβ2。此时,均值 μ β \mu_{\beta} μβ 是计算所有 batch μ β \mu_{\beta} μβ 值的平均值得到,标准差 σ β 2 \sigma_{\beta}^2 σβ2 采用每个batch σ β 2 \sigma_{\beta}^2 σβ2 的无偏估计得到。
名称 | 特点 |
---|---|
批量归一化(Batch Normalization,以下简称 BN) | 可让各种网络并行训练。但是,批量维度进行归一化会带来一些问题——批量统计估算不准确导致批量变小时,BN 的误差会迅速增加。在训练大型网络和将特征转移到计算机视觉任务中(包括检测、分割和视频),内存消耗限制了只能使用小批量的 BN。 |
群组归一化 Group Normalization (简称 GN) | GN 将通道分成组,并在每组内计算归一化的均值和方差。GN 的计算与批量大小无关,并且其准确度在各种批量大小下都很稳定。 |
比较 | 在 ImageNet 上训练的 ResNet-50上,GN 使用批量大小为 2 时的错误率比 BN 的错误率低 10.6% ;当使用典型的批量时,GN 与 BN 相当,并且优于其他标归一化变体。而且,GN 可以自然地从预训练迁移到微调。在进行 COCO 中的目标检测和分割以及 Kinetics 中的视频分类比赛中,GN 可以胜过其竞争对手,表明 GN 可以在各种任务中有效地取代强大的 BN。 |
Weight Normalization 和 Batch Normalization 都属于参数重写(Reparameterization)的方法,只是采用的方式不同。
Weight Normalization 是对网络权值 W W W 进行 normalization,因此也称为 Weight Normalization;
Batch Normalization 是对网络某一层输入数据进行 normalization。
Weight Normalization相比Batch Normalization有以下三点优势:
在CNN中,BN应作用在非线性映射前。在神经网络训练时遇到收敛速度很慢,或梯度爆炸等无法训练的状况时可以尝试BN来解决。另外,在一般使用情况下也可以加入BN来加快训练速度,提高模型精度。
BN比较适用的场景是:每个mini-batch比较大,数据分布比较接近。在进行训练之前,要做好充分的shuffle,否则效果会差很多。另外,由于BN需要在运行过程中统计每个mini-batch的一阶统计量和二阶统计量,因此不适用于动态的网络结构和RNN网络。