Tensorflow变量范围:如果存在变量则重用

时间:2016-07-23 18:52:32

标签: python tensorflow

我想要一段代码,如果它不存在则在范围内创建一个变量,并且如果该变量已经存在则访问该变量。我需要它是相同的代码,因为它将被多次调用。

但是,Tensorflow需要我指定是否要创建或重用变量,如下所示:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])

如何让它弄清楚是否自动创建或重复使用它?即,我希望上面两个代码块是相同并运行程序。

4 个答案:

答案 0 :(得分:32)

在创建新变量并且未声明形状或在变量创建期间违反重用时,ValueError中会引发get_variable()。因此,你可以试试这个:

def get_scope_variable(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as scope:
        try:
            v = tf.get_variable(var, shape)
        except ValueError:
            scope.reuse_variables()
            v = tf.get_variable(var)
    return v

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2

请注意,以下内容也有效:

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2

更新。新API现在支持自动重复使用:

def get_scope_variable(scope, var, shape=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        v = tf.get_variable(var, shape)
    return v

答案 1 :(得分:13)

尽管使用"尝试......除了......"如果子句工作,我认为更优雅和可维护的方式是将变量初始化过程与" reuse"分开。过程

def initialize_variable(scope_name, var_name, shape):
    with tf.variable_scope(scope_name) as scope:
        v = tf.get_variable(var_name, shape)
        scope.reuse_variable()

def get_scope_variable(scope_name, var_name):
    with tf.variable_scope(scope_name, reuse=True):
        v = tf.get_variable(var_name)
    return v

由于我们通常只需要初始化变量,但多次重用/共享它,将两个进程分开使代码更清晰。也是这样,我们不需要经历"尝试"每次检查变量是否已经创建的子句。

答案 2 :(得分:13)

新的AUTO_REUSE选项可以解决问题。

tf.variable_scope API docs:if reuse=tf.AUTO_REUSE,我们创建变量(如果它们不存在),否则返回它们。

共享变量AUTO_REUSE的基本示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

答案 3 :(得分:1)

我们可以在tf.varaible_scope上编写抽象,而不是在第一次调用时使用reuse=None,并在后续调用中使用reuse=True

def variable_scope(name_or_scope, *args, **kwargs):
  if isinstance(name_or_scope, str):
    scope_name = tf.get_variable_scope().name + '/' + name_or_scope
  elif isinstance(name_or_scope, tf.Variable):
    scope_name = name_or_scope.name

  if scope_name in variable_scope.scopes:
    kwargs['reuse'] = True
  else:
    variable_scope.scopes.add(scope_name)

  return tf.variable_scope(name_or_scope, *args, **kwargs)
variable_scope.scopes = set()

用法:

with variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with variable_scope("foo"): #reuse the second time
    v = tf.get_variable("v", [1])