我想要一段代码,如果它不存在则在范围内创建一个变量,并且如果该变量已经存在则访问该变量。我需要它是相同的代码,因为它将被多次调用。
但是,Tensorflow需要我指定是否要创建或重用变量,如下所示:
with tf.variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True): #reuse the second time
v = tf.get_variable("v", [1])
如何让它弄清楚是否自动创建或重复使用它?即,我希望上面两个代码块是相同并运行程序。
答案 0 :(得分:32)
在创建新变量并且未声明形状或在变量创建期间违反重用时,ValueError
中会引发get_variable()
。因此,你可以试试这个:
def get_scope_variable(scope_name, var, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
v = tf.get_variable(var, shape)
except ValueError:
scope.reuse_variables()
v = tf.get_variable(var)
return v
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
请注意,以下内容也有效:
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
更新。新API现在支持自动重复使用:
def get_scope_variable(scope, var, shape=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
v = tf.get_variable(var, shape)
return v
答案 1 :(得分:13)
尽管使用"尝试......除了......"如果子句工作,我认为更优雅和可维护的方式是将变量初始化过程与" reuse"分开。过程
def initialize_variable(scope_name, var_name, shape):
with tf.variable_scope(scope_name) as scope:
v = tf.get_variable(var_name, shape)
scope.reuse_variable()
def get_scope_variable(scope_name, var_name):
with tf.variable_scope(scope_name, reuse=True):
v = tf.get_variable(var_name)
return v
由于我们通常只需要初始化变量,但多次重用/共享它,将两个进程分开使代码更清晰。也是这样,我们不需要经历"尝试"每次检查变量是否已经创建的子句。
答案 2 :(得分:13)
新的AUTO_REUSE选项可以解决问题。
从tf.variable_scope API docs:if reuse=tf.AUTO_REUSE
,我们创建变量(如果它们不存在),否则返回它们。
共享变量AUTO_REUSE的基本示例:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
答案 3 :(得分:1)
我们可以在tf.varaible_scope
上编写抽象,而不是在第一次调用时使用reuse=None
,并在后续调用中使用reuse=True
:
def variable_scope(name_or_scope, *args, **kwargs):
if isinstance(name_or_scope, str):
scope_name = tf.get_variable_scope().name + '/' + name_or_scope
elif isinstance(name_or_scope, tf.Variable):
scope_name = name_or_scope.name
if scope_name in variable_scope.scopes:
kwargs['reuse'] = True
else:
variable_scope.scopes.add(scope_name)
return tf.variable_scope(name_or_scope, *args, **kwargs)
variable_scope.scopes = set()
用法:
with variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with variable_scope("foo"): #reuse the second time
v = tf.get_variable("v", [1])