How to correctly train with tf.keras.layers.BatchNormalization: Is there still a tf.GraphKeys.UPDATE_OPS dependency?

时间:2019-04-17 01:37:22

标签: python tensorflow deep-learning batch-normalization

My goal is how to correctly train with batch normalizations layers in TensorFlow (TensorFlow version 1.13.1 for Python in Graph Mode) using the recommended tf.keras.layers.BatchNormalization class (https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization).

An older recommended approach was to use tf.layers.batch_normalization. The documentation (https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization) indicates that it is currently deprecating instead in favor of tf.keras.layers.BatchNormalization.

While using the older class, the documentation indicates we must explicitly add dependency on the mean and variance update operations, which would otherwise be dangling nodes outside from any dependencies in training operations:

update_ops_including_from_batch_norms  =  tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
   my_optimizer = tf.super_cool_optimizer(loss)

My question: Is this explicit dependence on UPDATE_OPS still needed when training batch norms in TF 1.13 with tf.keras.layers.BatchNormalization? I don't see this mentioned in the documentation, however, I would be much more comfortable if someone knew for sure (and even better if can point to official documentation or code) that these operation dependencies are implicitly taken care of.

2 个答案:

答案 0 :(得分:0)

答案是否定的,不是必需的。当前文档中的https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization中提到了此内容。

另一方面,由于tf.layers.keras.BatchNormalization对我来说似乎很麻烦,因此我目前正在使用tf.layers.BatchNormalization显式依赖UPDATE_OPS(TF版本1.10)。使用tf.layers.keras.BatchNormalization时我的模型验证失败。也许它已在最近的更新中得到修复。

答案 1 :(得分:0)

根据the official doc

  

尤其不应使用tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)(请查阅tf.keras.layers.batch_normalization文档)。

您应该按照以下步骤从tf.keras.layers.BatchNormalization()收集更新操作。参见discussion

...
batch_normalizer = tf.keras.layers.BatchNormalization()
normalized_tensor = batch_normalizer(raw_tensor, training=is_training)
total_loss = ... # Get loss tensor
optimizer = tf.train.AdamOptimizer()
minimization_op = optimizer.minimize(total_loss, global_step=tf.get_global_step())
# Get "regular update ops"
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Get BatchNorm updates
update_ops.extend(batch_normalizer.updates)
# Group two sets of operations to form a train_op
train_ops = tf.group([minimization_op, update_ops])