如何获取变量的当前值?

时间:2015-11-12 19:11:52

标签: python tensorflow

假设我们有一个变量:

x = tf.Variable(...)

可以使用assign()方法在培训过程中更新此变量。

获取变量当前值的最佳方法是什么?

我知道我们可以使用它:

session.run(x)

但我担心这会触发一系列操作。

在Theano,你可以做到

y = theano.shared(...)
y_vals = y.get_value()

我在TensorFlow中寻找相同的东西。

4 个答案:

答案 0 :(得分:34)

获取变量值的唯一方法是在session中运行它。在FAQ it is written中:

  

Tensor对象是操作结果的符号句柄,   但实际上并没有保存操作输出的值。

所以TF相当于:

import tensorflow as tf

x = tf.Variable([1.0, 2.0])

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    v = sess.run(x)
    print(v)  # will show you your variable.

具有init = global_variables_initializer()的部分非常重要,应该完成以初始化变量。

另外,如果您在IPython工作,请查看InteractiveSession

答案 1 :(得分:20)

通常,session.run(x)将仅评估计算x所需的节点而不评估任何其他节点,因此如果要检查变量的值,它应该相对便宜。

了解更多背景信息https://stackoverflow.com/a/33610914/5543198

答案 2 :(得分:16)

tf.Print可以简化您的生活!

tf.Print将打印在您的代码被评估时代码中调用tf.Print行时打印的张量值。

例如:

import tensorflow as tf
x = tf.Variable([1.0, 2.0])
x = tf.Print(x,[x])
x = 2* x

tf.initialize_all_variables()

sess = tf.Session()
sess.run()
  

[1.0 2.0]

因为它在x行的时刻打印tf.Print的值。如果你做了

v = x.eval()
print(v)

你会得到:

  

[2.0 4.0]

因为它会给你x的最终值。

答案 3 :(得分:-1)

他们在tf.Variable()中取消了tensorflow 2.0.0

如果您想从tensor(ie "net")中提取值,可以使用它,

net.[tf.newaxis,:,:].numpy().