Tensorflow模型无法加载多次

时间:2017-04-27 22:58:45

标签: tensorflow

我使用tensorflow训练CNN模型进行角色分类。我使用tf.train.Saver()对象保存了我最好的模型。要在应用程序中进行分类,我使用一个函数,如下所示。

def classify_chars(images):
    # Create the model
    x = tf.placeholder(tf.float32, [None, 400])

    # Build the graph for the deep net
    y_conv, keep_prob = _deepnn(x)

    # Define classification
    letter_class = tf.argmax(y_conv, 1)
    confidence = tf.reduce_max(tf.nn.softmax(y_conv), 1)

    # Enable saving and loading of variables
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, os.path.join(M_PATH, "char_model.ckpt"))
        print("Model restored.")

        images_classes = letter_class.eval(feed_dict={x: images})
        images_confidences = confidence.eval(feed_dict={x: images})

    return images_classes, images_confidences

此函数加载已保存的模型并使用它来对函数输入进行分类。调用一次时,该功能与预期完全一致。但是,如果我在同一次执行期间多次调用它,则会失败,抛出:

tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_9 not found in checkpoint

现在,如果在使用该功能时发生这种情况,这对我来说是有意义的,我认为我的模型保存可能有问题。但是在这里似乎更像是函数有状态持续存在于其之外,这阻止了第二次运行。但是,在查看我的代码时,我看不出这个状态是什么。我没有覆盖我的检查点文件,所以从理论上讲,加载它应该没有任何问题。

有谁知道我做错了什么?

2 个答案:

答案 0 :(得分:0)

tf.saver将全局步骤添加到检查点名称,您可能应该检查是否存在问题

答案 1 :(得分:0)

由于未在函数调用之间重置默认图形而导致此问题。 TensorFlow已经有了一个图表,尝试再次加载相同的图表会导致此错误。这可以通过添加

来缓解
tf.reset_default_graph()

恢复图表之前。