我是TensorFlow的新手。当我阅读tensorflow saving and restoring variables手册时,我遇到了一个问题。我保存了一个由常量初始化的变量,但是我无法恢复变量。代码如下:
a = tf.get_variable("name_a", initializer=[1,2,3])
op1 = a.assign(a+1)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
op1.op.run()
print(a.eval())
saver.save(sess,"log1/model.ckpt")
然后我恢复它。
a = tf.get_variable("name_a", shape=[3])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
print(a.eval())
我希望获得[2,3,4]
之类的输出,但我得到了[ 2.80259693e-45 4.20389539e-45 5.60519386e-45]
。这都是零。
但是,当我将第一个代码段中的第一行修改为
时a = tf.get_variable("name_a", initializer=tf.zeros([3]))
我可以获得正确的恢复变量:[ 1. 1. 1.]
我想知道这种情况的原因。
答案 0 :(得分:1)
我不是百分百肯定,但看起来原因在于你的两个变量:
tf.get_variable("name_a", initializer=[1,2,3])
tf.get_variable("name_a", shape=[3])
不等同,并且不能轻易地相互使用(更新:dtype
不同,感谢@BlueSun注意到这一点)
如果您在恢复代码中定义张量,就像保存:a = tf.get_variable("name_a", initializer=[1,2,3])
一样,您将获得稳定的输出。但是,更好的方法是直接使用保存的图形:
saver = tf.train.import_meta_graph('log1/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
saved = sess.graph.get_tensor_by_name('name_a:0')
print(sess.run(saved))
无论您如何定义初始化程序,它都能正常工作。
答案 1 :(得分:1)
您必须使用相同的数据类型定义变量a
。如果您没有指定它并且没有任何初始值设定项,默认情况下dtype将为tf.float32,并且tf.int32的加载将失败。只需将数据类型设置为int32即可解决问题:
a = tf.get_variable("name_a", shape=[3], dtype=tf.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "log1/model.ckpt")
print(a.eval())
使用a = tf.get_variable("name_a", initializer=tf.zeros([3]))
有效,因为tf.zeros([3])
与[2, 3, 4]
的dtype相同。每当您创建变量时始终设置dtype
会更安全。