TensorFlow服务:如何获得预测,缺少ModelSpec

时间:2018-03-22 09:39:22

标签: python tensorflow tensorflow-serving

我创建了一个简单的多层感知器来测试TensorFlow服务。它只需要两个数字,并应预测这两个数字的总和。这是我写的代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import logging
import numpy as np
import tensorflow as tf

tf.app.flags.DEFINE_integer("model_version", 1, "version number of the model.")
tf.app.flags.DEFINE_string("work_dir", "/tmp", "Working directory.")
FLAGS = tf.app.flags.FLAGS

def main():

    # DATA PREPARATION
    # ================================================
    data_train = np.array(
        [[1, 3], [2, 5],
         [3, 1], [3, 3],
         [4, 2], [7, 1],
         [8, 1], [2, 2],
         [5, 1], [1, 7],
         [0, 1], [0, 5],
         [0, 7], [0, 8],
         [1, 1], [1, 2],
         [0, 0], [1, 8]]
    )

    labels_train = np.array(
        [4, 7, 4, 6, 6, 8, 9, 4, 6, 8, 1, 5, 7, 8, 2, 3, 0, 9]
    )

    # specify that all features have real-value data
    feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]

    # CREATE CLASSIFIER
    # ================================================
    # build 3 layer dnn with 10, 20, 10 units respectively.
    classifier = tf.estimator.DNNClassifier(
        feature_columns=feature_columns,
        hidden_units=[10, 20, 10],
        n_classes=10,
        model_dir="/tmp/addition_model"
    )

    # TRAINING PHASE
    # ================================================
    # define the training inputs
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": data_train},
        y=labels_train,
        num_epochs=None,
        shuffle=True
    )

    # train model
    classifier.train(input_fn=train_input_fn, steps=10000)

    # PREDICT NEW SAMPLES (LOCAL)
    # ================================================
    # new_samples = np.array(
    #     [[0, 6]],
    #     dtype=np.float32
    # )
    #
    # predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    #     x={"x": new_samples},
    #     num_epochs=1,
    #     shuffle=False
    # )
    #
    # predictions = list(classifier.predict(input_fn=predict_input_fn))
    # predicted_classes = [p["classes"] for p in predictions]
    #
    # print("Predictions: {}".format(predicted_classes))

    # BUILD AND SAVE MODEL
    # ================================================
    export_path_base = sys.argv[-1]
    export_path = os.path.join(
        tf.compat.as_bytes(export_path_base),
        tf.compat.as_bytes(str(FLAGS.model_version))
    )

    feature_spec = {"x": tf.FixedLenFeature([2], tf.float32)}

    def serving_input_receiver_fn():
        serialized_tf_example = tf.placeholder(dtype=tf.string,
                                           shape=[None],
                                           name="input_tensors")
        receiver_tensors = {"inputs": serialized_tf_example}
        features = tf.parse_example(serialized_tf_example, feature_spec)

        return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

    classifier.export_savedmodel(export_path, serving_input_receiver_fn)

if __name__ == "__main__":
    logging.getLogger().setLevel(logging.INFO)
    main()

这段代码完美无缺,它给了我预测。下一步,我导出模型并将其部署在TensorFlow服务器上(也可以)。我遇到了客户端代码的一些问题:

from grpc.beta import implementations
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import classification_pb2

host_port = "localhost:9000"
host, port = host_port.split(":")
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

request = classification_pb2.ClassificationRequest()
example = request.input.example_list.examples.add()
example.features.feature["x"].float_list.value.extend([2, 6])

result = stub.Classify(request, 10.0)  # 10 secs timeout
print(result)

在这种情况下,我想预测2 + 6的总和。但它一直告诉我在调用stub.Classify(请求,10.0)时缺少ModelSpec,但我不知道如何指定它(和我在文档中找不到任何内容):

grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="Missing ModelSpec")

有人有想法吗? 我也不确定出口部分是否正确......我们非常感谢任何改进建议。

非常感谢您的支持

1 个答案:

答案 0 :(得分:1)

我刚刚找到解决方案,我有点盲目:D

有时它就像那样简单:

request.model_spec.name = "addition"