我希望Tensorflow在f(...)
但是tf.control_dependencies
不能满足我的要求。
如何修复控件依赖项?
结果:
cache_ 0.0
x_ 2.0
AssertionError
测试:
import tensorflow as tf
import numpy as np
def f(a, cache):
assign_op = tf.assign(cache, a)
with tf.control_dependencies([assign_op]):
return a
def main():
dtype = np.float32
data = tf.range(5, dtype=dtype)
cache = tf.Variable(0, dtype=dtype)
x = f(data[2], cache)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
x_ = sess.run(x)
cache_ = sess.run(cache)
print("cache_", cache_)
print("x_", x_)
assert np.allclose(cache_, x_)
main()
答案 0 :(得分:2)
问题在于return a
是Python代码。您没有在with
块中创建任何TensorFlow操作。您可以使用tf.identity
创建一个操作,以确保从a
读取assign_op
时将首先执行。这是更新的代码:
def f(a, cache):
assign_op = tf.assign(cache, a)
with tf.control_dependencies([assign_op]):
return tf.identity(a)