我对以下代码段的运行方式感到有些困惑。
import tensorflow as tf
x = tf.Variable(0)
init_op = tf.initialize_all_variables()
modify_op = x.assign(5)
with tf.Session() as sess:
sess.run(init_op)
print(sess.run(x))
x += 3
print(sess.run(x))
sess.run(init_op) # Trying to initialize x once again to 0
print(sess.run(x)) # Gives out 3, which leaves me confused.
print(sess.run(modify_op))
print(sess.run(x)) # Gives out 8, even more confusing
这是输出:
0
3
3
5
8
行x += 3
不属于默认图表吗?或其他事情正在发生?一些帮助将不胜感激,谢谢!
答案 0 :(得分:3)
您的x
变量正在被
x += 3
但不是你想象的那样。 tensorflow库代码覆盖了+
,因此您可以有效地交换内容x
以获得新的TF张量(旧的张量仍然在图中,只是x现在指向一个新的张量)。写出来像这样:
x = tf.Variable(0) + 3
更清楚的是发生了什么。另外,插入一些打印语句。 。
x = tf.Variable(0)
print(x)
# <tensorflow.python.ops.variables.Variable object at 0x1018f5d68>
x += 3
print(x)
# Tensor("add:0", shape=(), dtype=int32)
如果x
的内容对您很重要,那么如果您想稍后使用变量名跟踪/显示x
,请避免重新分配到x
。或者,如果您没有指向它的方便的Python变量,您可以随时命名张量并直接从图中获取。重要的是习惯于TF变量和Python变量之间的分离。
实际上看到你正在尝试分配和重新设置TF变量,需要使用TF赋值运算符:
import tensorflow as tf
x = tf.Variable( 0 )
with tf.Session() as session:
session.run( tf.initialize_all_variables() )
print( x.eval() )
session.run( x.assign( x + 3 ) )
print( x.eval() )
session.run( tf.initialize_all_variables() )
print( x.eval() )
输出:
0
3
0
正如你所料。