是否有可能使可训练的变量无法训练?

时间:2016-05-19 14:17:20

标签: tensorflow pre-trained-model

我在范围内创建了 可训练变量。稍后,我输入了相同的范围,将范围设置为reuse_variables,并使用get_variable检索相同的范围。但是,我无法将变量的可训练属性设置为False。我的get_variable行就像:

weight_var = tf.get_variable('weights', trainable = False)

但变量'weights'仍然在tf.trainable_variables的输出中。

我可以使用trainable将共享变量False标记设置为get_variable吗?

我想这样做的原因是我试图在我的模型中重用从VGG网络预训练的低级过滤器,我想像之前一样构建图形,检索权重变量,并将VGG过滤器值分配给权重变量,然后在以下训练步骤中将其保持固定。

4 个答案:

答案 0 :(得分:28)

在查看文档和代码后,我 能够找到从TRAINABLE_VARIABLES中删除变量的方法。

以下是发生的事情:

  • 第一次调用tf.get_variable('weights', trainable=True)时,该变量会添加到TRAINABLE_VARIABLES列表中。
  • 第二次调用tf.get_variable('weights', trainable=False)时,您获得相同的变量,但参数trainable=False无效,因为变量已存在于TRAINABLE_VARIABLES列表中(并且存在<强>没办法从那里删除

第一个解决方案

调用优化器的minimize方法(参见doc.)时,可以将var_list=[...]作为参数传递给您想要优化的变量。

例如,如果要冻结除最后两个之外的所有VGG图层,可以在var_list中传递最后两个图层的权重。

第二种解决方案

您可以使用tf.train.Saver()保存变量并稍后恢复(请参阅this tutorial)。

  • 首先,您使用所有可训练的变量训练整个VGG模型。您可以通过调用saver.save(sess, "/path/to/dir/model.ckpt")将其保存在检查点文件中。
  • 然后(在另一个文件中)您使用不可训练的变量训练第二个版本。您加载以前使用saver.restore(sess, "/path/to/dir/model.ckpt")存储的变量。

或者,您可以决定仅在检查点文件中保存一些变量。有关详细信息,请参阅doc

答案 1 :(得分:9)

当您只想训练或优化预训练网络的某些层时,您需要了解这一点。

TensorFlow的minimize方法采用可选参数var_list,这是一个通过反向传播调整的变量列表。

如果您未指定var_list,则优化程序可以调整图表中的任何TF变量。当您在var_list中指定一些变量时,TF会保持所有其他变量不变。

以下是jonbruner及其合作者使用的脚本示例。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

这会找到之前定义的所有变量,它们具有&#34; g _&#34;在变量名中,将它们放入列表中,并在它们上运行ADAM优化器。

您可以在Quora

上找到相关答案

答案 2 :(得分:5)

为了从可训练变量列表中删除变量,您可以首先通过以下方式访问集合: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 其中trainable_collection包含对可训练变量集合的引用。如果从此列表中弹出元素,例如进行trainable_collection.pop(0),则将从可训练变量中删除相应的变量,因此将不训练该变量。

尽管此方法适用于pop,但我仍在努力寻找正确使用带有正确参数的remove的方法,因此我们不依赖变量的索引。

编辑:假设您在图形中拥有变量的名称(可以通过检查图形protobuf或通过使用Tensorboard来轻松获得),就可以使用它来循环通过可训练变量列表,然后从可训练集合中删除变量。 例如:说我想训练名称为"batch_normalization/gamma:0""batch_normalization/beta:0" NOT 的变量,但是它们已经被添加到TRAINABLE_VARIABLES集合中。我能做的是: `

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

` 这样将成功从集合中删除这两个变量,并且不再对它们进行训练。

答案 3 :(得分:0)

您可以使用tf.get_collection_ref而不是tf.get_collection来获取集合的引用

相关问题