如何从本地运行SavedModel的推理?

时间:2019-06-05 01:08:39

标签: python tensorflow

我想在本地运行模型。我正在尝试从网络课程中训练和预测模型:

https://github.com/GoogleCloudPlatform/tensorflow-without-a-phd/blob/master/tensorflow-planespotting/trainer_yolo/main.py

已使用上述代码训练了模型。这是YOLO对象检测模型,用于检测使用tf.estimator建造的飞机。使用提供的代码成功完成了培训,但是我不知道如何推断模型。

import tensorflow as tf

# DATA
DATA = './samples/airplane_sample.png'

# Model: This directory contains saved_model.pb and variables
SAVED_MODEL_DIR = './1559196417/'

def decode_image():
    img_bytes = tf.read_file(DATA)
    decoded = tf.image.decode_image(img_bytes, channels=3)
    return tf.cast(decoded, dtype=tf.uint8)

def main1():
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], SAVED_MODEL_DIR)
        img = decode_image()
        result = sess.run(['classes'], feed_dict={'input': img})
        print(result)

def main2():
    model = tf.contrib.predictor.from_saved_model(SAVED_MODEL_DIR)
    pred = model({'image_bytes': [decode_image()], 'square_size': [tf.placeholder(tf.int32)]})
    print(pred)

if __name__ == "__main__":
    main2()

以上是我编写的代码,但无效。甚至我都不知道有什么问题。输入类型不正确? API不当?你能给我一些建议吗?

2 个答案:

答案 0 :(得分:0)

也许这会起作用:

import tensorflow as tf
import cv2

with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('./1559196417/saved_model.pb', 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        image = cv2.imread('./samples/airplane_sample.png')

        rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        rgb_img_expanded = np.expand_dims(rgb_img, axis=0)

        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        classes = detection_graph.get_tensor_by_name('classes:0')
        result = sess.run([classes],feed_dict={image_tensor: rgb_img_expanded})

答案 1 :(得分:0)

首先在python之外的终端中运行saved_model_cli show --all --dir SAVED_MODEL_DIR,以检查保存的模型并检查其是否具有正确的标签,输入和输出。从那里开始,需要花费一些精力才能从API中获取必要的信息。

def extract_tensors(signature_def, graph):
    output = dict()

    for key in signature_def:
        value = signature_def[key]

        if isinstance(value, tf.TensorInfo):
            output[key] = graph.get_tensor_by_name(value.name)

    return output

def extract_tags(signature_def, graph):
    output = dict()

    for key in signature_def:
        output[key] = dict()
        output[key]['inputs'] = extract_tensors(
            signature_def[key].inputs, graph)
        output[key]['outputs'] = extract_tensors(
            signature_def[key].outputs, graph)

    return output

with tf.Session(graph=tf.Graph()) as session:
    serve = tf.saved_model.load(
        session, tags=['serve'], export_dir=SAVED_MODEL_DIR)

    tags = extract_tags(serve.signature_def, session.graph)
    model = tags['serving_default']

您可以从那里尝试print(model['inputs'], model['outputs'])来查看导出了哪些输入和输出,如果它们与saved_model_cli一致,如果您需要另一个标签,只需将其替换为serving_default