如何运行gRPC客户端来测试tensorflow模型?

时间:2018-08-21 15:25:59

标签: python tensorflow machine-learning deep-learning

我已经在这里停留了一段时间。我创建了一个简单的神经网络,可以从视频游戏销售数据集中预测收入。在训练了2000个时期之后,我将模型导出为:

    model_builder = tf.saved_model.builder.SavedModelBuilder("exported_model/001")

    inputs = {
        'input': tf.saved_model.utils.build_tensor_info(X)
    }

    outputs = {
        'earnings': tf.saved_model.utils.build_tensor_info(prediction)
    }

    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs,
        outputs=outputs,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )

    model_builder.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        }
    )

    model_builder.save()

保存的模型如下:

enter image description here

然后我尝试通过运行以下命令,使用tensorflow model serve在本地提供模型:

tensorflow_model_server --port=4000 --model_name=mymodel --model_base_path=/home/suhail/tensorflow-stubs/exported_model

这以2018-08-21 20:31:56.623311: I tensorflow_serving/model_servers/main.cc:327] Running ModelServer at 0.0.0.0:4000 ...的日志启动了模型服务器

现在我正尝试通过使用gRPC客户端(如此处tfserving-python-predict-client

所述)来进行预测

这是我的预报。

import numpy as np
from predict_client.prod_client import ProdClient
import random

HOST = '0.0.0.0:4000'
# a good idea is to place this global variables in a shared file
MODEL_NAME = 'mymodel'
MODEL_VERSION = 1

client = ProdClient(HOST, MODEL_NAME, MODEL_VERSION)

req_data = [{'in_tensor_name': 'inputs', 'in_tensor_dtype': 'DT_FLOAT', 'data': np.random.random_integers(1,200, size=(1,12))}]

prediction = client.predict(req_data, request_timeout=10)

print(prediction)

但是在运行predict.py时出现错误消息:

<_Rendezvous of RPC that terminated with (StatusCode.INVALID_ARGUMENT,
   input tensor alias not found in signature: inputs. Inputs expected
       to be in the set {input}.)>

Prediction failed!
{}

这是什么错误?我在这里做什么错了?

这是导出模型的complete code for my training script

1 个答案:

答案 0 :(得分:0)

只是简单的检查:您是否尝试过将'in_tensor_name'替换为'input'而不是'inputs'?输入的名称似乎不正确。

顺便说一句,我前一阵子已经实现了一些TF Serving的客户端,您可以使用PredictRequest对象,它看起来更直观,恕我直言。