我正在使用flask运行Web服务器,当我尝试使用vgg16时出现错误,这是keras的全局变量'预先培训的VGG16型号。我不知道为什么这个错误上升或者它是否与Tensorflow后端有任何关系。 这是我的代码:
vgg16 = VGG16(weights='imagenet', include_top=True)
def getVGG16Prediction(img_path):
global vgg16
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
pred = vgg16.predict(x)
return x, sort(decode_predictions(pred, top=3)[0])
@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
uploaded_files = request.files.getlist("file[]")
for file in uploaded_files:
path = os.path.join(STATIC_PATH, file.filename)
pInput, result = getVGG16Prediction(path)
非常感谢任何评论或建议。谢谢。
答案 0 :(得分:2)
在this github issue上查看avital
的答案。在此引用相关部分:
在加载或构建模型后,立即保存TensorFlow图:
graph = tf.get_default_graph()
在另一个线程中(或者可能在异步事件处理程序中),执行:
global graph with graph.as_default(): (... do inference here ...)
我对此进行了一些修改,并将图形存储在我的应用程序的配置对象中,而不是将其设置为全局。
get_default_graph
的{{3}}解释了为什么这是必要的:
注意:默认图表是当前线程的属性。如果您创建一个新线程,并希望在该线程中使用默认图形,则必须在该线程的函数中显式添加g.as_default():.