最小代码示例:
with tf.variable_scope("initializer_test"):
s = tf.get_variable("scalar", initializer=tf.constant(2))
with tf.variable_scope("initializer_test", reuse=True):
s = tf.get_variable("scalar")
# ValueError: Trying to share variable initializer_test/scalar, but specified dtype float32 and found dtype int32_ref.
我的解决方案:
只需阅读错误消息即可提供简单的解决方案:
with tf.variable_scope("initializer_test"):
s = tf.get_variable("scalar", initializer=tf.constant(2))
with tf.variable_scope("initializer_test", reuse=True):
s = tf.get_variable("scalar", dtype=tf.int32) # Just add the required dtype
有更好的方法吗?我希望不必(查看错误消息以找出dtype)或(在我第一次声明时手动为s
设置dtype)。
答案 0 :(得分:0)
将AUTO_REUSE
作为重用模式添加到变量作用域。此模式修改get_variable()
的行为以创建所请求的变量(如果不存在)或返回它们(如果存在)。
现在可以编写以下代码:
def call_f():
with tf.variable_scope("initializer_test", reuse=tf.AUTO_REUSE):
v = tf.get_variable("scalar", initializer=tf.constant(2))
return v
v1 = call_f() # Creates v.
v2 = call_f() # Gets the same, existing v.
print(v1)
print(v2)
输出:
<tf.Variable 'initializer_test/scalar:0' shape=() dtype=int32, numpy=2>
<tf.Variable 'initializer_test/scalar:0' shape=() dtype=int32, numpy=2>