如何使用tf.Print打印部分张量?

时间:2017-10-29 13:39:06

标签: tensorflow

在TensorFlow中,当张量具有较大的尺寸时,仅为了调试目的而打印张量的某些部分是有用的,例如2维矩阵的对角线。我只知道如何打印整个张量如下:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a], "print entire a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)

上面的代码将打印整个'a'张量。但我不知道如何打印'a'的一部分。例如,如果我只想在不执行sess.run(a)的情况下打印[0,0],则以下代码将不起作用:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a[0,0] = tf.Print(a[0,0], [a[0,0]], "print part of a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)    

1 个答案:

答案 0 :(得分:2)

来自文档:

Print(
    input_,
    data,
    message=None,
    first_n=None,
    summarize=None,
    name=None
)
     

打印张量列表。这是具有副作用的身份认证   在评估时打印data

a应该保持不变,在data参数中,您应该放置需要打印的张量。

import tensorflow as tf

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a[0, 0]], "Print part of a\n", summarize=100000)
b = a + 1.
ret = sess.run(b)