张量流中的tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

时间:2018-12-14 02:35:41

标签: python tensorflow deep-learning

张量流中tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))的目的是什么?

更多内容:

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_fn, var_list=tf.trainable_variables())

2 个答案:

答案 0 :(得分:4)

方法tf.control_dependencies可确保在上下文管理器内部定义的操作之前运行用作上下文管理器输入的操作。

例如:

count = tf.get_variable("count", shape=(), initializer=tf.constant_initializer(1), trainable=False)
count_increment = tf.assign_add(count, 1)
c = tf.constant(2.)
with tf.control_dependencies([count_increment]):
    d = c + 3
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("eval count", count.eval())
    print("eval d", d.eval())
    print("eval count", count.eval())

此打印:

eval count 1
eval d 5.0 # Running d make count_increment operation being run
eval count 2 # count_increment operation has be run and now count hold 2.

因此,在您的情况下,每次您运行train_op操作时,它将首先运行tf.GraphKeys.UPDATE_OPS集合中定义的所有操作。

答案 1 :(得分:2)

例如,如果使用tf.layers.batch_normalization,则该层将创建一些操作,需要在每个训练步骤中运行(更新变量的移动平均值和方差)。

tf.GraphKeys.UPDATE_OPS是这些变量的集合,如果将其放在tf.control_dependencies块中,则会在运行训练操作之前执行这些操作。

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization