我正在使用Tensorflow的scatter_update在一个简单的测试中设置数组的值。我希望以下代码将数组中的所有值都设置为1.0。
pt = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
numOne = tf.constant(1)
xi = tf.Variable(0)
out = tf.Variable(pt)
tf.global_variables_initializer().run()
xi_ = xi + numOne
out_ = tf.scatter_update(out, [xi], [tf.cast(numOne, tf.float32)])
step = tf.group(
xi.assign(xi_),
out.assign(out_)
)
for i in range(15): step.run()
print(out.eval())
相反,我得到了不一致的结果,如:
[ 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.]
我缺少某种类型的锁定机制吗?
答案 0 :(得分:0)
tf.group
内的操作以随机顺序执行。如果您需要特定订单,可以使用tf.control_dependencies
或分成两个.run
来电,即
step1 = xi.assign(xi_)
step2 = out.assign(out_)
for i in range(15):
print(out.eval())
sess.run(step1)
sess.run(step2)
答案 1 :(得分:0)
另外,使用control_dependencies会是这样的:
xi_ = xi + numOne
with tf.control_dependencies[xi_]:
out_ = tf.scatter_update(out, [xi], [tf.cast(numOne, tf.float32)])
with tf.control_dependencies[out_]:
step = tf.group(
xi.assign(xi_),
out.assign(out_)
)
for i in range(15): step.run()
print(out.eval())