关闭和重新打开会话时,Tensorflow恢复无法找到张量

时间:2016-09-16 02:54:07

标签: tensorflow

我正在训练一个看起来像这样的模型:

# Training the model
relevant_tensors = build_model(params)
with tf.Sesion() as sess:
    # do steps for model training...
    saver = tf.train.Saver()
    saver.save(sess, "mymodel.ckpt")

# Prediction
relevant_tensors = build_model(params)  # params here the same as previously, so we build the same model
with tf.Session() as sess:              # New session
    saver = tf.train.Saver()
    saver.load(sess, "mymodel.ckpt")

当我尝试按该顺序加载检查点时,我收到Not Found错误。我正在使用函数build_model以相同的方式构建图形,并且我已经验证两个调用的参数是相同的。

如果我注释掉训练步骤,预测步骤将从上一次运行加载模型就好了。但是当我尝试执行这两个步骤时,我在加载检查点时失败了。

有人在这看到逻辑吗?

1 个答案:

答案 0 :(得分:1)

看起来您正在同一个(默认)tf.Graph中构建模型两次,因此在第二次调用build_model()时创建的节点将获得不同的名称,这些名称与检查点中变量的名称。

一个简单的解决方案是为训练和预测创建不同的tf.Graph个对象,例如:

with tf.Graph().as_default():  # One graph for training the model...
    relevant_tensors = build_model(params)
    with tf.Session() as sess:
        # do steps for model training...
        saver = tf.train.Saver()
        saver.save(sess, "mymodel.ckpt")

with tf.Graph().as_default():  # Another graph for prediction....
    relevant_tensors = build_model(params) 
    with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, "mymodel.ckpt")

另一种方法是更改​​build_model(),以便为培训和预测构建单个图表,然后您可以为这两个任务使用相同的图形和会话。