TensorFlow:按名称变量

时间:2016-02-28 04:57:44

标签: tensorflow

使用TensorFlow Python API时,我创建了一个变量(未在构造函数中指定其name),其name属性的值为"Variable_23:0"。当我尝试使用tf.get_variable("Variable23")选择此变量时,会创建一个名为"Variable_23_1:0"的新变量。如何正确选择"Variable_23"而不是创建新的?{1}}?

我想要做的是按名称选择变量,然后重新初始化它以便我可以微调权重。

4 个答案:

答案 0 :(得分:34)

get_variable()函数创建一个新变量或返回get_variable()之前创建的变量。它不会返回使用tf.Variable()创建的变量。这是一个简单的例子:

>>> with tf.variable_scope("foo"):
...   bar1 = tf.get_variable("bar", (2,3)) # create
... 
>>> with tf.variable_scope("foo", reuse=True):
...   bar2 = tf.get_variable("bar")  # reuse
... 

>>> with tf.variable_scope("", reuse=True): # root variable scope
...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
... 
>>> (bar1 is bar2) and (bar2 is bar3)
True

如果您没有使用tf.get_variable()创建变量,则有几个选项。首先,您可以使用tf.global_variables()(正如@mrry建议的那样):

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True

或者你可以这样使用tf.get_collection()

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True

修改

您还可以使用get_tensor_by_name()

>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
bar2 in value.

回想一下张量是一个操作的输出。它与操作具有相同的名称,加上:0。如果操作具有多个输出,则它们与操作加:0:1:2等名称相同。

答案 1 :(得分:33)

按名称获取变量的最简单方法是在tf.global_variables()集合中搜索它:

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]

这适用于现有变量的临时重用。当您想要在模型的多个部分之间共享变量时,更加结构化的方法将在Sharing Variables tutorial中介绍。

答案 2 :(得分:0)

如果要从模型中获取任何存储的变量,请使用tf.train.load_variable("model_folder_name","Variable name")

答案 3 :(得分:0)

根据@mrry的回答,我认为创建和使用以下函数会更好,因为还有局部变量和其他不在全局变量中的变量(它们在不同的集合中):

def get_var_by_name(query_name, var_list):
    """
    Get Variable by name

    e.g.
    local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
    the_var = get_var_by_name(local_vars, 'accuracy/total:0')
    """
    target_var = None
    for var in var_list:
        if var.name==query_name:
            target_var = var
            break
    return target_var