tensorflow仅保存初始值

时间:2016-12-20 21:16:45

标签: python machine-learning tensorflow

我正在尝试保存一些变量,看看以后是否可以恢复它。 这是我的保存代码:

   import tensorflow as tf;
   my_a = tf.Variable(2,name = "my_a");
   my_b = tf.Variable(3,name = "my_b");
   my_c = tf.Variable(4,name = "my_c");
   my_c = tf.add(my_a,my_b);

   with tf.Session() as sess:
       init = tf.initialize_all_variables();
       sess.run(init);
       print("my_c =  ",sess.run(my_c));
       saver = tf.train.Saver();
       saver.save(sess,"test.ckpt");

打印出来:

    my_c =   5

当我恢复它时:

   import tensorflow as tf;
   c = tf.Variable(3100,dtype = tf.int32);
   with tf.Session() as sess:
       sess.run(tf.initialize_all_variables());
       saver = tf.train.Saver({"my_c":c});
       saver.restore(sess, "test.ckpt");
       cc= sess.run(c);
       print(cc);

这给了我:

    4

my_c的恢复值应为5,因为它是my_a和my_b的总和。但是它给了我4,这是my_c的初始化值。任何人都可以解释为什么会发生这种情况,以及如何将更改保存到变量?

1 个答案:

答案 0 :(得分:2)

在原始代码中,您尚未真正将名为my_c的变量(请注意,TensorFlow name)分配给my_a + my_b

通过编写my_c = tf.add(my_a,my_b),python变量my_c现在与tf.Variable name='my_c'不同。

执行sess.run()时,您只是执行操作,而不是更新该变量。

如果您希望此代码正确运行,请使用此代码 - (请参阅更改注释)

import tensorflow as tf
my_a = tf.Variable(2,name = "my_a")
my_b = tf.Variable(3,name = "my_b")
my_c = tf.Variable(4,name="my_c")
# Use the assign() function to set the new value
add = my_c.assign(tf.add(my_a,my_b))

with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    # Execute the add operator
    sess.run(add)
    print("my_c =  ",sess.run(my_c))
    saver = tf.train.Saver()
    saver.save(sess,"test.ckpt")