如何在tensorflow中打印局部张量?

时间:2018-01-14 23:38:06

标签: python debugging variables tensorflow local

我想在我的程序中打印张量,以便在评估后查看其内部值。然而,问题是张量是在函数内声明的。为了更好地理解我的问题,这里有一些示例代码可以更好地解释我想要做的事情:

a = tf.Variable([[2,3,4], [5,6,7]])
b = tf.Variable([[1,2,2], [3,3,3]])

def divide(a,b):
    with tf.variable_scope('tfdiv', reuse=True):
        c = tf.divide(a,b, name='c')
    # Cannot print(c) here, as this will only yield tf info on c
    return c

d = divide(a,b)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(d)
    sess.run(tf.get_variable('tfdiv/c:0').eval(session=sess))

以前,我已经能够进行打印(c.eval(session = sess)),但由于c是函数内部的局部变量,所以不起作用。从上面的代码中可以看出,我试图使用tensorflow的变量范围来访问变量然后对其进行评估。不幸的是,这会导致错误消息:

ValueError: Shape of a new variable (tfdiv/c:0) must be fully defined, but 
instead was <unknown>.

我尝试使用reuse = True标志,但我仍然得到相同的错误。有关如何解决这个问题的任何想法?如果有一个print(c)等价物可以放入除法函数,那就好了,如上面的代码所示。

2 个答案:

答案 0 :(得分:1)

这将实现您想要的目标:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(d))

或者,您可以将最后一行替换为:

print(sess.run(tf.get_default_graph().get_tensor_by_name('tfdiv/c:0')))

答案 1 :(得分:1)

了解Python端代码和TensorFlow端代码之间的区别非常重要。在python中,您只设置图形:d = divide(a, b)创建类似于: enter image description here

您设置了一个节点(正方形),用于划分节点 ab中的数据。它不会马上分开它们!请注意,黑色表示 python 变量名称,灰色表示 TensorFlow 节点名称 1 ab也有一些默认名称,如果您没有指定它们。灰色&#34; c&#34;您使用name='c'指定的。局部变量c和全局变量d(Python)都引用相同的操作(节点)。

这就是为什么如果你说print(d)只打印有关该节点的信息。设置图表后,执行sess.run(d)会在TensorFlow端{/ 1>上运行d 中节点所需的所有节点。然后它检索结果并使其在python端上可用作numpy数组。

您可以使用tf.Print(input, data)在TF端打印张量 。请注意,这是一个操作(图中的一个节点),它对input张量没有任何作用,只是将其传递过来,同时还在data中打印所有内容。

在您的情况下,您可以在张量流方面使用Print,如下所示:

def divide(a,b):
    with tf.variable_scope('tfdiv', reuse=True):
        c = tf.divide(a,b, name='c')
        cp = tf.Print(c, [c], message='Value of c: ', name='P')

    return cp

这有效地在图中添加了另一个节点(在TF侧名为P):

Graph of tensors illustrating the above

现在每次评估时都会打印操作 c的值。请注意,每次评估其中一个依赖项时也会打印它,例如,如果您稍后执行e = d + 1,则在评估e时,它需要d,这表示打印节点(从函数divide返回)。

最后请注意,如果您在Jupyter笔记本中执行此操作,则打印件将显示在笔记本服务器的终端中。现在这个细节并不重要:)。

1 默认添加:0,以便您可以使用name_of_op:0检索任何张量。操作名称(tfdiv/c)和张量名称(tfdiv/c:0)之间的区别。