Tensorflow低级API,批量归一化问题

时间:2019-03-07 13:24:50

标签: tensorflow keras batch-normalization tf.keras

tf.layers.batch_normalization文档说它将在将来的版本中删除,应该由tf.keras.layers.BatchNormalization代替,但是我找不到使用tensorflow低级api替换功能的方法。

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

输出:

[<tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>,
<tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>]

如果我们改为按照文档中的建议使用keras

bn = tf.keras.layers.BatchNormalization(axis=-1)(tf.constant([0.0]), training=True)

我们得到一个空的输出:

[]

由于UPDATE_OPS为空,因此在使用keras进行训练期间,该模型无法更新批次规范化moving_avg_mean和moving_avg_variance(导致更大的测试错误)。任何建议如何解决这个问题,不胜感激!

上面的示例摘自一本关于如何使用tf.layers.batch_normalization

的文章。

0 个答案:

没有答案