数据经过一层层网络之后,输出的数据分布会发生变化,此现象称为Internal Covariate Shift,会给下一层的网络学习带来困难:
如果将每一层的输出都归一化为标准正态分布,均值为0,方差为1,会导致网络完全学习不到输入数据的特征,因为所有的特征都被归一化了。
底层网络学习到的参数信息被丢掉了,降低了网络的数据表达能力。
求出整个batch数据的均值:
μ b = 1 m ∑ i = 1 m x i \mu_b=\frac{1}{m}\sum_{i=1}^mx_i μb=m1i=1∑mxi
求出整个batch数据的方差:
σ b 2 = 1 m ∑ i = 1 m ( x i − μ b ) 2 \sigma^2_b = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_b)^2 σb2=m1i=1∑m(xi−μb)2
将数据归一化为标准正态分布:
x ^ i = x i − μ b σ b 2 + ε \hat{x}_i = \frac{x_i-\mu_b}{\sqrt{\sigma^2_b+\varepsilon}} x^i=σb2+εxi−μb
其中的 ε \varepsilon ε是为了防止方差为0。
引入平移和缩放参数,得到最终的归一化结果:
y i = γ x ^ i + β y_i = \gamma\hat{x}_i+\beta yi=γx^i+β
Question: Is the estimate of the population variance that arises in this way using the sample mean always smaller than what we would get if we used the population mean?
Answer: yes except when the sample mean happens to be the same as the population mean.
We are seeking the sum of squares of distances from the population mean, but end up calculating the sum of squares of differences from the sample mean, which, as will be seen, is the number that minimizes that sum of squares of distances. So unless the sample happens to have the same mean as the population, this estimate will always underestimate the sum of squared differences from the population mean.
证明过程如下:
设样本的平均值是 μ s \mu_s μs,设总体的平均值是 μ \mu μ,则
(1) μ s = 1 m ∑ i = 1 m x i \mu_s = \frac{1}{m}\sum_{i=1}^mx_i \tag{1} μs=m1i=1∑mxi(1)
令 a i = x i − μ s a_i = x_i - \mu_s ai=xi−μs, b = μ s − μ b = \mu_s - \mu b=μs−μ,则
(2) ( x i − μ ) 2 = ∣ ( x i − μ s ) + ( μ s − μ ) ∣ 2 = ∣ a i + b ∣ 2 = a i 2 + 2 a i b + b 2 (x_i - \mu)^2 = |(x_i - \mu_s) + (\mu_s - \mu)|^2 = |a_i+b|^2 = a^2_i+2a_ib+b^2 \tag{2} (xi−μ)2=∣(xi−μs)+(μs−μ)∣2=∣ai+b∣2=ai2+2aib+b2(2)
将样本中的所有元素都用式(2)进行转换,得到:
(3) ∑ i = 1 m ( x i − μ ) 2 = ∑ i = 1 m a i 2 + ∑ i = 1 m b 2 + 2 ∑ i = 1 m a i b = ∑ i = 1 m a i 2 + ∑ i = 1 m b 2 + 2 b ∑ i = 1 m a i = ∑ i = 1 m a i 2 + ∑ i = 1 m b 2 = ∑ i = 1 m ( x i − μ s ) 2 + ∑ i = 1 m b 2 ≥ ∑ i = 1 m ( x i − μ s ) 2 \begin{aligned} \sum_{i=1}^m(x_i - \mu)^2 &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2 + 2 \sum_{i=1}^ma_ib\\ &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2 + 2 b\sum_{i=1}^ma_i\\ &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2\\ &= \sum_{i=1}^m(x_i-\mu_s)^2 + \sum_{i=1}^mb^2\\ &\ge \sum_{i=1}^m(x_i-\mu_s)^2 \tag{3} \end{aligned} i=1∑m(xi−μ)2=i=1∑mai2+i=1∑mb2+2i=1∑maib=i=1∑mai2+i=1∑mb2+2bi=1∑mai=i=1∑mai2+i=1∑mb2=i=1∑m(xi−μs)2+i=1∑mb2≥i=1∑m(xi−μs)2(3)
化简式(3)用到了下式:
∑ i = 1 m a i = ∑ i = 1 m ( x i − μ s ) = ∑ i = 1 m x i − m μ s = 0 \sum_{i=1}^ma_i= \sum_{i=1}^m(x_i-\mu_s) = \sum_{i=1}^mx_i - m\mu_s = 0 i=1∑mai=i=1∑m(xi−μs)=i=1∑mxi−mμs=0
由式(3)可知,总体的实际方差总是大于(或等于)使用样本均值估计得到的总体方差,并且只有当样本均值与总体均值正好相等时,等号才成立,对总体方差的估计才是准确的。
样本均值如式(1)所示。
样本方差:
σ s 2 = 1 m ∑ i = 1 m ( x i − μ s ) 2 = ∑ i = 1 m x i 2 m − ( ∑ i = 1 m x i ) 2 m 2 \sigma_s^2 = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_s)^2 = \frac{\sum_{i=1}^mx_i^2}{m} - \frac{\left(\sum_{i=1}^mx_i\right)^2}{m^2} σs2=m1i=1∑m(xi−μs)2=m∑i=1mxi2−m2(∑i=1mxi)2
上式其实就是方差公式: D ( x ) = E ( x 2 ) − ( E ( x ) ) 2 D(x) = E(x^2) - (E(x))^2 D(x)=E(x2)−(E(x))2
校正后的总体方差,使用 m − 1 m-1 m−1替换公式中的 m m m,得到:
σ 2 = 1 m − 1 ∑ i = 1 m ( x i − μ s ) 2 = ∑ i = 1 m x i 2 m − 1 − ( ∑ i = 1 m x i ) 2 m ( m − 1 ) = m m − 1 σ s 2 \sigma^2 = \frac{1}{m-1}\sum_{i=1}^m(x_i-\mu_s)^2 = \frac{\sum_{i=1}^mx_i^2}{m-1} - \frac{\left(\sum_{i=1}^mx_i\right)^2}{m(m-1)} = \frac{m}{m-1}\sigma^2_s σ2=m−11i=1∑m(xi−μs)2=m−1∑i=1mxi2−m(m−1)(∑i=1mxi)2=m−1mσs2
扩展阅读:
为什么样本方差(sample variance)的分母是 n-1? - 张英锋的回答 - 知乎
https://www.zhihu.com/question/20099757/answer/658048814
BN使得网络中每层输入数据的分布变得相对稳定,可以加快网络的学习速度;
BN在一定程度上可以把靠近梯度饱和区的数据分布,往中心的线性区拉一点,增大梯度。
BN可以简化调参过程,使得网络更加稳定;
BN可以抑制参数的微小变化随着网络层数加深而被放大的问题,增强网络对参数取值的适应能力。
BN使得网络可以使用sigmoid和tanh等激活函数,缓解梯度消失问题;
BN有一定的正则化效果。
前向传播过程:
z l = w l a l − 1 + b l z^l = w^la^{l-1}+b^l zl=wlal−1+bl
施加Batch Norm之后,需要对 z l z^l zl求均值和方差。在求均值的过程中,参数 b l b^l bl被抵消掉了。因此,在实际的网络结构中,可以不使用参数 b l b^l bl,前向传播过程变为:
z l = w l a l − 1 z^l = w^la^{l-1} zl=wlal−1
随后再进行Batch Norm。
Ioffe, S. and Szegedy, C., 2015. Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. ↩︎
https://en.wikipedia.org/wiki/Bessel's_correction#Proof_of_correctness_-_Alternate_3 ↩︎