TensorFlow:在变量范围中取消设置'重用'

时间:2016-08-25 19:18:53

标签: python tensorflow

有可能在变量范围内取消设置reuse吗? 我尝试了以下命令:

In [1]: import tensorflow as tf

In [2]: tf.get_variable_scope().reuse
Out[2]: False

In [3]: tf.get_variable_scope().reuse_variables
Out[3]: <bound method VariableScope.reuse_variables of <tensorflow.python.ops.variable_scope.VariableScope object at 0x7fd9cc46c4d0>>

In [4]: tf.get_variable_scope().reuse_variables()

In [5]: tf.get_variable_scope().reuse
Out[5]: True

In [6]: tf.get_variable_scope().reuse_variables()

In [7]: tf.get_variable_scope().reuse
Out[7]: True

In [8]: tf.get_variable_scope().reuse_variables(False)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-ba19ed12625c> in <module>()
----> 1 tf.get_variable_scope().reuse_variables(False)

TypeError: reuse_variables() takes exactly 1 argument (2 given)

In [9]: tf.get_variable_scope().reuse = False
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-ddd0b37e4f0e> in <module>()
----> 1 tf.get_variable_scope().reuse = False

AttributeError: can't set attribute

In [10]: tf.get_variable_scope().reuse_variables = False

In [11]: tf.get_variable_scope().reuse_variables()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-d03dc0fb25b6> in <module>()
----> 1 tf.get_variable_scope().reuse_variables()

TypeError: 'bool' object is not callable

正如您所看到的,我无法通过多次调用reuse取消reuse_variables,我无法使用=运算符进行设置,并且出于某种原因我可以设置无论我想要什么,而不是reuse_variables函数(这是一个错误?)。

1 个答案:

答案 0 :(得分:3)

一旦你进入了一个范围,我认为这是设计上的,你不能改变重用状态,直到再次重新打开范围,并以另一种方式设置标志:

with tf.variable_scope("scope"):
    a = tf.get_variable("var_a", 1) 
    print(a.name)

with tf.variable_scope("scope", reuse = True):
    b = tf.get_variable("var_a")
    print b.name 
    #c = tf.get_variable("var_b") # won't work
    # can't reuse something that doesn't exist
    # probably enforced so you don't make unintended variables

with tf.variable_scope("scope"): #reuse False
    #c = tf.get_variable("var_a") # won't work
    # there is another variable with the same name
    # makes sure you don't override the previous variable                    
    c = tf.get_variable("var_b",2)
    print c.name
你可以用这种方式破解它:

with tf.variable_scope("scope") as scope:
    a = tf.get_variable("var_a", 1)  
    print(a.name)

with tf.variable_scope("scope", reuse = True):
    b = tf.get_variable("var_a")
    print b.name 
    with tf.variable_scope(scope):
        c = tf.get_variable("var_b", 1)
        print c.name

我想你可以这样做:

print(tf.get_variable_scope().reuse) #False
tf.get_variable_scope().reuse_variables()
print(tf.get_variable_scope().reuse) #True

with tf.variable_scope(tf.get_variable_scope(), reuse=False):
    print(tf.get_variable_scope().reuse) #False