如何在TensorFlow中使用serve_input_receiver_fn解析推理csv示例

时间:2018-08-24 09:24:19

标签: python tensorflow

在这里,我制作了一个tf.estimator.Estimator模型,该模型总共具有21种浮点型特征,目前正努力通过还原的SavedModel进行预测:

def serving_input_receiver_fn():
    """build the serving inputs."""

    example_bytestring = tf.placeholder(
        shape=[None],
        dtype=tf.string,  
        name='input_example_tensor')

    features = tf.parse_example(
        serialized=example_bytestring,
        features=tf.feature_column.make_parse_example_spec(feature_columns))

    return tf.estimator.export.ServingInputReceiver(
        features, {'samples': example_bytestring})


# export_savemodel
export_dir = lr_classifier_test.export_savedmodel(
    export_dir_base=FLAGS.export_dir_base,
    serving_input_receiver_fn=serving_input_receiver_fn)

text_example = [0.8461,4.74107,7235,-3.5643,2.0,127,6336.44,118.653,0.002786,0.00151,-0.00063594,-0.0005746,1.00103,1.00118,0.56842,0.087231,136.32418,39534.0136,0.06375,1490,5.76061]

# restore
g1 = tf.Graph()

with tf.Session(graph = g1) as sess:
    meta_graph_def = tf.saved_model.loader.load(
        sess=sess, 
        tags=[tf.saved_model.tag_constants.SERVING], 
        export_dir='tf_exports/1535096104')

    graph_def_nodes = [n.name for n in meta_graph_def.graph_def.node]
    input_example = sess.graph.get_tensor_by_name('input_example_tensor:0')
    W = sess.graph.get_tensor_by_name('logits/kernel:0')
    proba = sess.graph.get_tensor_by_name('prediction/probabilities:0')
#     print(sess.run(W))
    print(sess.run(proba, {input_example: text_example}))

但是会引发错误:

InvalidArgumentError: Could not parse example input, value: '0.8461,4.74107,7235,-3.5643,2.0,127,6336.44,118.653,0.002786,0.00151,-0.00063594,-0.0005746,1.00103,1.00118,0.56842,0.087231,136.32418,39534.0136,0.06375,1490,5.76061' [[Node: ParseExample/ParseExample = ParseExample[Ndense=22, Nsparse=0, Tdense=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], 

如何解析包含21个浮点特征的示例并对其进行预测?

0 个答案:

没有答案