Batch Normalization(批标准化):它的功能是使得输入的X数据符合同一分布,从而使得训练更加简单、快速。一般来讲,Batch Normalization会放在卷积层后面,即卷积 + BN层 + 激活函数。
图解为什么需要batch Normalization
神经网络输入层:
隐藏层:
Bn层在进行前三步后,会引入γ和β变量,对输入进来的数据进行缩放和平移。
γ和β变量是网络参数,是可学习的。
引入γ和β变量进行缩放平移可以使得神经网络有自适应的能力,在标准化效果好时,尽量不抵消标准化的作用,而在标准化效果不好时,尽量去抵消一部分标准化的效果,相当于让神经网络学会要不要标准化,如何折中选择。
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
if not is_training:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
else:
mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * x_hat + beta
return Y, moving_mean, moving_var
class BatchNorm2d(nn.Module):
def __init__(self, num_features):
super(BatchNorm2d, self).__init__()
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.register_buffer('moving_mean', torch.zeros(shape))
self.register_buffer('moving_var', torch.ones(shape))
def forward(self, x):
if self.moving_mean.device != x.device:
self.moving_mean = self.moving_mean.to(x.device)
self.moving_var = self.moving_var.to(x.device)
y, self.moving_mean, self.moving_var = batch_norm(self.training,
x, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return y
https://blog.csdn.net/weixin_44791964/article/details/114998793?spm=1001.2014.3001.5501