如何在张量流中保存和恢复变量?
我遇到了问题。我的代码:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
save_path = saver.save(sess, 'model.ckpt')
print "model saved in file:", save_path
v1 = v1 + 1
print sess.run(v1)
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(v1)
结果:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 1. 1.]
[ 1. 1.]]
我希望得到:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 0. 0.]
[ 0. 0.]]
我犯了什么错误?
请帮助我理解。
答案 0 :(得分:4)
您的代码中存在两个主要问题:
行v1 = v1 + 1
创建一个新的TensorFlow Tensor
并将其绑定到Python变量v1
,但不会更改TensorFlow {{1}中的值您使用名称Variable
创建的。因此,当您稍后调用"v1"
时,您正在评估将原始变量加1的新张量,而不是从张量中读取值。
相反,要将变量添加到变量,您应该使用以下内容:
sess.run(v1)
tf.train.import_meta_graph()
调用重新创建原始图表,并在此过程中向图表添加新节点,包括新的tf.train.Saver
。当您尚未构建图形(或者没有程序可用于执行该图形)时,它非常有用。由于您已经构建了图表,因此只需使用increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)
。
以下程序应该产生您预期的行为:
saver.restore(sess, 'model.ckpt')
答案 1 :(得分:1)
虽然,所选答案告诉我们应该做些什么,但并不能解释为什么你得到了意想不到的答案。我正在为稍后来这里的人解释。
在Tensorflow中,如果你已经有了一个图表并且在保存之后再次导入相同的图形,那么你的图形操作将不会被替换,而是它们Tensorflow被设计为通过添加后缀如_1,_2等来创建新变量。例如,在您的情况下,在您执行之前: saver = tf.train.import_meta_graph(' model.ckpt.meta') saver.restore(sess,tf.train.latest_checkpoint(' ./')) 您的图表有一个名为v1的变量。导入相同的图形后,您的变量v1将不会被替换,而是将新的变量v1_1添加到图形中。因此,图表的大小将加倍。由于v1没有通过加载图表而改变,你仍然得到v1的旧值(全1)。
如果要重置图形,则必须使用tf.reset_default_graph() 在再次导入图表之前,如documentation中所述。如果你在此之后导入并打印v1,你将获得全0 0。
答案 2 :(得分:0)
docs可能会对此有所了解。我运行了一两个修改文件:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
tf.add_to_collection('v1', v1)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(v1)
save_path = saver.save(sess, 'model.ckpt')
print "model saved in file:", save_path
v1 = v1 + 1
print sess.run(v1)
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(v1)
注意tf.add_to_collection
电话。在那之后,我跑了这个:
import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(tf.get_collection('v1')[0])
输出:
[[ 0. 0.]
[ 0. 0.]]
看起来恢复的东西不会实际修改你当前的计算图,你需要使用集合来获得你想要的东西。