control_dependencies与预期不同

时间:2019-06-25 23:06:02

标签: python tensorflow

我希望Tensorflow在f(...)

中执行以下操作
  1. 获取数据[索引]
  2. 缓存值
  3. 返回数据[索引]

但是tf.control_dependencies不能满足我的要求。

如何修复控件依赖项?

结果:

cache_ 0.0
x_ 2.0
AssertionError

测试:

import tensorflow as tf
import numpy as np


def f(a, cache):
    assign_op = tf.assign(cache, a)
    with tf.control_dependencies([assign_op]):
        return a


def main():
    dtype = np.float32
    data = tf.range(5, dtype=dtype)
    cache = tf.Variable(0, dtype=dtype)
    x = f(data[2], cache)
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        x_ = sess.run(x)
        cache_ = sess.run(cache)
    print("cache_", cache_)
    print("x_", x_)
    assert np.allclose(cache_, x_)


main()

1 个答案:

答案 0 :(得分:2)

问题在于return a是Python代码。您没有在with块中创建任何TensorFlow操作。您可以使用tf.identity创建一个操作,以确保从a读取assign_op时将首先执行。这是更新的代码:

def f(a, cache):
    assign_op = tf.assign(cache, a)
    with tf.control_dependencies([assign_op]):
        return tf.identity(a)