为什么TensorFlow无法恢复由常量初始化的变量?

时间:2017-10-30 14:41:07

标签: tensorflow

我是TensorFlow的新手。当我阅读tensorflow saving and restoring variables手册时,我遇到了一个问题。我保存了一个由常量初始化的变量,但是我无法恢复变量。代码如下:

a = tf.get_variable("name_a", initializer=[1,2,3])
op1 = a.assign(a+1)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    op1.op.run()
    print(a.eval())
    saver.save(sess,"log1/model.ckpt")

然后我恢复它。

a = tf.get_variable("name_a", shape=[3])
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

我希望获得[2,3,4]之类的输出,但我得到了[ 2.80259693e-45 4.20389539e-45 5.60519386e-45]。这都是零。

但是,当我将第一个代码段中的第一行修改为

a = tf.get_variable("name_a", initializer=tf.zeros([3]))

我可以获得正确的恢复变量:[ 1. 1. 1.]

我想知道这种情况的原因。

2 个答案:

答案 0 :(得分:1)

我不是百分百肯定,但看起来原因在于你的两个变量:

  • tf.get_variable("name_a", initializer=[1,2,3])

  • tf.get_variable("name_a", shape=[3])

不等同,并且不能轻易地相互使用(更新dtype不同,感谢@BlueSun注意到这一点)

如果您在恢复代码中定义张量,就像保存:a = tf.get_variable("name_a", initializer=[1,2,3])一样,您将获得稳定的输出。但是,更好的方法是直接使用保存的图形:

saver = tf.train.import_meta_graph('log1/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "log1/model.ckpt")
  saved = sess.graph.get_tensor_by_name('name_a:0')
  print(sess.run(saved))

无论您如何定义初始化程序,它都能正常工作。

答案 1 :(得分:1)

您必须使用相同的数据类型定义变量a。如果您没有指定它并且没有任何初始值设定项,默认情况下dtype将为tf.float32,并且tf.int32的加载将失败。只需将数据类型设置为int32即可解决问题:

a = tf.get_variable("name_a", shape=[3], dtype=tf.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "log1/model.ckpt")
    print(a.eval())

使用a = tf.get_variable("name_a", initializer=tf.zeros([3]))有效,因为tf.zeros([3])[2, 3, 4]的dtype相同。每当您创建变量时始终设置dtype会更安全。