如何让tensorflow模型持续加载?

时间:2017-10-27 06:41:46

标签: python tensorflow

现在我想写两个函数:

  • 1用于加载我已经训练的模型,
  • 第二个是使用该模型进行分类。

但是这两个函数都需要相同的会话,所以我将会话作为参数,以便将它播种到下一个函数。但是我收到了一个错误。

这是我的代码。第一种方法是加载模型,第二种方法是使用模型来预测某些东西,但是在初始化会话时我遇到了一些问题

def callmodel():
    with tf.Graph().as_default():
        #saver = tf.train.Saver()
        model_path = 'E:/MyProject/MachineLearning/callTFModel/model/'
        ckpt = tf.train.get_checkpoint_state(model_path)
        sess = tf.Session()
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(model_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("load model successful!")
            return sess
        else:
            print("failed to load model!")


def test_one_image(sess,test_dir):
    global p, logits
    image = Image.open(test_dir)
    image = image.resize([32, 32])
    image_array = np.array(image)
    image = tf.cast(image_array, tf.float32)
    image = tf.reshape(image, [1, 32, 32, 3])  # 调整image的形状
    p = mmodel(image, 1)
    logits = tf.nn.softmax(p)
    x = tf.placeholder(tf.float32, shape=[32, 32, 3])
    prediction = sess.run(logits, feed_dict={x: image_array})
    max_index = np.argmax(prediction)
    if max_index == 0:
        print('probability of good: %.6f' % prediction[:, 0])
    else:
        print('probability of Lack of glue: %.6f' % prediction[:, 1])



#######//test
sess=callmodel
path="c:/test/1001.jpg"
test_one_image(sess,path)

it occurs error:

 File "E:/MyProject/python/C+pythonModel/test.py", line 175, in <module>
    test_one_image(sess,path)
  File "E:/MyProject/python/C+pythonModel/test.py", line 164, in test_one_image
    prediction = sess.run(logits, feed_dict={x: image_array})
  File "D:\study\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 895, in run
    run_metadata_ptr)
  File "D:\study\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1071, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(32, 32, 3), dtype=float32) is not an element of this graph.

1 个答案:

答案 0 :(得分:0)

问题不在于使用会话作为参数,而在于如何恢复图表的输入和输出节点:当你写的时候

p = mmodel(image, 1)
logits = tf.nn.softmax(p)
x = tf.placeholder(tf.float32, shape=[32, 32, 3])

您没有恢复会话图中的相应节点,而是创建新节点。你应该使用:

x= sess.graph().get_tensor_by_name("your_x_placeholder_name")
logits= sess.graph().get_tensor_by_name("your_logits_placeholder_name")

然后prediction = sess.run(logits, feed_dict={x: image_array})

另外,您可能需要检查imageimage_array之间是否有任何错误(现在您正在重塑image,而不是数组,如果没有用,如果你用image_array ...)

喂食