从earlier question开始,似乎tf.group
确实忽略了依赖关系。这是一个简单的独立示例(我在使用TensorFlow 1.1的Python 2.7上运行它):
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
xs = [tf.constant(x) for x in range(10)]
xs = [tf.Print(x, [x]) for x in xs]
dependency = None
dxs = []
for x in xs:
if dependency is None:
dependency = x
else:
dependency = control_flow_ops.with_dependencies([dependency], x)
dxs.append(dependency)
print_all_op = tf.group(*dxs)
with tf.Session() as session:
session.run(print_all_op)
预期产出:
2017-05-29 15:11:53.961221: I tensorflow/core/kernels/logging_ops.cc:79] [0]
2017-05-29 15:11:53.961236: I tensorflow/core/kernels/logging_ops.cc:79] [1]
2017-05-29 15:11:53.961255: I tensorflow/core/kernels/logging_ops.cc:79] [2]
2017-05-29 15:11:53.961237: I tensorflow/core/kernels/logging_ops.cc:79] [3]
2017-05-29 15:11:53.961262: I tensorflow/core/kernels/logging_ops.cc:79] [4]
2017-05-29 15:11:53.961263: I tensorflow/core/kernels/logging_ops.cc:79] [5]
2017-05-29 15:11:53.961268: I tensorflow/core/kernels/logging_ops.cc:79] [6]
2017-05-29 15:11:53.961272: I tensorflow/core/kernels/logging_ops.cc:79] [7]
2017-05-29 15:11:53.961274: I tensorflow/core/kernels/logging_ops.cc:79] [8]
2017-05-29 15:11:53.961221: I tensorflow/core/kernels/logging_ops.cc:79] [9]
实际输出(每次运行代码时都不同):
2017-05-29 15:16:26.279655: I tensorflow/core/kernels/logging_ops.cc:79] [0]
2017-05-29 15:16:26.279655: I tensorflow/core/kernels/logging_ops.cc:79] [9]
2017-05-29 15:16:26.279697: I tensorflow/core/kernels/logging_ops.cc:79] [3]
2017-05-29 15:16:26.279660: I tensorflow/core/kernels/logging_ops.cc:79] [1]
2017-05-29 15:16:26.279711: I tensorflow/core/kernels/logging_ops.cc:79] [8]
2017-05-29 15:16:26.279713: I tensorflow/core/kernels/logging_ops.cc:79] [4]
2017-05-29 15:16:26.279723: I tensorflow/core/kernels/logging_ops.cc:79] [5]
2017-05-29 15:16:26.279663: I tensorflow/core/kernels/logging_ops.cc:79] [2]
2017-05-29 15:16:26.279724: I tensorflow/core/kernels/logging_ops.cc:79] [7]
2017-05-29 15:16:26.279728: I tensorflow/core/kernels/logging_ops.cc:79] [6]
tf.group
文档中没有任何内容表明为什么会忽略依赖项。
是否有tf.group
的替代品确实考虑了依赖关系?
切换为使用tf.control_dependencies
代替tensorflow.python.ops.control_flow_ops.with_dependencies
无效:
import tensorflow as tf
xs = [tf.constant(x) for x in range(10)]
xs = [tf.Print(x, [x]) for x in xs]
dependency = None
dxs = []
for x in xs:
if dependency is None:
dependency = x
else:
with tf.control_dependencies([dependency]):
dependency = x
dxs.append(dependency)
print_all_op = tf.group(*dxs)
with tf.Session() as session:
session.run(print_all_op)
答案 0 :(得分:3)
正确使用tf.control_dependencies
可以解决此问题:
import tensorflow as tf
xs = [tf.constant(x) for x in range(10)]
dependency = None
dxs = []
for x in xs:
if dependency is None:
dependency = tf.Print(x, [x])
else:
with tf.control_dependencies([dependency]):
dependency = tf.Print(x, [x])
dxs.append(dependency)
print_all_op = tf.group(*dxs)
with tf.Session() as session:
session.run(print_all_op)
请注意,需要在Print
上下文管理器中创建tf.control_dependencies
操作。
我仍然不清楚control_flow_ops.with_dependencies
版本失败的原因。
答案 1 :(得分:2)
我认为问题在于初始代码会在control_flow_ops.with_dependencies
隐式创建的虚拟身份操作之间创建依赖关系,而不是实际的tf.Print
操作。 Tensorflow似乎只确保依赖项列表中的操作已经执行但其他前面操作的顺序不固定。在上面的示例中,依赖关系是在control_flow_ops.with_dependencies
:
dependency = control_flow_ops.with_dependencies([dependency], x)
应该相当于:
with tf.control_dependencies([dependency]):
dependency = tf.identity(x)
因此,这里的依赖关系是在tf.identity
操作而不是tf.Print
操作之间创建的。 tf.Print
操作可以按任何顺序运行,严格排序仅适用于tf.identity
操作。我不认为可以通过control_flow_ops.with_dependencies
实现所需的行为。相反,必须使用with tf.control_dependencies
代替(正如op已经建议的那样):
xs = [tf.constant(x) for x in range(10)]
dependency = None
dxs = []
for x in xs:
if dependency is None:
dependency = tf.Print(x, [x])
else:
with tf.control_dependencies([dependency]):
dependency = tf.Print(x, [x])
dxs.append(dependency)