Tensorflow:保存和恢复变量问题

时间:2016-12-12 14:54:28

标签: python tensorflow

如何在张量流中保存和恢复变量?

我遇到了问题。我的代码:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    save_path = saver.save(sess, 'model.ckpt')
    print "model saved in file:", save_path
    v1 = v1 + 1
    print sess.run(v1)
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print sess.run(v1)

结果:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 1.  1.]
 [ 1.  1.]]

我希望得到:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 0.  0.]
 [ 0.  0.]]

我犯了什么错误?

请帮助我理解。

3 个答案:

答案 0 :(得分:4)

您的代码中存在两个主要问题:

  1. v1 = v1 + 1创建一个新的TensorFlow Tensor并将其绑定到Python变量v1,但不会更改TensorFlow {{1}中的值您使用名称Variable创建的。因此,当您稍后调用"v1"时,您正在评估将原始变量加1的新张量,而不是从张量中读取值。

    相反,要将变量添加到变量,您应该使用以下内容:

    sess.run(v1)
  2. tf.train.import_meta_graph()调用重新创建原始图表,并在此过程中向图表添加新节点,包括新的tf.train.Saver。当您尚未构建图形(或者没有程序可用于执行该图形)时,它非常有用。由于您已经构建了图表,因此只需使用increment_op = v1.assign_add(tf.ones([2, 2])) sess.run(increment_op)

  3. 以下程序应该产生您预期的行为:

    saver.restore(sess, 'model.ckpt')

答案 1 :(得分:1)

虽然,所选答案告诉我们应该做些什么,但并不能解释为什么你得到了意想不到的答案。我正在为稍后来这里的人解释。

在Tensorflow中,如果你已经有了一个图表并且在保存之后再次导入相同的图形,那么你的图形操作将不会被替换,而是它们Tensorflow被设计为通过添加后缀如_1,_2等来创建新变量。例如,在您的情况下,在您执行之前:     saver = tf.train.import_meta_graph(' model.ckpt.meta')     saver.restore(sess,tf.train.latest_checkpoint(' ./')) 您的图表有一个名为v1的变量。导入相同的图形后,您的变量v1将不会被替换,而是将新的变量v1_1添加到图形中。因此,图表的大小将加倍。由于v1没有通过加载图表而改变,你仍然得到v1的旧值(全1)。

如果要重置图形,则必须使用tf.reset_default_graph()  在再次导入图表之前,如documentation中所述。如果你在此之后导入并打印v1,你将获得全0 0。

答案 2 :(得分:0)

docs可能会对此有所了解。我运行了一两个修改文件:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

tf.add_to_collection('v1', v1)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print sess.run(v1)
    save_path = saver.save(sess, 'model.ckpt')
    print "model saved in file:", save_path
    v1 = v1 + 1
    print sess.run(v1)
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print sess.run(v1)

注意tf.add_to_collection电话。在那之后,我跑了这个:

import tensorflow as tf

sess = tf.Session()
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(tf.get_collection('v1')[0])

输出:

[[ 0.  0.]
 [ 0.  0.]]

看起来恢复的东西不会实际修改你当前的计算图,你需要使用集合来获得你想要的东西。