函数内的get_variable

时间:2017-08-23 09:46:40

标签: tensorflow

我无法理解函数中的变量重用行为。 代码是:

def fa(y):
    x = tf.get_variable("x", initializer=tf.constant([3, 4, 5])) #without reusing
    x = tf.scatter_add(x,[0,2],y) #fine updating x
    return x

def fb():
    x = tf.get_variable("x", initializer=tf.constant([3, 4, 5]))
    x = x.assign([1,2,3])
    return x

with tf.Graph().as_default(),tf.device('/cpu:0'):
    with tf.variable_scope('ns1',reuse=False):
        stepa = fa([1,1])
#         stepb = fb() #would cause an error

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print('fa',sess.run(stepa))
        print('fa',sess.run(stepa))
        print('fa',sess.run(stepa))

        print('init')
        sess.run(tf.global_variables_initializer())
        print('fa',sess.run(stepa))
        print('fa',sess.run(stepa))
#         print('fb',sess.run(stepb))
#         print('fb',sess.run(stepb))
        print('fa',sess.run(stepa))

因此函数fa被调用3次,更新x。 在相同的函数fa中,get_variable似乎重用了变量x ,但在不同的函数fb中,get_variable'x'会导致错误。为什么?谢谢你的帮助!

1 个答案:

答案 0 :(得分:0)

  

stepb = fb()#会导致错误

当您尝试创建图中已存在name的变量时,它会导致错误。此外,您不会重复使用已存在的变量。

如果要重用名称中存在的变量,可以使用以下修复

with tf.variable_scope('ns1',reuse=False):
    stepa = fa([1,1])
# reuse the variable named as ns1/x
with tf.variable_scope('ns1',reuse=True):
    stepb = fb()