为什么tensorflow train.Saver()保存初始变量值而不是修改值?

时间:2017-06-03 10:33:52

标签: tensorflow

我有以下代码来初始化变量v3和v4。在初始化变量v3和v4之后,我正在修改这些变量并将其保存在检查点文件中:

import tensorflow as tf
sess = tf.Session()
v3 = tf.Variable(tf.random_uniform([4,2]), name="v3")
init_op = tf.global_variables_initializer()
sess.run(init_op)
with sess.as_default():
    print(v3.eval())
    v3 = tf.transpose(v3)
    print(v3.eval())
    sess.run(v3)
    v4 = tf.Variable(v3+3, name="v4")
    v4 = v4 + 5
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    print(v4.eval())
    saver = tf.train.Saver()
    saver.save(sess, "pktest_ckpt")

这将打印以下v3值,转置v3和v4值:

[[ 0.90765333  0.61777163]
 [ 0.5102632   0.45610023]
 [ 0.36511779  0.5465256 ]
 [ 0.61696458  0.86357415]]
[[ 0.90765333  0.5102632   0.36511779  0.61696458]
 [ 0.61777163  0.45610023  0.5465256   0.86357415]]
[[ 8.96951866  8.24961662  8.30669975  8.54586029]
 [ 8.55886841  8.16989517  8.48039341  8.06889534]]

从检查点文件恢复变量后,我看到变量值是初始化的值而不是修改后的值:

tf.reset_default_graph()
mg = tf.train.import_meta_graph("pktest_ckpt.meta")
with tf.Session() as sess:
    for v in tf.global_variables():
        print(v)
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(sess, tf.train.latest_checkpoint("./"))  
    print(sess.run('v4:0'))

打印:

<tf.Variable 'v3:0' shape=(4, 2) dtype=float32_ref>
<tf.Variable 'v4:0' shape=(2, 4) dtype=float32_ref>
INFO:tensorflow:Restoring parameters from ./pktest_ckpt
[[ 3.75337863  3.52812386  3.97137022  3.76210618]
 [ 3.81927872  3.41938591  3.82610369  3.20377684]]

我的期望是获得v3的形状(2,4)和v4值     [[8.96951866 8.24961662 8.30669975 8.54586029]      [8.55886841 8.16989517 8.48039341 8.06889534]]

任何人都可以解释为什么会这样吗?

2 个答案:

答案 0 :(得分:0)

我想我得到了答案。我应该这样保存变量:

import tensorflow as tf
sess = tf.Session()
v3 = tf.Variable(tf.random_uniform([4,2]), name="v3")
init_op = tf.global_variables_initializer()
sess.run(init_op)
with sess.as_default():
    v4 = tf.Variable(v3+3, name="v4")
    init_op = tf.variables_initializer([v3,v4])
    sess.run(init_op)
    do_add = v4.assign(tf.add(v4,5))
    sess.run(do_add)

    print(v4.eval())
    saver = tf.train.Saver()
    saver.save(sess, "pktest_ckpt")

答案 1 :(得分:0)

这可能是因为你在第二次初始化之后并没有真正运行这些操作。尝试在第二个session.run([v3, v4])

之后添加session.run(init_op)