我使用tensorflow训练CNN模型进行角色分类。我使用tf.train.Saver()对象保存了我最好的模型。要在应用程序中进行分类,我使用一个函数,如下所示。
def classify_chars(images):
# Create the model
x = tf.placeholder(tf.float32, [None, 400])
# Build the graph for the deep net
y_conv, keep_prob = _deepnn(x)
# Define classification
letter_class = tf.argmax(y_conv, 1)
confidence = tf.reduce_max(tf.nn.softmax(y_conv), 1)
# Enable saving and loading of variables
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, os.path.join(M_PATH, "char_model.ckpt"))
print("Model restored.")
images_classes = letter_class.eval(feed_dict={x: images})
images_confidences = confidence.eval(feed_dict={x: images})
return images_classes, images_confidences
此函数加载已保存的模型并使用它来对函数输入进行分类。调用一次时,该功能与预期完全一致。但是,如果我在同一次执行期间多次调用它,则会失败,抛出:
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_9 not found in checkpoint
现在,如果在使用该功能时发生这种情况,这对我来说是有意义的,我认为我的模型保存可能有问题。但是在这里似乎更像是函数有状态持续存在于其之外,这阻止了第二次运行。但是,在查看我的代码时,我看不出这个状态是什么。我没有覆盖我的检查点文件,所以从理论上讲,加载它应该没有任何问题。
有谁知道我做错了什么?
答案 0 :(得分:0)
tf.saver将全局步骤添加到检查点名称,您可能应该检查是否存在问题
答案 1 :(得分:0)
由于未在函数调用之间重置默认图形而导致此问题。 TensorFlow已经有了一个图表,尝试再次加载相同的图表会导致此错误。这可以通过添加
来缓解tf.reset_default_graph()
恢复图表之前。