我正在使用keras,flask开发用于预测图像标签的转移学习模型网络api。由于图像大约有10个标签。我需要加载每个标签模型以预测标签。我想在json请求之前加载模型,因为keras模型加载需要很多时间。但是在获取json请求后,使用加载的模型预测图像标签时会出错。
我的代码:
img_dim = (299, 299, 3)
img_size = (299, 299)
num_label = 2
model_image = load_resnet_model()
industry = '100'
lst_model = []
for t_image in lst_main_image:
lst_model = []
ml = model_image
for i in range(1, 6):
ml.load_weights('../model/{}/main_image/{}_aug_inception.fold_{}{}.hdf5'.format(industry, industry, i, t_image))
ml.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
lst_model.append(ml)
print('finish image:', t_image)
dict_model[t_image] = lst_model
def load_resnet_model():
img_dim = (299, 299, 3)
img_size = (299, 299)
num_label = 2
print('begin to get model')
input_tensor = Input(shape=img_dim)
base_model = InceptionResNetV2(include_top=False, input_shape=img_dim, weights='imagenet')
x = input_tensor
x = Lambda(preprocess_input, name='preprocessing')(x)
x = base_model(x)
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(num_label, activation='softmax', name='softmax')(x)
model_image = Model(input_tensor, x)
print('finish loading model')
return model_image
@app.route("/api/", methods=["POST"])
def predict_tag():
print('beginning to prediction')
# model_image = load_resnet_model(img_dim, num_label)
len_test = validation_batch.shape[0]
for t in lst_main_image:
n_fold = 5
preds_test = np.zeros((len_test, 2), dtype=np.float)
print('t_image:', t)
tag_i_time = time.time()
lst_t_model = dict_model[t]
for m in lst_t_model:
test_prob = m.predict(validation_batch)
preds_test += test_prob
tag_i_e = time.time()
print('each tag the times:', t, tag_i_e - tag_i_time)
preds_test /= n_fold
y_pred = preds_test.argmax(axis=-1)
lst_result_image.append(list(y_pred))
print('finish predict the tag:', t)
但是,当我将加载模型放入predict_tag函数中时,没有这种错误。我想在发送请求之前放置加载模型。
127.0.0.1 - - [09/Jul/2019 15:15:38] "POST /api/ HTTP/1.1" 500 -
Traceback (most recent call last):
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1997, in __call__
return self.wsgi_app(environ, start_response)
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1985, in wsgi_app
response = self.handle_exception(e)
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1540, in handle_exception
reraise(exc_type, exc_value, tb)
File "/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise
raise value
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1982, in wsgi_app
response = self.full_dispatch_request()
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1614, in full_dispatch_request
rv = self.handle_user_exception(e)
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1517, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise
raise value
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1612, in full_dispatch_request
rv = self.dispatch_request()
File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1598, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "/Users/k.den/PycharmProjects/Banner-Tag-API/src/app792.py", line 188, in predict_tag
test_prob = m.predict(validation_batch)
File "/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1164, in predict
self._make_predict_function()
File "/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 554, in _make_predict_function
**kwargs)
File "/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2744, in function
return Function(inputs, outputs, updates=updates, **kwargs)
File "/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2546, in __init__
with tf.control_dependencies(self.outputs):
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 5028, in control_dependencies
return get_default_graph().control_dependencies(control_inputs)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4528, in control_dependencies
c = self.as_graph_element(c)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3478, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3557, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)