澄清tf.Session()范围TensorFlow

时间:2018-10-16 09:42:28

标签: python tensorflow

我有一个python程序,其中已经定义了一个网络,并且像往常一样,我正在我拥有的函数中对其进行训练

with tf.Session() as sess:
...
for epoch in xrange(num_epochs):
...
for n in xrange(num_batches):
       _, c = sess.run([optimizer, loss], feed_dict={....

在损失函数中,我必须做很多工作才能得到损失,尤其是我必须在张量内取最大并用它来做事情。这是一个例子

values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)
max_values = tf.reduce_max(values) # Tensor...
...
max_values行中的

,如果我使用调试,它说它是张量而不是值,因此,如果我以这种方式更改代码,将传递给函数的代码传递给上一段代码中的会话

values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)
max_values = sess.run(tf.reduce_max(values)) # 2.0
...

有效。但是这个损失函数已经在会话范围内了,所以我的问题是为什么结果是张量而不是数字?有没有一种方法可以在不将会话传递给损失函数的情况下获取值?

2 个答案:

答案 0 :(得分:2)

根据documentation

  

TensorFlow使用tf.Session类来表示之间的连接   客户端程序--通常是Python程序,尽管类似   界面支持其他语言-和C ++运行时。

这意味着当您执行values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)时,您只是在tensorflow图中插入一个节点!由于Python是低级C ++运行时的高级API,因此实际上您需要一个会话来评估此低级运行时中的Python代码。

这就是为什么每次需要计算或评估Tensorflow变量/方法/常量/等时,都需要在带有tf.Session().run(yournode)的Session中运行它的原因

我希望它能帮助

答案 1 :(得分:1)

使用Tensor.eval()函数将Tensor转换为其值。 在以下示例中,您可以获取max_values张量的值。

def loss():
  values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)
  max_values = tf.reduce_max(values)
  print (max_values.eval())

with tf.Session() as sess:
  loss()

如果您在Session范围之外调用loss(),则会出现错误。

还可以使用“急切”执行模式。 https://www.tensorflow.org/tutorials/eager/eager_basics