我保存了一个模型,现在我试图在两个分支中还原它,如下所示:
我写了这段代码,它引发了ValueError: The same saveable will be restored with two names
。
如何从同一个变量恢复两个变量?
restore_variables = {}
for varr in tf.global_variables()
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name.split("_red")[0]] = varr
restore_variables[varr.op.name.split("_blue")[0]] = varr
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
答案 0 :(得分:1)
在TF 1.15上测试
基本上,该错误是说它正在restore_variables
字典中找到对同一变量的多个引用。解决方法很简单。使用tf.Variable(varr)
为其中一个引用创建变量的副本,如下所示。
我认为可以安全地假设您不是要在此处查找对同一变量的多个引用,而是要查找两个单独的变量。 (我假设这样做是因为,如果您想多次使用同一变量,则可以多次使用单个变量。)
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name.split("_red")[0]] = varr
restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
在下面,您可以找到一个完整的代码,使用一个玩具示例来复制问题。本质上,我们有两个变量a
和b
,除此之外,我们正在创建b_red
和b_blue
变量。
# Saving the variables
import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)
saver = tf.train.Saver([w1, w2])
with tf.Session() as sess:
tf.global_variables_initializer().run()
saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
# Restoring the variables
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name+"_red"] = varr
# Fixing the issue: Instead of varr, do tf.Variable(varr)
restore_variables[varr.op.name+"_blue"] = varr
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
答案 1 :(得分:0)
我可能无法正确理解问题,但是您不能只创建两个保护程序对象吗?像这样:
import tensorflow as tf
# Make checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.Variable([1., 2.], name='a')
sess.run(a.initializer)
b = tf.Variable([3., 4., 5.], name='b')
sess.run(b.initializer)
saver = tf.train.Saver([a, b])
saver.save(sess, 'tmp/vars.ckpt')
# Restore checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
# Red
a_red = tf.Variable([0., 0.], name='a_red')
b_red = tf.Variable([0., 0., 0.], name='b_red')
saver_red = tf.train.Saver({'a': a_red, 'b': b_red})
saver_red.restore(sess, 'tmp1/vars.ckpt')
print(a_red.eval())
# [1. 2.]
print(b_red.eval())
# [3. 4. 5.]
# Blue
a_blue = tf.Variable([0., 0.], name='a_blue')
b_blue = tf.Variable([0., 0., 0.], name='b_blue')
saver_blue = tf.train.Saver({'a': a_blue, 'b': b_blue})
saver_blue.restore(sess, 'tmp/vars.ckpt')
print(a_blue.eval())
# [1. 2.]
print(b_blue.eval())
# [3. 4. 5.]