我正在使用tf.train.Saver
save
和restore
保存并恢复TensorFlow模型。在恢复过程中,我正在加载新的输入数据。 restore
方法抛出此错误:
InvalidArgumentError(参见上面的回溯):Assign需要形状 两个张量匹配。 lhs shape = [1334,3] rhs shape = [1246,3] [[节点:save / Assign_6 =分配[T = DT_FLOAT,_class = [“loc:@ Variable_2”], use_locking = true,validate_shape = true, _device =“/ job:localhost / replica:0 / task:0 / cpu:0”](Variable_2,save / RestoreV2_6)]]
这似乎表明问题出在Variable_2
,但是如何确定代码中哪个变量对应Variable_2
?
答案 0 :(得分:1)
a = tf.placeholder("float", [3, 3], name="tensor_a")
代码中的快速教程:
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
tf.reset_default_graph()
# Create some variables.
d1 = tf.get_variable("v1", shape=[3])
d2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % d1.eval())
print("v2 : %s" % d2.eval())
如果您注意到上面的代码d1和v1现在具有相同的形状,如果您更改任何变量的形状,它将引发一个错误,类似于您得到的错误
答案 1 :(得分:-1)
当您创建一个新变量时,它将获得一个唯一的名称。 Saver.restore在检查点中查看相同的名称。如果您需要使用不同的名称从不同的检查点初始化某些变量,请查看tf.contrib.framework.init_from_checkpoint
。