我正在尝试保存一些变量,看看以后是否可以恢复它。 这是我的保存代码:
import tensorflow as tf;
my_a = tf.Variable(2,name = "my_a");
my_b = tf.Variable(3,name = "my_b");
my_c = tf.Variable(4,name = "my_c");
my_c = tf.add(my_a,my_b);
with tf.Session() as sess:
init = tf.initialize_all_variables();
sess.run(init);
print("my_c = ",sess.run(my_c));
saver = tf.train.Saver();
saver.save(sess,"test.ckpt");
打印出来:
my_c = 5
当我恢复它时:
import tensorflow as tf;
c = tf.Variable(3100,dtype = tf.int32);
with tf.Session() as sess:
sess.run(tf.initialize_all_variables());
saver = tf.train.Saver({"my_c":c});
saver.restore(sess, "test.ckpt");
cc= sess.run(c);
print(cc);
这给了我:
4
my_c的恢复值应为5,因为它是my_a和my_b的总和。但是它给了我4,这是my_c的初始化值。任何人都可以解释为什么会发生这种情况,以及如何将更改保存到变量?
答案 0 :(得分:2)
在原始代码中,您尚未真正将名为my_c
的变量(请注意,TensorFlow name
)分配给my_a + my_b
。
通过编写my_c = tf.add(my_a,my_b)
,python变量my_c
现在与tf.Variable
name='my_c'
不同。
执行sess.run()
时,您只是执行操作,而不是更新该变量。
如果您希望此代码正确运行,请使用此代码 - (请参阅更改注释)
import tensorflow as tf
my_a = tf.Variable(2,name = "my_a")
my_b = tf.Variable(3,name = "my_b")
my_c = tf.Variable(4,name="my_c")
# Use the assign() function to set the new value
add = my_c.assign(tf.add(my_a,my_b))
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
# Execute the add operator
sess.run(add)
print("my_c = ",sess.run(my_c))
saver = tf.train.Saver()
saver.save(sess,"test.ckpt")