TensorFlow 中 Batch Normalization API 的一些坑

版权声明:本文为博主原创文章,如需转载,请注明出处, 谢谢。
https://blog.csdn.net/u014061630/article/details/85104491

Batch Normalization 作为深度学习中一个常用层,掌握其的使用非常重要,本博客将梳理下各种 Batch Normalization API 的一些坑。

如果你对 Batch Normalization 还不清楚,可以查看之前的博客 Inception v2/BN-Inception:Batch Normalization 论文笔记 来学习下 Batch Normalization。

在 TensorFlow 中,Batch Normlization 有以下几个实现(API):

  • tf.layers.BatchNormalization
  • tf.layers.batch_normalization
  • tf.keras.layers.BatchNormalization
  • tf.nn.batch_normalization

上述四个 API 按层次可分为两类:

  • 高阶 APItf.layers.BatchNormalizationtf.layers.batch_normalizationtf.keras.layers.BatchNormalization
  • 低阶 APItf.nn.batch_normalization

上述四个 API 按行为可以分为两类:

  • TensorFlow APItf.layers.BatchNormalizationtf.layers.batch_normalizationtf.nn.batch_normalization
  • Keras APItf.keras.layers.BatchNormalization

1. 头号大坑 ----- 没有调用 update_ops

Batch Normalization 中需要计算移动平均值,所以 BN 中有一些 update_ops,在训练中需要通过 tf.control_dependencies() 来添加对 update_ops 的调用:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = tf.train.AdamOptimizer(lrn_rate).minimize(cost)

在使用 update_ops 前,需要将 BN 层的 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中。

tf.layers.BatchNormalizationtf.layers.batch_normalization 会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中(注:training 参数为 True 时,才会添加,False 时不添加)。

你以为躲过头号坑就完了,哈哈哈,还有二号坑!

2. 二号坑(已修复) ----- 在 TensorFlow 训练 session 中使用 tf.keras.layers.BatchNormalization

tf.keras.layers.BatchNormalization 不会自动将 update_ops 添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中。所以在 TensorFlow 训练 session 中使用 tf.keras.layers.BatchNormalization 时,需要手动将 keras.BatchNormalization 层的 updates 添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中。
   \;

检测是否添加了 updates 的方法:

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

Keras 是以面向对象的方式编写的,所以我们可以通过以下方式获取 BN 层的 updates,并将其添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中:

x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 调用后updates属性才会有内容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)

那么如果我已经以下列形式编写完了代码,该怎么将 BN 的 updates 添加到 tf.GraphKeys.UPDATE_OPS 这个 collection 中呢?

y = tf.keras.layers.BatchNormalization()(x, training=True)

程序都这样编写了,难道我们就必须更改所有的 BN 层的编写方式吗?

答:不要改代码了,直接从 ops 中过滤出 updates_ops,然后添加到指定的 collectuon 中即可。代码如下:

ops = tf.get_default_graph().get_operations()
update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_ops)

通过一些技术,我们已经将 BN 层的 updates 添加到了指定的 collection 中,然后按照 1. 的方式处理下即可。

经过上面的努力,以后终于可以在 TF-session 中顺畅地使用 tf.keras.layers API 了,真高兴!(不要忘记点赞哦)

3. 附录

其他关于 TF - BN 使用的一些资料:

  • tensorflow batch_normalization的正确使用姿势

你可能感兴趣的:(TensorFlow教程)