如何正确将tf.Print()节点附加到张量流图?

时间:2017-05-19 16:27:40

标签: python tensorflow

据我所知,我应该可以通过这样的方式在我的图表中添加一个打印操作符:

a = nn_ops.softmax(s)

a = tf.Print(a, [tf.shape(a)], message="This is shape a: ")

并且在执行图表时,应该打印a的形状。但是,这个语句对我没有产生输出(我正在运行seq2seq tensorflow教程,这个softmax属于注意函数,所以它肯定是执行的)。

获取输出,而不是我做这样的事情:

ph = tf.placeholder(tf.float32, [3,4,5,6])

ts = tf.shape(ph)

tp = tf.Print(ts, [ts], message="PRINT=")

sess = tf.Session()

sess.run(tp)

但是,在我的实例中,sess.run()在seq2seq_model.py中调用,如果我尝试在注意函数中执行sess.run(a),则tensorflow会抱怨:

You must feed a value for placeholder tensor 'encoder0' with dtype int32

但我在代码中此时无法访问输入Feed。我怎样才能解决这个问题?

1 个答案:

答案 0 :(得分:1)

如果您只是想知道张量形状,通常可以在不运行图形的情况下推断出它。那你就不需要tf.Print

例如,在第二个代码片段中,您可以使用:

ph = tf.placeholder(tf.float32, [3,4,5,6])
print(ph.get_shape())

如果您想在输入尺寸上看到依赖的形状(使用tf.shape),或者您希望在依赖上看到一个值依赖如果没有提供输入数据,则无法输入。

例如,如果您训练的模型中xy分别是您的样本和标签,则无法在不提供成本的情况下计算成本。

如果您有以下代码:

predictions = ...
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_output, y_train))
cost = tf.Print(cost, [tf.shape(cost)], message='cost:')

尝试在不提供占位符值的情况下对其进行评估将不起作用:

sess.run(cost)
# error, no placeholder provided

然而,这将按预期工作:

sess.run(cost, {x: x_train, y: y_train})

关于你的第一个代码片段。为了工作,需要执行tf.Print节点才能打印消息。我怀疑在你的情况下,在进一步的计算过程中不使用打印节点。

例如,以下代码不会产生输出:

import tensorflow as tf

a = tf.Variable([1., 2., 3])
b = tf.Variable([1., 2., 3])

c = tf.add(a, b)

# we create a print node, but it is never used
a = tf.Print(a, [tf.shape(a)], message='a.shape: ')

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c))

但是,如果您反转行以便在计算期间使用打印节点,您将看到输出:

a = tf.Print(a, [tf.shape(a)], message='a.shape: ')
# now c depends on the tf.Print node
c = tf.add(a, b)