来自TensorFlow服务器的GRPC响应错误

时间:2019-05-10 16:02:14

标签: python tensorflow object-detection grpc

我正在尝试提供tensorflow对象检测api,但我的grpc客户端和Rest客户端针对同一图像给出不同的结果。我不知道为什么GRPC响应为:

  

[2.8745510860517243e-08,2.476179972177306e-08,1.955560691158098e-08,   1.1536828381508712e-08、1.0335512889980691e-08、9.3943972601096e-09、7.994729381260918e-09、6.33630348190195e-09、5.411928682974974576e-09、4.907114270480406e-09、4.88628841536179179e-09、4.271269560263136e-09、3.88036875872-795-09 3.550610694347256e-09、3.171058082784839e-09、3.1009288470329466e-09、2.9455364813912865e-09、2.471733706599366e-09、2.4317983182697844e-09、2.048162306422796e-09]

REST客户端响应为:

  

[0.996831,0.000675639,0.000323685,0.000144642,0.000137603,   0.000134516,0.000104812,0.000104108,9.99449e-05,8.9907e-05,8.72486e-05,6.28879e-05,6.16111e-05,6.06435e-05,5.47078e-05,4.88681e-05,4.87645e-05 ,4.73167e-05,4.13763e-05,4.01956e-05]

很显然,GRPC客户端无法检测到任何东西。这是我的GRPC客户端

<div *ngIf="formControl; then trueBlock; else falseBlock"></div>
<ng-template #trueBlock>
  <div>
    <input [formControl]="myFormControl">
  </div>
</ng-template>
<ng-template #falseBlock>
  <div [fromGroup]="myForm">
    <input [formControlName]="formControlName">
  </div>
</ng-template>

和REST客户端:

from tensorflow.core.framework import tensor_pb2  
from tensorflow.core.framework import tensor_shape_pb2  
from tensorflow.core.framework import types_pb2

from grpc.beta import implementations
from tensorflow_serving.apis import predict_pb2  
from tensorflow_serving.apis import prediction_service_pb2

import helper

parser = argparse.ArgumentParser(description='incetion grpc client flags.')
parser.add_argument('--host', default='0.0.0.0', help='inception serving host')
parser.add_argument('--port', default='8500', help='inception serving port')
parser.add_argument('--image', dest='image', type=str,
                        help='Path to the jpeg image directory')
FLAGS = parser.parse_args()

def main(file):  
  print("\n\ninput file {}...\n".format(file))
  channel = implementations.insecure_channel(FLAGS.host, int(FLAGS.port))
  stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

  request = predict_pb2.PredictRequest()
  request.model_spec.name = 'vedNet'
  # request.model_spec.signature_name = 'serving_default'

  img = cv2.imread(file).astype(np.uint8)
  tensor_shape = [1]+list(img.shape)
  dims = [tensor_shape_pb2.TensorShapeProto.Dim(size=dim) for dim in tensor_shape]  
  tensor_shape = tensor_shape_pb2.TensorShapeProto(dim=dims) 

  tensor = tensor_pb2.TensorProto(  
                dtype=types_pb2.DT_UINT8,
                tensor_shape=tensor_shape,
                float_val=list(img.reshape(-1)))
  request.inputs['inputs'].CopyFrom(tensor)  

  resp = stub.Predict(request, 30.0)
  f.close()

  # print(resp.outputs['detection_scores'].float_val)
  print(resp.outputs['detection_scores'].float_val)

if __name__ == '__main__':
  main(FLAGS.image)

任何帮助,请...

0 个答案:

没有答案