我有一个Keras CNN模型,我正在使用该模型使用加速度计数据来预测活动。我正在尝试使用TF Serving tutorial为TensorFlow Serving保存此模型。我用以下代码保存了模型
sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)
x = model.input
y = model.output
prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": x}, {"prediction":y})
valid_prediction_signature = tf.saved_model.signature_def_utils.is_valid_signature(prediction_signature)
if(valid_prediction_signature == False):
raise ValueError("Error: Prediction signature not valid!")
export_path = 'models/cnn_v1/1'
# export_path is a directory in which the model will be created
builder = saved_model_builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
# Initialize global variables and the model
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Add the meta_graph and the variables to the builder
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature,
},
legacy_init_op=legacy_init_op)
# save the graph
builder.save()
与输出
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: models/cnn_v1/1/saved_model.pb
b'cnn_v1/saved_model.pb'
这是我的grpc客户代码
tf.app.flags.DEFINE_string('server', 'localhost:8500', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS
def main(_):
inputs = np.array([[-0.8099115393684377, 1.114737069020815, 0.11360972373296188],[-0.6940617608960894, 1.2277026768965522, 0.07879808254788291], [-0.645282906802469, 1.1942313856741114, -0.05174557189616266],[-0.8099115393684377, 1.1314727146320354, 0.043986441362803934], [-0.8099115393684377, 1.2779096137302133, 0.043986441362803934], [-0.7184511879428995, 1.2444383225077726, -0.025636841007354014], [-0.6940617608960894, 1.2109670312853318, -0.1561804954513996], [-0.7428406149897097, 1.164944005854476, -0.12136885426632062], [-0.79161946908333, 1.164944005854476, -0.05174557189616266], [-0.79161946908333, 1.2444383225077726, -0.05174557189616266], [-0.79161946908333, 1.2277026768965522, -0.12136885426632062], [-0.7672300420365199, 1.2109670312853318, -0.1561804954513996], [-0.8586903934620581, 1.1314727146320354, -0.22580377782155753], [-0.834300966415248, 1.2444383225077726, -0.1561804954513996], [-0.8099115393684377, 1.3783234873975356, -0.08655721308124163], [-0.6940617608960894, 1.3448521961750948, -0.19099213663647857], [-0.7184511879428995, 1.1816796514656964, -0.1561804954513996], [-0.9562481016492987, 1.2611739681189928, -0.19099213663647857], [-0.9562481016492987, 1.395059133008756, -0.2606154190066365], [-0.8099115393684377, 1.4243465128283916, -0.19099213663647857], [-0.9562481016492987, 1.3783234873975356, -0.2606154190066365], [-1.0538058098365388, 1.3448521961750948, -0.19099213663647857], [-0.980637528696109, 1.491289095273273, -0.22580377782155753], [-0.8830798205088682, 1.541496032106934, -0.1561804954513996], [-0.9074692475556785, 1.4745534496620525, -0.1561804954513996], [-0.980637528696109, 1.4076108672171712, -0.19099213663647857], [-0.9318586746024886, 1.441082158439612, -0.2606154190066365], [-0.9562481016492987, 1.395059133008756, -0.1561804954513996], [-1.0781952368833492, 1.4578178040508325, -0.12136885426632062], [-1.0050269557429192, 1.6084386145518155, -0.08655721308124163], [-0.980637528696109, 1.6209903487602308, -0.33023870137679445], [-1.0294163827897287, 1.4578178040508325, -0.33023870137679445], [-1.3403815776365586, 1.4745534496620525, -0.5304056381909984], [-1.3647710046833688, 1.5582316777181544, -0.5652172793760774], [-1.3159921505897483, 1.6084386145518155, -0.5304056381909984], [-1.1940450153556976, 1.671197285593892, -0.4955939970059194], [-1.1696555883088873, 1.7214042224275532, -0.4346736249320314], [-0.8830798205088682, 1.671197285593892, -0.5652172793760774], [-1.0538058098365388, 1.6084386145518155, -0.46948526611711033], [-1.3647710046833688, 1.7046685768163325, -0.5652172793760774], [-1.4562313561089064, 1.788346804872434, -0.4955939970059194], [-1.413549858776989, 1.788346804872434, -0.6696522029313142], [-1.1025846639301593, 1.8050824504836547, -0.4346736249320314], [-1.1025846639301593, 1.754875513649994, -0.5652172793760774], [-1.0050269557429192, 1.8678411215257313, -0.7740871264865508], [-0.9562481016492987, 1.7381398680387732, -1.1831239104112286], [-1.3403815776365586, 1.591702968940595, -1.3484792060403534], [-1.389160431730179, 1.771611159261214, -1.3223704751515444], [-1.1696555883088873, 1.8343698303032905, -1.2875588339664654], [-1.1696555883088873, 1.754875513649994, -1.2179355515963075], [-1.2184344424025078, 1.8050824504836547, -1.2875588339664654], [-1.389160431730179, 1.8050824504836547, -0.9742540633007547], [-1.4562313561089064, 1.771611159261214, -0.9046307809305971], [-1.4318419290620963, 1.771611159261214, -1.0438773456709127], [-1.2916027235429381, 1.9515193495818328, -1.0438773456709127], [-1.2184344424025078, 1.9682549951930528, -0.8785220500417877], [-1.389160431730179, 1.934783703970612, -0.9046307809305971], [-1.3403815776365586, 1.934783703970612, -0.9046307809305971], [-1.3403815776365586, 1.8845767671369513, -1.1135006280410706], [-1.3647710046833688, 1.8218180960948747, -1.1483122692261496], [-1.413549858776989, 1.7381398680387732, -1.2179355515963075], [-1.3647710046833688, 1.918048058359392, -1.2179355515963075], [-1.3403815776365586, 1.8845767671369513, -1.1483122692261496], [-1.5050102102025267, 1.918048058359392, -1.0090657044858338], [-1.0538058098365388, 1.8343698303032905, -1.0090657044858338], [-0.7428406149897097, 1.8511054759145105, -1.1483122692261496], [-1.0294163827897287, 1.6544616399826715, -1.1135006280410706], [-1.2184344424025078, 1.541496032106934, -0.8785220500417877], [-1.2184344424025078, 1.7046685768163325, -0.9394424221156761], [-0.9562481016492987, 1.6544616399826715, -0.9046307809305971], [-0.980637528696109, 1.7214042224275532, -0.7392754853014718], [-1.267213296496128, 1.441082158439612, -0.5652172793760774], [-1.1452661612620771, 1.6377259943714513, -0.6696522029313142], [-0.7672300420365199, 1.7046685768163325, -0.5304056381909984], [-0.645282906802469, 1.591702968940595, -0.5652172793760774], [-1.0781952368833492, 1.2779096137302133, -0.5652172793760774], [-1.3647710046833688, 1.4578178040508325, -0.6000289205611563], [-1.2184344424025078, 1.6084386145518155, -0.6696522029313142], [-0.8830798205088682, 1.491289095273273, -0.6000289205611563], [-0.8099115393684377, 1.5080247408844931, -0.5304056381909984], [-0.8099115393684377, 1.4076108672171712, -0.46948526611711033], [-0.79161946908333, 1.5749673233293746, -0.2954270601917155], [-0.79161946908333, 1.6084386145518155, -0.19099213663647857], [-0.8099115393684377, 1.3448521961750948, -0.2606154190066365], [-0.7672300420365199, 1.3783234873975356, -0.4955939970059194], [-1.120876734215267, 1.2611739681189928, -0.1561804954513996], [-1.0050269557429192, 1.4076108672171712, -0.4955939970059194], [-0.980637528696109, 1.4243465128283916, -0.2954270601917155], [-0.79161946908333, 1.2444383225077726, -0.6696522029313142], [-0.6940617608960894, 1.3281165505638746, -0.33023870137679445], [-0.79161946908333, 1.3281165505638746, -0.5652172793760774], [-0.9074692475556785, 1.2444383225077726, -0.19099213663647857], [-1.1452661612620771, 1.2779096137302133, -0.05174557189616266], [-0.8099115393684377, 1.5247603864957135, -0.19099213663647857], [-0.6696723338492792, 1.4578178040508325, -0.025636841007354014], [-0.5965040527088487, 1.2109670312853318, -0.19099213663647857], [-0.4989463445216086, 1.164944005854476, -0.1561804954513996], [-0.4074859930960703, 1.5080247408844931, 0.11360972373296188], [-0.2611494308152093, 1.4076108672171712, 0.3572912120285147], [-0.1209102252960507, 1.1942313856741114, 0.2528562884732778], [-0.02335251710881002, 1.0980014234095945, 0.18323300610311982], [-0.2611494308152093, 1.0143231953534928, 0.00917480017772496], [-0.5965040527088487, 1.114737069020815, -0.46948526611711033], [-0.645282906802469, 0.8846219418665351, -0.4346736249320314], [-0.4989463445216086, 0.7549206883795772, -0.8088987676716297], [0.23883382364439887, 0.9013575874777555, -0.7392754853014718], [0.19005496955077855, 0.7549206883795772, -1.1135006280410706], [-0.07213137120243035, 0.6210355234898145, -0.6348405617462353], [-0.18798114967477933, 0.39092039633553427, -0.5652172793760774], [-0.047741944155620726, 0.26121914284857667, -0.4346736249320314], [0.1229840451720499, 0.32816172529345816, -0.12136885426632062], [0.07420519107842957, 0.3114260796822378, 0.00917480017772496], [-0.1209102252960507, 0.12733397795881365, -0.05174557189616266], [-0.2611494308152093, 0.014368370083076223, -0.12136885426632062], [-0.28553885786202, 0.19427656040369515, -0.3650503425618734], [-0.047741944155620726, 0.32816172529345816, -0.8785220500417877], [0.19005496955077855, 0.2277478516261359, -0.6348405617462353], [0.45224131030398745, 0.14406962357003403, -0.3998619837469524], [0.3851703859252599, 0.26121914284857667, 0.043986441362803934], [0.43394924001888024, 0.2946904340710174, -0.12136885426632062], [0.4095598129720695, 0.2277478516261359, -0.19099213663647857], [0.3607809588784492, 0.1608052691812544, -0.2606154190066365], [0.1412761154571582, 0.11478224375039847, -0.3998619837469524], [0.16566554250396784, 0.04783966130551697, -0.3998619837469524], [0.19005496955077855, 0.031104015694296598, -0.3998619837469524]])
inputs = inputs.astype(float)
inputs = inputs.reshape((1,125,3))
print (inputs.dtype)
if inputs.shape != (1,125,3):
return
channel = grpc.insecure_channel(FLAGS.server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'cnn_v1'
request.model_spec.signature_name = 'serving_default'
request.inputs['inputs'].CopyFrom(tf.contrib.util.make_tensor_proto(inputs, shape=[1,125,3]))
result = stub.Predict(request, 30.0)
print (result)
if __name__ == '__main__':
tf.app.run()
基本上,我输入一个(None,125,3)形状的numpy数组,该数组由x,y,z轴上的值组成,并以25Hz捕获5秒钟。但是运行grpc客户端时出现以下错误
File "activity_tf_client.py", line 34, in <module>
tf.app.run()
File "/anaconda3/envs/tf2/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "activity_tf_client.py", line 30, in main
result = stub.Predict(request, 30.0)
File "/anaconda3/envs/tf2/lib/python2.7/site-packages/grpc/_channel.py", line 532, in __call__
return _end_unary_response_blocking(state, call, False, None)
File "/anaconda3/envs/tf2/lib/python2.7/site-packages/grpc/_channel.py", line 466, in _end_unary_response_blocking
raise _Rendezvous(state, None, None, deadline)
grpc._channel._Rendezvous: <_Rendezvous of RPC that terminated with:
status = StatusCode.INVALID_ARGUMENT
details = "Expects arg[0] to be float but double is provided"
debug_error_string = "{"created":"@1536737041.281428000","description":"Error received from peer","file":"src/core/lib/surface/call.cc","file_line":1099,"grpc_message":"Expects arg[0] to be float but double is provided","grpc_status":3}"
我怎样才能做到这一点?
答案 0 :(得分:1)
感谢@sdcbr,您的解决方案有效,还有另一种方法。使用make_tensor_proto
时,将数据类型输入为dtype=types_pb2.DT_FLOAT
。
request.inputs['inputs'].CopyFrom(tf.contrib.util.make_tensor_proto(inputs, dtype=types_pb2.DT_FLOAT, shape=[1,125,3]))
两种解决方案都可以解决此错误。