本文探讨 Batch Normalization 在测试或推断时使用的算法及其原理.
配套代码, 请参考文章 :
Python和PyTorch对比实现批标准化 Batch Normalization 函数在测试或推理过程中的算法.
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
Batch Normalization 在训练时使用每批数据的均值和方差进行数据的规范化, 在测试或推断的时候使用全体数据的特征.
我们不可能在训练前再次遍历全体数据的特征, 耗时太大. 也不可能记录所有批次的中间结果, 内存消耗太大.
目前主流的深度学习框架 TensorFlow 和 PyTorch 等采用的是参数估计加滑动平均法, 并引入的一个超参数来解决这个问题.
接下来我们详细探讨这个方法的理论基础.
已知一个 k k k 维数组 x x x 的均值为 μ \mu μ, 标准差为 σ \sigma σ, 排列成列数量不相等的矩阵 X X X , 共 m m m 行, 每行 n i n_i ni 个元素, 下标 i i i 表示第 i i i 行.
其形式类似如下 :
x 11 x_{11} x11 | x 12 x_{12} x12 | x 13 x_{13} x13 | x 14 x_{14} x14 | ( n 1 = 4 ) (n_1 = 4) (n1=4) | |
---|---|---|---|---|---|
x 21 x_{21} x21 | x 22 x_{22} x22 | x 23 x_{23} x23 | ( n 2 = 3 ) (n_2 = 3) (n2=3) | ||
x 31 x_{31} x31 | x 32 x_{32} x32 | x 33 x_{33} x33 | x 34 x_{34} x34 | x 35 x_{35} x35 | ( n 3 = 5 ) (n_3 = 5) (n3=5) |
x 41 x_{41} x41 | x 42 x_{42} x42 | ( n 4 = 2 ) (n_4 = 2) (n4=2) | |||
x 51 x_{51} x51 | x 52 x_{52} x52 | x 53 x_{53} x53 | x 54 x_{54} x54 | ( n 5 = 3 ) (n_5 = 3) (n5=3) |
设第 i i i 组的均值为 x ˉ i \bar x_i xˉi, 标准差为 s i s_i si. 求 x ˉ \bar x xˉ 与 μ \mu μ 的关系, 求 s s s 与 σ \sigma σ 的关系.
由题意, 知矩阵的元素的总数量 k k k , 均值 μ \mu μ, 方差 σ 2 \sigma^2 σ2 :
k = ∑ i = 1 m n i    μ = 1 k ∑ i = 1 m ∑ j = 1 n i x i j    σ 2 = 1 k ∑ i = 1 m ∑ j = 1 n i ( x i j − μ ) 2 k = \sum_{i=1}^{m}n_i \;\\ \mu =\frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}\\ \;\\ \sigma^2 = \frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2 k=i=1∑mniμ=k1i=1∑mj=1∑nixijσ2=k1i=1∑mj=1∑ni(xij−μ)2
因为 :
∑ i = 1 m ∑ j = 1 n i x i j = ∑ j = 1 n n i x ˉ i \sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}=\sum_{j=1}^{n}n_i\bar x_i\\ i=1∑mj=1∑nixij=j=1∑nnixˉi
μ = 1 k ∑ j = 1 n n i x ˉ i \mu=\frac{1}{k}\sum_{j=1}^{n}n_i\bar x_i μ=k1j=1∑nnixˉi
令 n = 2 n = 2 n=2 :
μ = n 1 x ˉ 1 + n 2 x ˉ 2 ( n 1 + n 2 ) \mu=\frac{n_1\bar x_1 +n_2\bar x_2}{(n_1+n_2)} μ=(n1+n2)n1xˉ1+n2xˉ2
改成迭代式 :
μ i + 1 = k i μ i + n i + 1 x ˉ i + 1 ( k i + n i + 1 ) \mu_{i+1}=\frac{k_i\mu_i +n_{i+1}\bar x_{i+1}}{(k_i+n_{i+1})} μi+1=(ki+ni+1)kiμi+ni+1xˉi+1
式中 μ i \mu_i μi 表示前 i i i 行 (包括 i i i) 所有元素的均值, k i k_i ki 表示前 k i k_i ki 行 (包括 i i i) 所有元素的总数量.
在编程实战中, 使用迭代式, 可以分步计算总体的均值, 节省宝贵的计算机内存.
总方差 :
σ 2 = 1 k ∑ i = 1 m ∑ j = 1 n i ( x i j − μ ) 2 = 1 k ∑ i = 1 m ∑ j = 1 n i ( x i j 2 − 2 x i j μ − μ 2 )    = 1 k ( ∑ i = 1 m ∑ j = 1 n i x i j 2 − 2 ∑ i = 1 m ∑ j = 1 n i x i j μ + ∑ i = 1 m ∑ j = 1 n i μ 2 )    = 1 k ( ∑ i = 1 m ∑ j = 1 n i x i j 2 − 2 k μ 2 + k μ 2 ) \sigma^2 = \frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2=\frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}(x_{ij}^2-2x_{ij}\mu-\mu^2)\\ \;\\ =\frac{1}{k}(\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}^2-2\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}\mu+\sum_{i=1}^{m}\sum_{j=1}^{n_i}\mu^2)\\ \;\\ =\frac{1}{k}(\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}^2-2k\mu^2+k\mu^2)\\ σ2=k1i=1∑mj=1∑ni(xij−μ)2=k1i=1∑mj=1∑ni(xij2−2xijμ−μ2)=k1(i=1∑mj=1∑nixij2−2i=1∑mj=1∑nixijμ+i=1∑mj=1∑niμ2)=k1(i=1∑mj=1∑nixij2−2kμ2+kμ2)
得 :
∑ i = 1 m ∑ j = 1 n i x i j 2 = k σ 2 + k μ 2 = k ( σ 2 + μ 2 ) \sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}^2=k\sigma^2+k\mu^2=k(\sigma^2+\mu^2) i=1∑mj=1∑nixij2=kσ2+kμ2=k(σ2+μ2)
同样的, 对于第 i i i 行样本方差, 有 :
∑ j = 1 n i x i j 2 = n i s i 2 + k i n ˉ i 2 = n i ( s i 2 + x ˉ i 2 ) \sum_{j=1}^{n_i}x_{ij}^2=n_is_i^2+k_i\bar n_i^2=n_i(s_i^2+\bar x_i^2) j=1∑nixij2=nisi2+kinˉi2=ni(si2+xˉi2)
得 :
∑ i = 1 m n i ( s i 2 + x ˉ i 2 ) = k ( σ 2 + μ 2 )    σ 2 = 1 k ∑ i = 1 m n i ( s i 2 + x ˉ i 2 ) − μ 2 \sum_{i=1}^{m}n_i(s_i^2+\bar x_i^2)=k(\sigma^2+\mu^2)\\ \;\\ \sigma^2=\frac{1}{k}\sum_{i=1}^{m}n_i(s_i^2+\bar x_i^2)-\mu^2 i=1∑mni(si2+xˉi2)=k(σ2+μ2)σ2=k1i=1∑mni(si2+xˉi2)−μ2
令 n = 2 n = 2 n=2 :
σ 2 = n 1 ( s 1 2 + x ˉ 1 2 ) + n 2 ( s 2 2 + x ˉ 2 2 ) n 1 + n 2 − μ 2 \sigma^2=\frac{n_1(s_1^2+\bar x_1^2)+n_2(s_2^2+\bar x_2^2)}{n_1+n_2}-\mu^2 σ2=n1+n2n1(s12+xˉ12)+n2(s22+xˉ22)−μ2
改成迭代式 :
σ i + 1 = k i ( σ i 2 + μ i 2 ) + n i + 1 ( s i + 1 2 + x ˉ i + 1 2 ) ( k i + n i + 1 ) − μ i + 1 2 \sigma_{i+1}=\frac{k_i(\sigma_i^2+\mu_i^2) +n_{i+1}(s_{i+1}^2+\bar x_{i+1}^2)}{(k_i+n_{i+1})}-\mu_{i+1}^2 σi+1=(ki+ni+1)ki(σi2+μi2)+ni+1(si+12+xˉi+12)−μi+12
式中 μ i \mu_i μi 表示前 i i i 行 (包括 i i i) 所有元素的均值, σ i \sigma_i σi 表示前 i i i 行 (包括 i i i) 所有元素的标准差, k i k_i ki 表示前 k i k_i ki 行 (包括 i i i) 所有元素的总数量.
使用迭代公式计算总体方差会明显的加大计算量, 并不是理想的算法. 仅适用于需要精确解, 同时对内存消耗非常敏感的环境中.
在概率论和数理统计学中, 期望 (expectation) 的定义是 :
E ( X ) = ∑ i = 1 m ∑ j = 1 n i p i j x i j E(X)=\sum_{i=1}^{m}\sum_{j=1}^{n_i}p_{ij}x_{ij} E(X)=i=1∑mj=1∑nipijxij
p i j p_{ij} pij 是 x i j x_{ij} xij 出现的概率. 在本例中, p i j = 1 / k p_{ij}=1/k pij=1/k :
E ( X ) = 1 k ∑ i = 1 m ∑ j = 1 n i x i j = μ E(X)=\frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}=\mu E(X)=k1i=1∑mj=1∑nixij=μ
对 x ˉ ,    s 2 \bar x,\;s^2 xˉ,s2 使用期望公式 :
E ( x ˉ ) = 1 m ∑ i = 1 m x ˉ i = 1 m ∑ i = 1 m ∑ j = 1 n i x i j n i = μ    E ( s 2 ) = 1 m ∑ i = 1 m s i 2 E(\bar x)=\frac{1}{m}\sum_{i=1}^{m}\bar x_i=\frac{1}{m}\sum_{i=1}^{m}\frac{\sum_{j=1}^{n_i}x_{ij}}{n_i}=\mu\\ \;\\ E(s^2)=\frac{1}{m}\sum_{i=1}^{m}s_i^2 E(xˉ)=m1i=1∑mxˉi=m1i=1∑mni∑j=1nixij=μE(s2)=m1i=1∑msi2
因为 :
s i 2 = ∑ j = 1 n i ( x i j − x ˉ i ) 2 n i = ∑ j = 1 n i ( x i j − μ + μ − x ˉ i ) 2 n i    = ∑ j = 1 n i ( x i j − μ ) 2 + 2 ( x i j − μ ) ( μ − x ˉ i ) + ( μ − x ˉ i ) 2 n i    = 1 n i ∑ j = 1 n i ( x i j − μ ) 2 − 2 n i ∑ j = 1 n i ( x i j − μ ) ( x ˉ i − μ ) + 1 n i ∑ j = 1 n i ( μ − x ˉ i ) 2    = 1 n i ∑ j = 1 n i ( x i j − μ ) 2 − 2 ( x ˉ i − μ ) 2 + ( μ − x ˉ i ) 2    = 1 n i ∑ j = 1 n i ( x i j − μ ) 2 − ( x ˉ i − μ ) 2 s_i^2=\sum_{j=1}^{n_i}\frac{(x_{ij}-\bar x_i)^2}{n_i}=\sum_{j=1}^{n_i}\frac{(x_{ij}-\mu+\mu-\bar x_i)^2}{n_i}\\ \;\\ =\sum_{j=1}^{n_i}\frac{(x_{ij}-\mu)^2+2(x_{ij}-\mu)(\mu-\bar x_i)+(\mu-\bar x_i)^2}{n_i}\\ \;\\ =\frac{1}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2 -\frac{2}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)(\bar x_i-\mu) +\frac{1}{n_i}\sum_{j=1}^{n_i}(\mu-\bar x_i)^2\\ \;\\ =\frac{1}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2 -2(\bar x_i-\mu)^2+(\mu-\bar x_i)^2\\ \;\\ =\frac{1}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2 -(\bar x_i-\mu)^2\\ si2=j=1∑nini(xij−xˉi)2=j=1∑nini(xij−μ+μ−xˉi)2=j=1∑nini(xij−μ)2+2(xij−μ)(μ−xˉi)+(μ−xˉi)2=ni1j=1∑ni(xij−μ)2−ni2j=1∑ni(xij−μ)(xˉi−μ)+ni1j=1∑ni(μ−xˉi)2=ni1j=1∑ni(xij−μ)2−2(xˉi−μ)2+(μ−xˉi)2=ni1j=1∑ni(xij−μ)2−(xˉi−μ)2
所以 :
E ( s 2 ) = 1 m ∑ i = 1 m s i 2 = 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ( x i j − μ ) 2 − 1 m ∑ i = 1 m ( x ˉ i − μ ) 2 E(s^2)=\frac{1}{m}\sum_{i=1}^{m}s_i^2=\frac{1}{m}\sum_{i=1}^{m}\frac{1}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2-\frac{1}{m}\sum_{i=1}^{m}(\bar x_i-\mu)^2\\ E(s2)=m1i=1∑msi2=m1i=1∑mni1j=1∑ni(xij−μ)2−m1i=1∑m(xˉi−μ)2
当 n i = n ,    k = m n n_i = n,\;k=mn ni=n,k=mn 时 :
1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ( x i j − μ ) 2 = 1 m n ∑ i = 1 m ∑ j = 1 n i ( x i j − μ ) 2 = σ 2 \frac{1}{m}\sum_{i=1}^{m}\frac{1}{n_i}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2=\frac{1}{mn}\sum_{i=1}^{m}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2=\sigma^2 m1i=1∑mni1j=1∑ni(xij−μ)2=mn1i=1∑mj=1∑ni(xij−μ)2=σ2
E ( s 2 ) = σ 2 − 1 m ∑ i = 1 m ( x ˉ i − μ ) 2 E(s^2)=\sigma^2-\frac{1}{m}\sum_{i=1}^{m}(\bar x_i-\mu)^2 E(s2)=σ2−m1i=1∑m(xˉi−μ)2
根据统计学中的公式 :
V a r ( x ˉ ) = 1 m ∑ i = 1 m ( x ˉ i − μ ) 2 = σ n Var(\bar x)=\frac{1}{m}\sum_{i=1}^{m}(\bar x_i-\mu)^2=\frac{\sigma}{n} Var(xˉ)=m1i=1∑m(xˉi−μ)2=nσ
这个公式的证明请看后文. 汇总可得 :
μ ^ = E ( x ˉ )    σ ^ 2 = n n − 1 E ( s 2 ) \hat\mu=E(\bar x)\\ \;\\ \hat\sigma^2=\frac{n}{n-1}E(s^2) μ^=E(xˉ)σ^2=n−1nE(s2)
其中, n n n 为每一组样本中元素的个数, μ ^ ,    σ ^ \hat\mu,\;\hat\sigma μ^,σ^ 上面的尖号表示估计.
由此, 我们可以用样本均值的期望来估计总体均值, 使用样本方差的期望乘以修正系数 n / ( n − 1 ) n/(n-1) n/(n−1) 来估计总体方差.
在上例中, 我们得出了使用样本均值和方差估计整体均值和方差的办法. 但这里还有一个问题, 我们需要记录每一个批次的样本特征及其权重, 才能在需要的时候进行计算期望计算.
随着训练批次的增加, 对内存的消耗会变得无法忍受.
在目前流行的深度学习框架中, 比如 TensorFlow 和 PyTorch, 使用了滑动平均 (moving average) 的方法来替代期望求解.
滑动平均法的定义如下 :
x ^ i + 1 = ( 1 − η ) ∗ x ^ i + η ∗ x i + 1 \hat x_{i+1} =(1-\eta) * \hat x_i +\eta * x_{i+1} x^i+1=(1−η)∗x^i+η∗xi+1
x ^ i \hat x_i x^i 是历史总体特征的估计值, x i + 1 x_{i+1} xi+1 是当前样本特征 (若是样本方差则需要按的系数修正).
其中, η \eta η 为动量 (momentum), 是一个经验值为 0.01~ 0.1 的超参数.
注意, μ ^ \hat\mu μ^ 的初始值为 0, σ ^ 2 \hat\sigma^2 σ^2 的初始值为 1 .
下面的证明摘抄自《深入浅出统计学》(美) 格里菲斯(Griffiths, D.) 著
V a r ( x ˉ ) = V a r ( x 1 + x 2 + x 3 + ⋯ + x n n )    = V a r ( 1 n x 1 + 1 n x 2 + 1 n x 3 + ⋯ + 1 n x n ) Var(\bar x)=Var(\frac{x_1+x_2+x_3+\cdots+x_n}{n})\\ \;\\ =Var(\frac{1}{n}x_1+\frac{1}{n}x_2+\frac{1}{n}x_3+\cdots+\frac{1}{n}x_n) Var(xˉ)=Var(nx1+x2+x3+⋯+xn)=Var(n1x1+n1x2+n1x3+⋯+n1xn)
当 x i x_i xi 为相互独立的样本时 :
V a r ( 1 n x 1 + 1 n x 2 + 1 n x 3 + ⋯ + 1 n x n )    = V a r ( 1 n x 1 ) + V a r ( 1 n x 2 ) + ⋯ + V a r ( 1 n x n )    = 1 n 2 ( σ 2 + σ 2 + ⋯ + σ 2 )    = σ 2 n Var(\frac{1}{n}x_1+\frac{1}{n}x_2+\frac{1}{n}x_3+\cdots+\frac{1}{n}x_n)\\ \;\\ =Var(\frac{1}{n}x_1)+Var(\frac{1}{n}x_2)+\cdots+Var(\frac{1}{n}x_n)\\ \;\\ =\frac{1}{n^2}(\sigma^2+\sigma^2+\cdots+\sigma^2)\;\\ =\frac{\sigma^2}{n} Var(n1x1+n1x2+n1x3+⋯+n1xn)=Var(n1x1)+Var(n1x2)+⋯+Var(n1xn)=n21(σ2+σ2+⋯+σ2)=nσ2
需要强调的是, 在证明的过程中, 非常重要的条件是样本数量足够大且满足大数定律, 同时所抽取的样本之间相互独立, 缺一不可. 我们再次检查上文推导过程中的关键一步 :
V a r ( x ˉ ) = 1 m ∑ i = 1 m ( x ˉ i − μ ) 2 = σ n Var(\bar x)=\frac{1}{m}\sum_{i=1}^{m}(\bar x_i-\mu)^2=\frac{\sigma}{n} Var(xˉ)=m1i=1∑m(xˉi−μ)2=nσ
实际生产环境中, 我们往往不能得到满足大数定律的足够多的样本, 也不能保证样本独立, Batch Normalization 的计算方法在理论上会有潜在的误差. 每批数量越少, 误差越大.
使用滑动平均法, 可以在一定程度上缓和样本异常值所带来的误差, 而超参数动量 (momentum) 则给与了某种人为主动修正的可能性.
全文完.