使用TensorFlow Python API时,我创建了一个变量(未在构造函数中指定其name
),其name
属性的值为"Variable_23:0"
。当我尝试使用tf.get_variable("Variable23")
选择此变量时,会创建一个名为"Variable_23_1:0"
的新变量。如何正确选择"Variable_23"
而不是创建新的?{1}}?
我想要做的是按名称选择变量,然后重新初始化它以便我可以微调权重。
答案 0 :(得分:34)
get_variable()
函数创建一个新变量或返回get_variable()
之前创建的变量。它不会返回使用tf.Variable()
创建的变量。这是一个简单的例子:
>>> with tf.variable_scope("foo"):
... bar1 = tf.get_variable("bar", (2,3)) # create
...
>>> with tf.variable_scope("foo", reuse=True):
... bar2 = tf.get_variable("bar") # reuse
...
>>> with tf.variable_scope("", reuse=True): # root variable scope
... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
...
>>> (bar1 is bar2) and (bar2 is bar3)
True
如果您没有使用tf.get_variable()
创建变量,则有几个选项。首先,您可以使用tf.global_variables()
(正如@mrry建议的那样):
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True
或者你可以这样使用tf.get_collection()
:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True
修改强>
您还可以使用get_tensor_by_name()
:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal
bar2 in value.
回想一下张量是一个操作的输出。它与操作具有相同的名称,加上:0
。如果操作具有多个输出,则它们与操作加:0
,:1
,:2
等名称相同。
答案 1 :(得分:33)
按名称获取变量的最简单方法是在tf.global_variables()
集合中搜索它:
var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
这适用于现有变量的临时重用。当您想要在模型的多个部分之间共享变量时,更加结构化的方法将在Sharing Variables tutorial中介绍。
答案 2 :(得分:0)
如果要从模型中获取任何存储的变量,请使用tf.train.load_variable("model_folder_name","Variable name")
答案 3 :(得分:0)
根据@mrry的回答,我认为创建和使用以下函数会更好,因为还有局部变量和其他不在全局变量中的变量(它们在不同的集合中):>
def get_var_by_name(query_name, var_list):
"""
Get Variable by name
e.g.
local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
the_var = get_var_by_name(local_vars, 'accuracy/total:0')
"""
target_var = None
for var in var_list:
if var.name==query_name:
target_var = var
break
return target_var