我有以下代码段:
import tensorflow as tf
import numpy as np
# batch x time x events x dim
batch = 2
time = 3
events = 4
tensor = np.random.rand(batch, time, events)
tensor[0][0][2] = 0
tensor[0][0][3] = 0
tensor[0][1][3] = 0
tensor[0][2][1] = 0
tensor[0][2][2] = 0
tensor[0][2][3] = 0
tensor[1][0][3] = 0
def cum_sum(prev, cur):
non_zeros = tf.equal(cur, 0.)
tf.Print(non_zeros, [non_zeros], "message ")
tf.Print(cur, [cur])
return cur
elems = tf.constant([1,2,3],dtype=tf.int64)
#alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
cum_sum_ = tf.scan(cum_sum, tensor)
s = tf.Session()
s.run(cum_sum_)
我在传递给tf.Print
的函数中有两个tf.scan
语句,但是运行累积总和时,没有得到任何打印语句。我在做错什么吗?
答案 0 :(得分:2)
tf.Print不能那样工作。打印节点必须位于图中才能突出。我强烈建议您查看this教程,以了解如何使用它。
如果您有任何问题,请随时提问。
答案 1 :(得分:0)
尽管,@ Ignacio Peletier的答案是完全有帮助的,它取决于外部站点。我发现没有人在这里提到这一点很奇怪。无论如何,为了遵守规则,我还在这里提供了答案(无需访问外部链接),并提供了一些额外的信息:
要让tf.Print
实际打印某些内容,它应该属于图形。为此,您只需重复使用tf.Print
中返回的Tensor,并将其传递给下一个操作。 将其传递到下一个操作对于实际显示消息至关重要。
因此,使用您的示例可以将其重写:
def cum_sum(prev, cur):
non_zeros = tf.equal(cur, 0.)
non_zeros = tf.Print(non_zeros, [non_zeros], "message ")
cur = tf.Print(cur, [cur])
return cur
,它将打印cur
,但不会打印non_zeros
,因为此节点是悬空的。另外,我不确定是否可以重写您的代码,以便显示non_zeros
,因为在您定义它们之后,它实际上并没有在您的代码中使用(因此,tensorflow在非急切模式下只会忽略它)。 / p>
通常对于(非悬挂)节点,我们在代码中的某个地方将其命名为result
:
result = tf.some_op_here()
# the following is actually displaying the contents of result
result = tf.Print(result, [result], 'some arbitrary message here to be displayed before the actual message')
tf.another_op_here_using_result(result)
将打印result
的(可能是串联的)内容。为了控制显示的信息量,您还可以使用参数summarize=x
,其中x
是显示的参数数量。