来自本地保存模型的tensorflow.contrib.predictor

时间:2018-12-18 19:03:11

标签: python-3.x tensorflow machine-learning deep-learning tensorflow-serving

如何使用tensorflow.contrib.predictor本地预测以下保存的模型,

在通过docker容器运行此文件之前,我要确保一切正常。

import tensorflow as tf
import numpy as np
import os

rng = np.random

with tf.Session() as sess:
    serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
    feature_configs = {'X': tf.FixedLenFeature(shape=[], dtype=tf.float32), }
    tf_example = tf.parse_example(serialized_tf_example, feature_configs)
    input_tensor = tf.identity(tf_example['X'][0], name='X')

    W = tf.Variable(rng.randn(), name="weight")
    b = tf.Variable(rng.randn(), name="bias")
    pred = tf.add(tf.multiply(input_tensor, W), b, name="pred")

    saver = tf.train.Saver()
    saver.restore(sess, os.path.join('ModelOutput', "model.ckpt-1001"))
    print("Model","restored.")

    builder = tf.saved_model.builder.SavedModelBuilder('Ser_Model/1/')
    tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
    tensor_info_output = tf.saved_model.utils.build_tensor_info(pred)

    prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input_x': tensor_info_input},
                outputs={'output_y': tensor_info_output},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict_output':
                    prediction_signature,
            })

    builder.save(as_text=True)
    print('Done exporting!')

下面是代码段,我正在尝试通过提供输入内容进行检查,但我做错了事,无法找出问题所在

from tensorflow.contrib import predictor

predict_fn = predictor.from_saved_model(export_dir='Ser_Model/1/',
                                        signature_def_key='predict_output',
                                        signature_def='predict_output',
                                        input_names='{input_x: [10]}',
                                        output_names='output_y')

0 个答案:

没有答案