我使用固定估计器训练了线性分类器。它具有6个输入功能(tf.feature_column.numeric_column
)
并执行二进制分类。
我使用导出了模型
estimator.export_savedmodel
,但我仍然不完全了解如何正确设置serving_input_receiver_fn
函数以请求对新示例(例如6x1的数组)进行预测。
到目前为止,我已经定义了以下函数:
def serving_input_receiver_fn():
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=6)
receiver_tensors = {'input_feature': serialized_tf_example}
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
执行export_savedmodel
函数似乎成功,我得到了输出:
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['classification', 'serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from graphs/linear\model.ckpt-3000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: export/linear_trial\temp-b'1536667847'\saved_model.pb
当我请求使用gRPC协议(例如data= [1,0,0.5,1,0.5,1]
)对新示例进行预测时,就会出现问题。我非常简单地创建了 channel,stub 和 credentials ,但是我没有通过 request 对其进行管理,
这里的代码:
data_str = [str(i) for i in data]
request = predict_pb2.PredictRequest()
request.model_spec.name = MODEL_NAME
request.model_spec.signature_name = 'predict'
request.inputs['input_feature'].CopyFrom( tf.contrib.util.make_tensor_proto(data_str, shape=[6]))
stub.Predict(request, 120)
和输出:
File "[...]\client_app\venv_trial\lib\site-packages\grpc\_channel.py", line 448, in _end_unary_response_blocking"
raise _Rendezvous(state, None, None, deadline)
grpc._channel._Rendezvous: <_Rendezvous of RPC that terminated with:
status = StatusCode.INTERNAL
details = "invalid header field value "Could not parse example input, value:
'1'\n\t [[Node: ParseExample/ParseExample = ParseExample[Ndense=6,
Nsparse=0, Tdense=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT,
DT_FLOAT, DT_FLOAT], _output_shapes=[[6,1], [6,1], [6,1],
[6,1], [6,1], [6,1]]
dense_shapes=[[1], [1], [1], [1], [1],
[1]], sparse_types=[],
_device=\"/job:localhost/replica:0/task:0/cpu:0\"](_arg_Placeholder_0_0,
ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/ParseExample/dense_keys_1, ParseExample/ParseExample/dense_keys_2, ParseExample/ParseExample/dense_keys_3, ParseExample/ParseExample/dense_keys_4,
ParseExample/ParseExample/dense_keys_5,
ParseExample/Const,
ParseExample/Const_1, ParseExample/Const_2, ParseExample/Const_3, ParseExample/Const_4,
ParseExample/Const_5)]]""
从我对输出的了解来看, tensor_proto 的格式似乎不正确,因此对模型特征的解析完全失败了。我非常感谢您的帮助。