如何在Tensorflow中使用迭代器时正确设置is_training

时间:2018-02-08 13:23:00

标签: python tensorflow machine-learning deep-learning

我有一个is_training变量,我仍然必须在我的main.py文件中定义如下:

is_training = tf.placeholder(tf.bool, name='is_training')

然后我从另一个文件(x输入)调用我的inference方法:

test = net.inference(x, is_training)

最后,在我的会话中,我这样做:

sess.run(test, feed_dict={x: test_x, is_training: True})

但是,我想将is_training放在我的推理功能中。这有可能吗?

1 个答案:

答案 0 :(得分:0)

如果您仅在推理中需要is_training,我建议您tf.placeholder_with_default。这样,您只能在net.inference()方法中定义它,并且在会话中传递:

self.is_training = tf.placeholder_with_default(False, shape=(),name='is_training')

当您必须将其更改为True时,您可以执行以下操作:

sess.run(test, feed_dict={x: test_x, net.is_training: True})

请注意,tensorflow并不关心python变量或字段的范围。定义占位符后,它就会出现在图表中。