我有两个图,我想它们是独立训练的,这意味着我有两个不同的优化器,但同时其中一个使用另一个图的张量值。因此,我需要能够在训练其中一个图表时停止更新特定的张量。我已经为我的张量分配了两个不同的名称范围,并使用此代码来控制不同优化器的张量更新:
mentor_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "mentor")
train_op_mentor = mnist.training(loss_mentor, FLAGS.learning_rate, mentor_training_vars)
mentee_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "mentee")
train_op_mentee = mnist.training(loss_mentee, FLAGS.learning_rate, mentee_training_vars)
在mnist对象的训练方法中,如下所示使用vars变量:
def training(loss, learning_rate, var_list):
# Add a scalar summary for the snapshot loss.
tf.summary.scalar('loss', loss)
# Create the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# Create a variable to track the global step.
global_step = tf.Variable(0, name='global_step', trainable=False)
# Use the optimizer to apply the gradients that minimize the loss
# (and also increment the global step counter) as a single training step.
train_op = optimizer.minimize(loss, global_step=global_step, var_list=var_list)
return train_op
我正在使用优化程序类的 var_list 属性来控制优化程序正在更新的变量。
现在我很困惑我是否已经做了我应该做的事情,即使有任何方法要检查是否有任何优化器只会更新部分图表?
如果有人能帮我解决这个问题,我将不胜感激。
谢谢!
答案 0 :(得分:1)
我遇到了类似的问题并使用了与您相同的方法,即通过优化器的var_list
参数。然后我检查了用于训练的变量是否保持不变:
the_var_np = sess.run(tf.get_default_graph().get_tensor_by_name('the_var:0'))
assert np.equal(the_var_np, pretrained_weights['the_var']).all()
pretrained_weights
是由np.load('some_file.npz')
返回的词典,我用它将预先训练过的权重存储到磁盘上。
如果你需要它,以下是如何用给定值覆盖张量:
value = pretrained_weights['the_var']
variable = tf.get_default_graph().get_tensor_by_name('the_var:0')
sess.run(tf.assign(variable, value))