我是tensorflow的新手,我试图了解它的行为;我正在尝试在会话范围之外定义所有操作,以优化计算时间。 在以下代码中:
import tensorflow as tf
import numpy as np
Z_tensor = tf.Variable(np.float32( np.zeros((1, 10)) ), name="Z_tensor")
Z_np = np.zeros((1,10))
update_Z = tf.assign(Z_tensor, Z_np)
Z_np[0][2:4] = 4
with tf.Session() as sess:
sess.run(Z_tensor.initializer)
print(Z_tensor.eval())
print(update_Z.eval(session=sess))
我得到了输出:
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
相反,我希望将其作为输出:
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 4. 4. 0. 0. 0. 0. 0. 0.]]
似乎Z_np
数组在赋值操作中没有更新,我也不明白为什么。
不是吗
update_Z = tf.assign(Z_tensor, Z_np)
与Z_np
链接吗?
答案 0 :(得分:3)
使用tf.assign时,它期望张量作为第二个参数。由于您提供了一个Numpy数组,因此它会自动将其提升为CONSTANT张量并将其放置在图中。因此,您对Numpy数组所做的任何更改都不会对TensorFlow图产生任何影响。为了获得所需的功能,您应该使用占位符:
Z_placeholder = tf.placeholder(tf.float32, Z_np.shape)
with tf.Session() as sess:
sess.run(Z_tensor.initializer)
print(Z_tensor.eval(feed_dict={Z_placeholder: Z_np}, session=sess))
Z_np[0][2:4] = 4
print(Z_tensor.eval(feed_dict={Z_placeholder: Z_np}, session=sess))