我在范围内创建了 可训练变量。稍后,我输入了相同的范围,将范围设置为reuse_variables
,并使用get_variable
检索相同的范围。但是,我无法将变量的可训练属性设置为False
。我的get_variable
行就像:
weight_var = tf.get_variable('weights', trainable = False)
但变量'weights'
仍然在tf.trainable_variables
的输出中。
我可以使用trainable
将共享变量False
标记设置为get_variable
吗?
我想这样做的原因是我试图在我的模型中重用从VGG网络预训练的低级过滤器,我想像之前一样构建图形,检索权重变量,并将VGG过滤器值分配给权重变量,然后在以下训练步骤中将其保持固定。
答案 0 :(得分:28)
在查看文档和代码后,我 能够找到从TRAINABLE_VARIABLES
中删除变量的方法。
tf.get_variable('weights', trainable=True)
时,该变量会添加到TRAINABLE_VARIABLES
列表中。tf.get_variable('weights', trainable=False)
时,您获得相同的变量,但参数trainable=False
无效,因为变量已存在于TRAINABLE_VARIABLES
列表中(并且存在<强>没办法从那里删除 调用优化器的minimize
方法(参见doc.)时,可以将var_list=[...]
作为参数传递给您想要优化的变量。
例如,如果要冻结除最后两个之外的所有VGG图层,可以在var_list
中传递最后两个图层的权重。
您可以使用tf.train.Saver()
保存变量并稍后恢复(请参阅this tutorial)。
saver.save(sess, "/path/to/dir/model.ckpt")
将其保存在检查点文件中。saver.restore(sess, "/path/to/dir/model.ckpt")
存储的变量。或者,您可以决定仅在检查点文件中保存一些变量。有关详细信息,请参阅doc。
答案 1 :(得分:9)
当您只想训练或优化预训练网络的某些层时,您需要了解这一点。
TensorFlow的minimize
方法采用可选参数var_list
,这是一个通过反向传播调整的变量列表。
如果您未指定var_list
,则优化程序可以调整图表中的任何TF变量。当您在var_list
中指定一些变量时,TF会保持所有其他变量不变。
以下是jonbruner及其合作者使用的脚本示例。
tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)
这会找到之前定义的所有变量,它们具有&#34; g _&#34;在变量名中,将它们放入列表中,并在它们上运行ADAM优化器。
您可以在Quora
上找到相关答案答案 2 :(得分:5)
为了从可训练变量列表中删除变量,您可以首先通过以下方式访问集合:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
其中trainable_collection
包含对可训练变量集合的引用。如果从此列表中弹出元素,例如进行trainable_collection.pop(0)
,则将从可训练变量中删除相应的变量,因此将不训练该变量。
尽管此方法适用于pop
,但我仍在努力寻找正确使用带有正确参数的remove
的方法,因此我们不依赖变量的索引。
编辑:假设您在图形中拥有变量的名称(可以通过检查图形protobuf或通过使用Tensorboard来轻松获得),就可以使用它来循环通过可训练变量列表,然后从可训练集合中删除变量。
例如:说我想训练名称为"batch_normalization/gamma:0"
和"batch_normalization/beta:0"
NOT 的变量,但是它们已经被添加到TRAINABLE_VARIABLES
集合中。我能做的是:
`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
` 这样将成功从集合中删除这两个变量,并且不再对它们进行训练。
答案 3 :(得分:0)
您可以使用tf.get_collection_ref而不是tf.get_collection来获取集合的引用