时代深度学习模型flask API:ValueError:Tensor Tensor(“ softmax / Softmax:0”,shape =(?, 2),dtype = float32)不是此图的元素

时间:2019-07-09 06:26:15

标签: python flask keras transfer-learning

我正在使用keras,flask开发用于预测图像标签的转移学习模型网络api。由于图像大约有10个标签。我需要加载每个标签模型以预测标签。我想在json请求之前加载模型,因为keras模型加载需要很多时间。但是在获取json请求后,使用加载的模型预测图像标签时会出错。

我的代码:

img_dim = (299, 299, 3)
img_size = (299, 299)
num_label = 2

model_image = load_resnet_model()

industry = '100'
lst_model = []

for t_image in lst_main_image:
    lst_model = []
    ml = model_image
    for i in range(1, 6):
        ml.load_weights('../model/{}/main_image/{}_aug_inception.fold_{}{}.hdf5'.format(industry, industry, i, t_image))
        ml.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
        lst_model.append(ml)
    print('finish image:', t_image)
    dict_model[t_image] = lst_model


def load_resnet_model():

    img_dim = (299, 299, 3)
    img_size = (299, 299)
    num_label = 2
    print('begin to get model')
    input_tensor = Input(shape=img_dim)
    base_model = InceptionResNetV2(include_top=False, input_shape=img_dim, weights='imagenet')
    x = input_tensor
    x = Lambda(preprocess_input, name='preprocessing')(x)
    x = base_model(x)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.5)(x)
    x = Dense(num_label, activation='softmax', name='softmax')(x)
    model_image = Model(input_tensor, x)
    print('finish loading model')

    return  model_image

@app.route("/api/", methods=["POST"])
def predict_tag():
    print('beginning to prediction')
     # model_image = load_resnet_model(img_dim, num_label)
    len_test = validation_batch.shape[0]
    for t in lst_main_image:
        n_fold = 5
        preds_test = np.zeros((len_test, 2), dtype=np.float)
        print('t_image:', t)
        tag_i_time = time.time()
        lst_t_model = dict_model[t]
        for m in lst_t_model:
            test_prob = m.predict(validation_batch)
            preds_test += test_prob
        tag_i_e = time.time()
        print('each tag the times:', t, tag_i_e - tag_i_time)
        preds_test /= n_fold
        y_pred = preds_test.argmax(axis=-1)
        lst_result_image.append(list(y_pred))
        print('finish predict the tag:', t)

但是,当我将加载模型放入predict_tag函数中时,没有这种错误。我想在发送请求之前放置加载模型。

127.0.0.1 - - [09/Jul/2019 15:15:38] "POST /api/ HTTP/1.1" 500 -
Traceback (most recent call last):
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1997, in __call__
    return self.wsgi_app(environ, start_response)
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1985, in wsgi_app
    response = self.handle_exception(e)
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1540, in handle_exception
    reraise(exc_type, exc_value, tb)
  File "/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise
    raise value
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1982, in wsgi_app
    response = self.full_dispatch_request()
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1614, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1517, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/anaconda3/lib/python3.6/site-packages/flask/_compat.py", line 33, in reraise
    raise value
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1612, in full_dispatch_request
    rv = self.dispatch_request()
  File "/anaconda3/lib/python3.6/site-packages/flask/app.py", line 1598, in dispatch_request
    return self.view_functions[rule.endpoint](**req.view_args)
  File "/Users/k.den/PycharmProjects/Banner-Tag-API/src/app792.py", line 188, in predict_tag
    test_prob = m.predict(validation_batch)
  File "/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1164, in predict
    self._make_predict_function()
  File "/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 554, in _make_predict_function
    **kwargs)
  File "/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2744, in function
    return Function(inputs, outputs, updates=updates, **kwargs)
  File "/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2546, in __init__
    with tf.control_dependencies(self.outputs):
  File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 5028, in control_dependencies
    return get_default_graph().control_dependencies(control_inputs)
  File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4528, in control_dependencies
    c = self.as_graph_element(c)
  File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3478, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3557, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)

0 个答案:

没有答案