Batch Normalization 作为深度学习中一个常用层,掌握其的使用非常重要,本博客将梳理下各种 Batch Normalization API 的一些坑。
如果你对 Batch Normalization 还不清楚,可以查看之前的博客 Inception v2/BN-Inception:Batch Normalization 论文笔记 来学习下 Batch Normalization。
在 TensorFlow 中,Batch Normlization 有以下几个实现(API):
tf.layers.BatchNormalization
tf.layers.batch_normalization
tf.keras.layers.BatchNormalization
tf.nn.batch_normalization
上述四个 API 按层次可分为两类:
tf.layers.BatchNormalization
、tf.layers.batch_normalization
、tf.keras.layers.BatchNormalization
tf.nn.batch_normalization
上述四个 API 按行为可以分为两类:
tf.layers.BatchNormalization
、tf.layers.batch_normalization
、tf.nn.batch_normalization
tf.keras.layers.BatchNormalization
Batch Normalization 中需要计算移动平均值,所以 BN 中有一些 update_ops
,在训练中需要通过 tf.control_dependencies()
来添加对 update_ops
的调用:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = tf.train.AdamOptimizer(lrn_rate).minimize(cost)
在使用 update_ops
前,需要将 BN 层的 update_ops 添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中。
tf.layers.BatchNormalization
和 tf.layers.batch_normalization
会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中(注:training
参数为 True 时,才会添加,False 时不添加)。
你以为躲过头号坑就完了,哈哈哈,还有二号坑!
tf.keras.layers.BatchNormalization
不会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中。所以在 TensorFlow 训练 session 中使用 tf.keras.layers.BatchNormalization
时,需要手动将 keras.BatchNormalization 层的 updates
添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中。
   \;
检测是否添加了 updates 的方法:
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Keras 是以面向对象的方式编写的,所以我们可以通过以下方式获取 BN 层的 updates,并将其添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中:
x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 调用后updates属性才会有内容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
那么如果我已经以下列形式编写完了代码,该怎么将 BN 的 updates 添加到 tf.GraphKeys.UPDATE_OPS
这个 collection 中呢?
y = tf.keras.layers.BatchNormalization()(x, training=True)
程序都这样编写了,难道我们就必须更改所有的 BN 层的编写方式吗?
答:不要改代码了,直接从 ops 中过滤出 updates_ops,然后添加到指定的 collectuon 中即可。代码如下:
ops = tf.get_default_graph().get_operations()
update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_ops)
通过一些技术,我们已经将 BN 层的 updates 添加到了指定的 collection 中,然后按照 1. 的方式处理下即可。
经过上面的努力,以后终于可以在 TF-session 中顺畅地使用 tf.keras.layers
API 了,真高兴!(不要忘记点赞哦)
3. 附录
其他关于 TF - BN 使用的一些资料: