如是我闻: 在深度学习模型中,Batch Normalization(简称 BN)是一种常用的技术,能有效加速训练并提高模型的稳定性。BN 通过对 mini-batch 数据进行归一化,使每层的输入数据分布保持稳定,从而缓解梯度消失/爆炸问题,同时减少对权重初始化的敏感性。
在本篇文章中,我们将从 数学推导、代码实现 和 经验总结 三个方面,详细探讨 Batch Normalization 的工作原理,并分析为什么 BN 应该放在线性层(Linear)和激活函数(ReLU)之间,而不是之后。
假设某层的输入数据为 X X X,它的形状是 (batch_size, num_features),其中每一列代表某个特征在 mini-batch 内的分布,BN 的计算过程如下:
对 batch 内所有样本的某个特征维度 j j j 计算均值:
μ j = 1 m ∑ i = 1 m x i , j \mu_j = \frac{1}{m} \sum_{i=1}^{m} x_{i,j} μj=m1i=1∑mxi,j
其中:
σ j 2 = 1 m ∑ i = 1 m ( x i , j − μ j ) 2 \sigma_j^2 = \frac{1}{m} \sum_{i=1}^{m} (x_{i,j} - \mu_j)^2 σj2=m1i=1∑m(xi,j−μj)2
为了防止除零,加上一个很小的数 ϵ \epsilon ϵ:
σ j 2 ^ = 1 m ∑ i = 1 m ( x i , j − μ j ) 2 + ϵ \hat{\sigma_j^2} = \frac{1}{m} \sum_{i=1}^{m} (x_{i,j} - \mu_j)^2 + \epsilon σj2^=m1i=1∑m(xi,j−μj)2+ϵ
x ^ i , j = x i , j − μ j σ j 2 + ϵ \hat{x}_{i,j} = \frac{x_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} x^i,j=σj2+ϵxi,j−μj
这样, x ^ \hat{x} x^ 具有均值 0、方差 1 的特性。
BN 不是单纯地归一化,而是 允许网络学习合适的缩放和平移参数,以恢复模型的表达能力:
y i , j = γ j x ^ i , j + η j y_{i,j} = \gamma_j \hat{x}_{i,j} +\eta_j yi,j=γjx^i,j+ηj
其中:
y i , j = γ j x i , j − μ j σ j 2 + ϵ + η j y_{i,j} = \gamma_j \frac{x_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} +\eta_j yi,j=γjσj2+ϵxi,j−μj+ηj
这些步骤的作用是:
我们用 PyTorch 搭建一个简单的 MLP(多层感知机),并在隐藏层之间加入 Batch Normalization:
import torch
import torch.nn as nn
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(2, 4) # 输入层 -> 隐藏层 1
self.bn1 = nn.BatchNorm1d(4) # BN 归一化
self.fc2 = nn.Linear(4, 3) # 隐藏层 1 -> 隐藏层 2
self.bn2 = nn.BatchNorm1d(3) # BN 归一化
self.fc3 = nn.Linear(3, 2) # 隐藏层 2 -> 输出层
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # 先归一化
x = torch.relu(x) # 再激活
x = self.fc2(x)
x = self.bn2(x) # 先归一化
x = torch.relu(x) # 再激活
x = self.fc3(x) # 最后输出层不使用 BN
return x
# 创建模型
model = SimpleMLP()
# 生成一个 batch 的数据,假设 batch_size=5,每个样本有 2 个特征
x = torch.randn(5, 2)
# 运行前向传播
output = model(x)
print(output)
关键点:
nn.BatchNorm1d(隐藏层大小)
直接加在 Linear
层之后。BatchNorm1d
作用于 mini-batch 的每个特征维度,而不是单个样本。很多人会问:BN 是否应该放在激活函数之后?
答案是:不应该,BN 应该放在 Linear 层之后,ReLU 之前,主要有以下几个原因:
Linear -> BN
之后,数据已经被标准化成 均值 0,方差 1。放在 ReLU 之前(正确) | 放在 ReLU 之后(错误) | |
---|---|---|
数据分布 | 均值 0,方差 1 | 均值会偏移,方差变小 |
梯度流 | 让梯度更稳定 | 负值部分梯度消失 |
收敛速度 | 收敛更快 | 训练可能不稳定 |
因此,在 MLP 或 CNN 里,Batch Normalization 一般放在线性层(Linear/Conv)之后,ReLU 之前。
Linear
层后,ReLU
之前。以上