TensorFlow中Variable和get_variable之间的区别

时间:2016-05-08 09:57:29

标签: python tensorflow

据我所知,Variable是制作变量的默认操作,而get_variable主要用于分配重量。

一方面,有些人建议您在需要变量时使用get_variable而不是原始Variable操作。另一方面,我只看到在TensorFlow的官方文档和演示中使用get_variable

因此,我想知道如何正确使用这两种机制的一些经验法则。有没有“标准”原则?

4 个答案:

答案 0 :(得分:84)

我建议您始终使用tf.get_variable(...) - 如果您需要随时共享变量,则可以更轻松地重构代码,例如在multi-gpu设置中(参见multi-gpu CIFAR示例)。它没有任何缺点。

tf.Variable是较低级别的;在某些时候tf.get_variable()不存在,所以一些代码仍然使用低级方式。

答案 1 :(得分:64)

tf.Variable是一个类,有几种方法可以创建tf.Variable,包括tf.Variable .__ init__和tf.get_variable。

tf.Variable .__ init__:使用 initial_value 创建一个新变量。

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable:使用这些参数获取现有变量或创建一个新变量。您也可以使用初始化程序。

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

使用xavier_initializer等初始化程序非常有用:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable的更多信息。

答案 2 :(得分:39)

我可以找到两者之间的两个主要区别:

  1. 首先,tf.Variable将始终创建一个新变量,tf.get_variable是否从图中获取带有这些参数的现有变量,如果它不存在,则会创建一个新变量

  2. tf.Variable要求指定初始值。

  3. 重要的是要澄清函数tf.get_variable在名称前加上当前变量作用域以执行重用检查。例如:

    with tf.variable_scope("one"):
        a = tf.get_variable("v", [1]) #a.name == "one/v:0"
    with tf.variable_scope("one"):
        b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
    with tf.variable_scope("one", reuse = True):
        c = tf.get_variable("v", [1]) #c.name == "one/v:0"
    
    with tf.variable_scope("two"):
        d = tf.get_variable("v", [1]) #d.name == "two/v:0"
        e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
    
    assert(a is c)  #Assertion is true, they refer to the same object.
    assert(a is d)  #AssertionError: they are different objects
    assert(d is e)  #AssertionError: they are different objects
    

    最后一个断言错误很有趣:在同一范围内具有相同名称的两个变量应该是相同的变量。但是,如果您测试变量de的名称,您会发现Tensorflow更改了变量e的名称:

    d.name   #d.name == "two/v:0"
    e.name   #e.name == "two/v_1:0"
    

答案 3 :(得分:2)

另一个区别在于,一个位于('variable_store',)集合中,而另一个不在。

请参见源code

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

让我说明一下:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

输出:

  

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}