pytorch(二)--batch normalization的理解

来源

《Batch Normalization: Accelerating Deep Network Training by  Reducing Internal Covariate Shift》这是2015年的文献,提出的就是Batch Normalization层的概念。

为什么需要它?

神经网络学习的过程本质就是数据分布的学习过程。从最开始的感知机到全连接实现或非运算等都是如此。在CNN网络中,卷积层输出的数据的分布依赖于卷积层的输入和kernel的参数,每一次后向传播都会更新参数的值,这样每次更新后,卷积层输出的数据的分布就在不断变化,以至于后一层需要拟合这种变化,所以训练速度很慢。我们把网络中间层在训练过程中,数据分布的改变称之为:“Internal  Covariate Shift”。Paper所提出的算法,就是要解决在训练过程中,中间层数据分布发生改变的情况,于是就有了Batch  Normalization,这个牛逼算法的诞生。

初识Batch Normalization

既然有上述问题,我们可以在网络的每一层输入的时候,又插入了一个归一化层,也就是先做一个归一化处理,然后再进入网络的下一层。不过文献归一化层,可不像我们想象的那么简单,它是一个可学习、有参数的网络层。

Batch Normalization层实现

这里可能会有这样一个问题,A层的输出数据,需要经过归一化处理,然后传递给B层,通过A层获得的特征数据的分布就会被搞坏,那么如何处理这样的问题呢?在Batch Normalization里面就引入了可学习的参数,公式如下:

                                                              

这样当

                                                            

时就会还原分布。完整公式如下:

                                                              pytorch(二)--batch normalization的理解_第1张图片

其中的epsilon是防止方差趋近于0或等于0,一般取值较小就行。那么在测试时,没有机会能够算出平均值和方差怎么办呢?我们可以使用所有训练数据在每一层输出后的均值和无偏估计作为这一层后面的batch normalization层的均值和方差。公式如下:

根据文献说,BN可以应用于一个神经网络的任何神经元上。文献主要是把BN变换,置于网络激活函数层的前面。在没有采用BN的时候,激活函数层是这样的:

z=g(Wu+b)

也就是我们希望一个激活函数,比如s型函数s(x)的自变量x是经过BN处理后的结果。因此前向传导的计算公式就应该是:

z=g(BN(Wu+b))

其实因为偏置参数b经过BN层后其实是没有用的,最后也会被均值归一化,当然BN层后面还有个β参数作为偏置项,所以b这个参数就可以不用了。因此最后把BN层+激活函数层就变成了:

z=g(BN(Wu))
 

在PyTorch中BatchNorm2d详解

有了上述理论基础,我们来看看在卷积神经网络中Batch Normalization的应用。

函数的参数有:

num_features: 来自期望输入的特征数(即channel数),该期望输入的大小为'batch_size x num_features x height x width'

eps: 为保证数值稳定性(分母不能趋近或取0),给分母加上的值。默认为1e-5。

momentum: 动态均值和动态方差所使用的动量。默认为0.1。(这个具体是什么还没有试验过,留个坑。)

affine: 一个布尔值,当设为true,给该层添加可学习的仿射变换参数

在CNN中batch normalization的实现是基于当前batch中所有样本的同一个channel进行的。看下面图就非常清晰了

pytorch(二)--batch normalization的理解_第2张图片

batch_size=4,channel=3,h=2,w=2。因为在batch的一个channel中所有元素的均值和方差一样,根据上述理论,beta和gama在一个channel中共用,即一个batch normalization层中的参数个数=2*channel数。

 

文章参考自:

https://blog.csdn.net/yimingsilence/article/details/80261383

https://www.cnblogs.com/yongjieShi/p/9332655.html

加入了自己的一些理解。

你可能感兴趣的:(PyTorch)