tensorflow.control_dependecy到底是如何应用的?

时间:2018-09-07 21:49:50

标签: python-2.7 tensorflow deep-learning

        self.solver = 'adam'
        if self.solver == 'adam':
            optimizer = tf.train.AdamOptimizer(self.learning_rate_init)
        if self.solver == 'sgd_nestrov':
            optimizer = tf.train.MomentumOptimizer(learning_rate = self.learning_rate_init, momentum = self.momentum, \
                                                  use_nesterov = True)
        gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        clipped_gradients, self.global_norm = tf.clip_by_global_norm(gradients, self.max_grad_norm)
        update_ops_ = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        optimizer_op = optimizer.apply_gradients(zip(clipped_gradients, variables))
        control_ops = tf.group([self.ema_op] + update_ops_)
        with tf.control_dependencies([optimizer_op]):
            self.optimizer = control_ops

我在会话中调用self.optimizer

上面的代码未更新渐变。但是,如果我将代码的控件依赖关系部分更改为下面的代码,则它工作得很好,除了它错过了最终的指数移动平均值(self.ema_op)更新,这对我来说是不希望的:

        self.solver = 'adam'
        if self.solver == 'adam':
            optimizer = tf.train.AdamOptimizer(self.learning_rate_init)
        if self.solver == 'sgd_nestrov':
            optimizer = tf.train.MomentumOptimizer(learning_rate = self.learning_rate_init, momentum = self.momentum, \
                                                  use_nesterov = True)
        gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        clipped_gradients, self.global_norm = tf.clip_by_global_norm(gradients, self.max_grad_norm)
        update_ops_ = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        optimizer_op = optimizer.apply_gradients(zip(clipped_gradients, variables))
        control_ops = tf.group([self.ema_op] + update_ops_)
#         with tf.control_dependencies(optimizer_op):
#             self.optimizer = control_ops
        with tf.control_dependencies([self.ema_op] + update_ops_):
            self.optimizer = optimizer.apply_gradients(zip(clipped_gradients, variables))

请告诉我我想念什么?

1 个答案:

答案 0 :(得分:1)

您需要在with语句下定义tensorflow操作,而不仅仅是设置变量。进行self.optimizer = control_ops无效,因为您没有创建任何张量流操作。

在没有完全理解您的问题的情况下,我认为您需要这样的东西:

with tf.control_dependencies(optimizer_op):
  control_ops = tf.group([self.ema_op] + update_ops_)

self.optimizer = control_ops

with语句进入一个块,在这种情况下,您在tensorflow中创建的任何新操作都将取决于optimizer_op