我正在训练一个看起来像这样的模型:
# Training the model
relevant_tensors = build_model(params)
with tf.Sesion() as sess:
# do steps for model training...
saver = tf.train.Saver()
saver.save(sess, "mymodel.ckpt")
# Prediction
relevant_tensors = build_model(params) # params here the same as previously, so we build the same model
with tf.Session() as sess: # New session
saver = tf.train.Saver()
saver.load(sess, "mymodel.ckpt")
当我尝试按该顺序加载检查点时,我收到Not Found错误。我正在使用函数build_model以相同的方式构建图形,并且我已经验证两个调用的参数是相同的。
如果我注释掉训练步骤,预测步骤将从上一次运行加载模型就好了。但是当我尝试执行这两个步骤时,我在加载检查点时失败了。
有人在这看到逻辑吗?
答案 0 :(得分:1)
看起来您正在同一个(默认)tf.Graph
中构建模型两次,因此在第二次调用build_model()
时创建的节点将获得不同的名称,这些名称与检查点中变量的名称。
一个简单的解决方案是为训练和预测创建不同的tf.Graph
个对象,例如:
with tf.Graph().as_default(): # One graph for training the model...
relevant_tensors = build_model(params)
with tf.Session() as sess:
# do steps for model training...
saver = tf.train.Saver()
saver.save(sess, "mymodel.ckpt")
with tf.Graph().as_default(): # Another graph for prediction....
relevant_tensors = build_model(params)
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "mymodel.ckpt")
另一种方法是更改build_model()
,以便为培训和预测构建单个图表,然后您可以为这两个任务使用相同的图形和会话。