keras无法多次调用model.predict_classes

时间:2018-11-18 16:28:02

标签: python tensorflow machine-learning keras prediction

def predictOne(imgPath):

    model = load_model("withImageMagic.h5")
    image = read_image(imgPath)
    test_sample = preprocess(image)
    predicted_class = model.predict_classes(([test_sample]))
    return predicted_class

我已经训练了一个模型。在此功能中,我加载模型,读取新图像,进行一些预处理并最终预测其标签。

当我运行main.py文件时,将调用此函数,一切都会顺利进行。但是,几秒钟后,此函数将再次用另一张图像调用,并且出现此错误:

'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
     

TypeError:无法将feed_dict键解释为张量:Tensor Tensor(“ Placeholder:0”,shape =(5,5,1,32),dtype = float32)不是该图的元素。

该功能仅在第一次使用时很奇怪。我测试了多张图像,并得到了相同的行为。

Windows 10-具有keras的tensorflow-gpu

2 个答案:

答案 0 :(得分:2)

尝试从函数外部的文件加载模型,并将模型对象作为函数library.zip的参数。这也将更快,因为不需要在每次需要预测时从磁盘加载权重。

如果要继续在函数内加载模型,请导入后端:

def predictOne(imgPath, model)

然后

from keras import backend as K

在加载模型之前。

答案 1 :(得分:0)

class one_model:
    session = None
    graph = None 
    loadModel = None
    __instance = None
    @staticmethod
    def getInstance(modelPath):
        """ Static access method. """
        if one_model.__instance == None:
            one_model.__instance = one_model(modelPath)
        return one_model.__instance
        
    def __init__(self, modelPath):
        self.modelPath = modelPath
        self.session = tf.Session(graph=tf.Graph())
        self.loadOneModel()
            
    def loadOneModel(self):
        try:
            with self.session.graph.as_default():
                K.set_session(self.session)
                self.loadModel = keras.models.load_model(self.modelPath)               
        except Exception as e:
            logging.error(str(e))
            print(str(e))
                        
    def getPredictionOne(self, input_file_path): 
        #Predict the data once the model is loaded
        if self.loadModel is not None and self.session is not None: 
            try:
                image = load_img(input_file_path, target_size=inputShape)
                image = img_to_array(image)
                image = np.expand_dims(image, axis=0)
                image = preprocess(image)
                with self.session.graph.as_default():
                    K.set_session(self.session)
                    preds = self.loadModel.predict(image)
                    return preds
            except Exception as e:
                logging.error(str(e))
        
        return -1


if __name__== "__main__": 
    #First Model 
    data = web.input()
        fileapth = data.imagefilepath  
        modelfilepath = data.modelfilepath
        one_modelObj = one_model.getInstance(modelfilepath)        
        value = one_modelObj.getPredictionOne(fileapth)