我在slim.batch_norm
中设置了占位符is_training params,如下所示:
is_traing_ph = tf.placeholder(tf.bool)
output = slim.batch_norm(
input,
activation_fn=activation_fn,
is_training=is_training_ph,
updates_collections=None,
scale=scale,
scope=scope)
像这样喂它:
sess.run(train_op, feed_dict={is_training_ph:False}
当我用is_training_ph提供True时,程序没问题,但是当我用is_traing_ph提供False时,程序会抛出OOM错误。
而且,当我不使用这样的占位符时:
output = slim.batch_norm(
input,
activation_fn=activation_fn,
is_training=True,
updates_collections=None,
scale=scale,
scope=scope)
这不是问题。
这是我的完整测试代码和日志跟踪: https://gist.github.com/xxxzhi/8fc8f840a8ec07fdbae7c2fc2c77b3da
有谁知道原因?这是slim.batch_norm
的错误吗?
GPU的内存是12G。 CUDA 8,tensorflow1.2,tensorflow1.3
提前致谢。