TensorFlow:多次初始化变量

时间:2016-11-08 07:35:37

标签: python tensorflow

我对以下代码段的运行方式感到有些困惑。

import tensorflow as tf

x = tf.Variable(0)
init_op = tf.initialize_all_variables()
modify_op = x.assign(5)

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(x))
    x += 3
    print(sess.run(x))
    sess.run(init_op) # Trying to initialize x once again to 0
    print(sess.run(x)) # Gives out 3, which leaves me confused.
    print(sess.run(modify_op))
    print(sess.run(x)) # Gives out 8, even more confusing

这是输出:
    0
    3
    3
    5
    8

x += 3不属于默认图表吗?或其他事情正在发生?一些帮助将不胜感激,谢谢!

1 个答案:

答案 0 :(得分:3)

您的x变量正在被

更改
x += 3

但不是你想象的那样。 tensorflow库代码覆盖了+,因此您可以有效地交换内容x以获得新的TF张量(旧的张量仍然在图中,只是x现在指向一个新的张量)。写出来像这样:

x = tf.Variable(0) + 3

更清楚的是发生了什么。另外,插入一些打印语句。 。

x = tf.Variable(0)
print(x)
# <tensorflow.python.ops.variables.Variable object at 0x1018f5d68>

x += 3
print(x)
# Tensor("add:0", shape=(), dtype=int32)

如果x的内容对您很重要,那么如果您想稍后使用变量名跟踪/显示x,请避免重新分配到x。或者,如果您没有指向它的方便的Python变量,您可以随时命名张量并直接从图中获取。重要的是习惯于TF变量和Python变量之间的分离。

实际上看到你正在尝试分配和重新设置TF变量,需要使用TF赋值运算符:

import tensorflow as tf

x = tf.Variable( 0 )

with tf.Session() as session:
    session.run( tf.initialize_all_variables() )
    print( x.eval() )

    session.run( x.assign( x + 3 ) )
    print( x.eval() )

    session.run( tf.initialize_all_variables() )
    print( x.eval() )

输出:

0
3
0

正如你所料。