如何使用tensorflow.contrib.predictor本地预测以下保存的模型,
在通过docker容器运行此文件之前,我要确保一切正常。
import tensorflow as tf
import numpy as np
import os
rng = np.random
with tf.Session() as sess:
serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
feature_configs = {'X': tf.FixedLenFeature(shape=[], dtype=tf.float32), }
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
input_tensor = tf.identity(tf_example['X'][0], name='X')
W = tf.Variable(rng.randn(), name="weight")
b = tf.Variable(rng.randn(), name="bias")
pred = tf.add(tf.multiply(input_tensor, W), b, name="pred")
saver = tf.train.Saver()
saver.restore(sess, os.path.join('ModelOutput', "model.ckpt-1001"))
print("Model","restored.")
builder = tf.saved_model.builder.SavedModelBuilder('Ser_Model/1/')
tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
tensor_info_output = tf.saved_model.utils.build_tensor_info(pred)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input_x': tensor_info_input},
outputs={'output_y': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_output':
prediction_signature,
})
builder.save(as_text=True)
print('Done exporting!')
下面是代码段,我正在尝试通过提供输入内容进行检查,但我做错了事,无法找出问题所在
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(export_dir='Ser_Model/1/',
signature_def_key='predict_output',
signature_def='predict_output',
input_names='{input_x: [10]}',
output_names='output_y')