我无法理解函数中的变量重用行为。 代码是:
def fa(y):
x = tf.get_variable("x", initializer=tf.constant([3, 4, 5])) #without reusing
x = tf.scatter_add(x,[0,2],y) #fine updating x
return x
def fb():
x = tf.get_variable("x", initializer=tf.constant([3, 4, 5]))
x = x.assign([1,2,3])
return x
with tf.Graph().as_default(),tf.device('/cpu:0'):
with tf.variable_scope('ns1',reuse=False):
stepa = fa([1,1])
# stepb = fb() #would cause an error
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('fa',sess.run(stepa))
print('fa',sess.run(stepa))
print('fa',sess.run(stepa))
print('init')
sess.run(tf.global_variables_initializer())
print('fa',sess.run(stepa))
print('fa',sess.run(stepa))
# print('fb',sess.run(stepb))
# print('fb',sess.run(stepb))
print('fa',sess.run(stepa))
因此函数fa被调用3次,更新x。 在相同的函数fa中,get_variable似乎重用了变量x ,但在不同的函数fb中,get_variable'x'会导致错误。为什么?谢谢你的帮助!
答案 0 :(得分:0)
stepb = fb()#会导致错误
当您尝试创建图中已存在name
的变量时,它会导致错误。此外,您不会重复使用已存在的变量。
如果要重用名称中存在的变量,可以使用以下修复
with tf.variable_scope('ns1',reuse=False):
stepa = fa([1,1])
# reuse the variable named as ns1/x
with tf.variable_scope('ns1',reuse=True):
stepb = fb()