TensorFlow tf.group忽略依赖关系?

时间:2017-05-29 14:20:23

标签: python tensorflow

earlier question开始,似乎tf.group确实忽略了依赖关系。这是一个简单的独立示例(我在使用TensorFlow 1.1的Python 2.7上运行它):

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

xs = [tf.constant(x) for x in range(10)]
xs = [tf.Print(x, [x]) for x in xs]
dependency = None
dxs = []

for x in xs:
    if dependency is None:
        dependency = x
    else:
        dependency = control_flow_ops.with_dependencies([dependency], x)

    dxs.append(dependency)

print_all_op = tf.group(*dxs)

with tf.Session() as session:
    session.run(print_all_op)

预期产出:

2017-05-29 15:11:53.961221: I tensorflow/core/kernels/logging_ops.cc:79] [0]
2017-05-29 15:11:53.961236: I tensorflow/core/kernels/logging_ops.cc:79] [1]
2017-05-29 15:11:53.961255: I tensorflow/core/kernels/logging_ops.cc:79] [2]
2017-05-29 15:11:53.961237: I tensorflow/core/kernels/logging_ops.cc:79] [3]
2017-05-29 15:11:53.961262: I tensorflow/core/kernels/logging_ops.cc:79] [4]
2017-05-29 15:11:53.961263: I tensorflow/core/kernels/logging_ops.cc:79] [5]
2017-05-29 15:11:53.961268: I tensorflow/core/kernels/logging_ops.cc:79] [6]
2017-05-29 15:11:53.961272: I tensorflow/core/kernels/logging_ops.cc:79] [7]
2017-05-29 15:11:53.961274: I tensorflow/core/kernels/logging_ops.cc:79] [8]
2017-05-29 15:11:53.961221: I tensorflow/core/kernels/logging_ops.cc:79] [9]

实际输出(每次运行代码时都不同):

2017-05-29 15:16:26.279655: I tensorflow/core/kernels/logging_ops.cc:79] [0]
2017-05-29 15:16:26.279655: I tensorflow/core/kernels/logging_ops.cc:79] [9]
2017-05-29 15:16:26.279697: I tensorflow/core/kernels/logging_ops.cc:79] [3]
2017-05-29 15:16:26.279660: I tensorflow/core/kernels/logging_ops.cc:79] [1]
2017-05-29 15:16:26.279711: I tensorflow/core/kernels/logging_ops.cc:79] [8]
2017-05-29 15:16:26.279713: I tensorflow/core/kernels/logging_ops.cc:79] [4]
2017-05-29 15:16:26.279723: I tensorflow/core/kernels/logging_ops.cc:79] [5]
2017-05-29 15:16:26.279663: I tensorflow/core/kernels/logging_ops.cc:79] [2]
2017-05-29 15:16:26.279724: I tensorflow/core/kernels/logging_ops.cc:79] [7]
2017-05-29 15:16:26.279728: I tensorflow/core/kernels/logging_ops.cc:79] [6]

tf.group文档中没有任何内容表明为什么会忽略依赖项。

是否有tf.group的替代品确实考虑了依赖关系?

切换为使用tf.control_dependencies代替tensorflow.python.ops.control_flow_ops.with_dependencies无效:

import tensorflow as tf

xs = [tf.constant(x) for x in range(10)]
xs = [tf.Print(x, [x]) for x in xs]
dependency = None
dxs = []

for x in xs:
    if dependency is None:
        dependency = x
    else:
        with tf.control_dependencies([dependency]):
            dependency = x

    dxs.append(dependency)

print_all_op = tf.group(*dxs)

with tf.Session() as session:
    session.run(print_all_op)

2 个答案:

答案 0 :(得分:3)

正确使用tf.control_dependencies可以解决此问题:

import tensorflow as tf

xs = [tf.constant(x) for x in range(10)]
dependency = None
dxs = []

for x in xs:
    if dependency is None:
        dependency = tf.Print(x, [x])
    else:
        with tf.control_dependencies([dependency]):
            dependency = tf.Print(x, [x])

    dxs.append(dependency)

print_all_op = tf.group(*dxs)

with tf.Session() as session:
    session.run(print_all_op)

请注意,需要在Print上下文管理器中创建tf.control_dependencies操作。

我仍然不清楚control_flow_ops.with_dependencies版本失败的原因。

答案 1 :(得分:2)

我认为问题在于初始代码会在control_flow_ops.with_dependencies隐式创建的虚拟身份操作之间创建依赖关系,而不是实际的tf.Print操作。 Tensorflow似乎只确保依赖项列表中的操作已经执行但其他前面操作的顺序不固定。在上面的示例中,依赖关系是在control_flow_ops.with_dependencies

创建的虚拟身份操作上创建的
    dependency = control_flow_ops.with_dependencies([dependency], x)

应该相当于:

        with tf.control_dependencies([dependency]):
            dependency = tf.identity(x)

因此,这里的依赖关系是在tf.identity操作而不是tf.Print操作之间创建的。 tf.Print操作可以按任何顺序运行,严格排序仅适用于tf.identity操作。我不认为可以通过control_flow_ops.with_dependencies实现所需的行为。相反,必须使用with tf.control_dependencies代替(正如op已经建议的那样):

xs = [tf.constant(x) for x in range(10)]
dependency = None
dxs = []

for x in xs:
    if dependency is None:
        dependency = tf.Print(x, [x])
    else:
        with tf.control_dependencies([dependency]):
            dependency = tf.Print(x, [x])

    dxs.append(dependency)