当dtypes不同时,如何重用TensorFlow中的现有变量?

时间:2019-06-21 07:13:01

标签: python tensorflow

最小代码示例:

with tf.variable_scope("initializer_test"):
    s = tf.get_variable("scalar", initializer=tf.constant(2)) 

with tf.variable_scope("initializer_test", reuse=True):
    s = tf.get_variable("scalar")
# ValueError: Trying to share variable initializer_test/scalar, but specified dtype float32 and found dtype int32_ref. 

我的解决方案:

只需阅读错误消息即可提供简单的解决方案:

with tf.variable_scope("initializer_test"):
    s = tf.get_variable("scalar", initializer=tf.constant(2)) 

with tf.variable_scope("initializer_test", reuse=True):
    s = tf.get_variable("scalar", dtype=tf.int32) # Just add the required dtype

有更好的方法吗?我希望不必(查看错误消息以找出dtype)或(在我第一次声明时手动为s设置dtype)。

1 个答案:

答案 0 :(得分:0)

AUTO_REUSE作为重用模式添加到变量作用域。此模式修改get_variable()的行为以创建所请求的变量(如果不存在)或返回它们(如果存在)。

现在可以编写以下代码:

def call_f():
    with tf.variable_scope("initializer_test", reuse=tf.AUTO_REUSE):
        v = tf.get_variable("scalar", initializer=tf.constant(2))
    return v

v1 = call_f()  # Creates v.
v2 = call_f()  # Gets the same, existing v.

print(v1)
print(v2)

输出:

<tf.Variable 'initializer_test/scalar:0' shape=() dtype=int32, numpy=2>
<tf.Variable 'initializer_test/scalar:0' shape=() dtype=int32, numpy=2>