排除TensorFlow中由优化程序更新的特定张量

时间:2017-02-13 23:51:03

标签: graph tensorflow gradient-descent

我有两个图,我想它们是独立训练的,这意味着我有两个不同的优化器,但同时其中一个使用另一个图的张量值。因此,我需要能够在训练其中一个图表时停止更新特定的张量。我已经为我的张量分配了两个不同的名称范围,并使用此代码来控制不同优化器的张量更新:

mentor_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "mentor")
train_op_mentor = mnist.training(loss_mentor, FLAGS.learning_rate, mentor_training_vars)
mentee_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "mentee")
train_op_mentee = mnist.training(loss_mentee, FLAGS.learning_rate, mentee_training_vars)

在mnist对象的训练方法中,如下所示使用vars变量:

def training(loss, learning_rate, var_list):

  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)
  # Create the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # Create a variable to track the global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # Use the optimizer to apply the gradients that minimize the loss
  # (and also increment the global step counter) as a single training step.
  train_op = optimizer.minimize(loss, global_step=global_step, var_list=var_list)
  return train_op

我正在使用优化程序类 var_list 属性来控制优化程序正在更新的变量。

现在我很困惑我是否已经做了我应该做的事情,即使有任何方法要检查是否有任何优化器只会更新部分图表?

如果有人能帮我解决这个问题,我将不胜感激。

谢谢!

1 个答案:

答案 0 :(得分:1)

我遇到了类似的问题并使用了与您相同的方法,即通过优化器的var_list参数。然后我检查了用于训练的变量是否保持不变:

the_var_np = sess.run(tf.get_default_graph().get_tensor_by_name('the_var:0'))
assert np.equal(the_var_np, pretrained_weights['the_var']).all()

pretrained_weights是由np.load('some_file.npz')返回的词典,我用它将预先训练过的权重存储到磁盘上。

如果你需要它,以下是如何用给定值覆盖张量:

value = pretrained_weights['the_var']
variable = tf.get_default_graph().get_tensor_by_name('the_var:0')
sess.run(tf.assign(variable, value))