TensorFlow无法恢复现有变量的已保存值

时间:2017-03-21 03:26:46

标签: tensorflow

我尝试保存一个在初始化后值改变的变量。然后我恢复了模型,但恢复的值仍然是这个变量的初始值,为什么在保存模型时它不能恢复变量的当前值?

这是我的代码:

import tensorflow as tf
import numpy as np
a = tf.Variable(tf.zeros([3, 2]), name='a')
saver = tf.train.Saver()
tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=tf_config)
tf.global_variables_initializer().run()

print(a.eval())
#[[ 0.  0.]
# [ 0.  0.]
# [ 0.  0.]]
a = a + np.random.randint(0, 5, (3, 2))
print(a.eval())
#[[ 0.  3.]
# [ 2.  3.]
# [ 0.  3.]]
saver.save(sess, './model/model.ckpt')
a = a - np.random.randint(0, 5, (3, 2))
print(a.eval())
#[[-1.  3.]
# [ 2.  2.]
# [-4.  3.]]
saver.restore(sess, './model/model.ckpt')
print(a.eval())
#[[-1.  3.]
# [ 2.  2.]
# [-4.  3.]]

显然,这里恢复的价值不是我想要保存的价值。

1 个答案:

答案 0 :(得分:1)

您对TF的运作方式存在误解。您在第3行定义的TF变量在初始化后不会更改其值。

package sbp

import (
    "net/http"
    "time"

    "github.com/pressly/chi/render"
)

// AddInvestment Adds an investment
func AddInvestment(w http.ResponseWriter, r *http.Request) {
    ...
}

// GetInvestments Gets list of investments
func GetInvestments(w http.ResponseWriter, r *http.Request) {
    ...
}

这不会更改TF变量,它会向图形添加节点以创建张量,该张量是变量和您创建的numpy数组的总和。 TF变量的值仍为零,并且python变量现在指向作为总和的张量(因此python变量a = a + np.random.randint(0, 5, (3, 2)) 不再指向TF变量)。

要更改TF变量的值,请使用类似

的内容
a

这样,你用新值更新TF变量,python变量仍然指向TF变量。