我已经使用TF对象检测API训练了模型。这将自动导出variables
文件夹和saved_model.pb
文件夹。我运行了saved_model_cli show --dir <saved_model_path> --all
并得到了输出:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['serialized_example'] tensor_info:
dtype: DT_STRING
shape: ()
name: tf_example:0
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40, 4)
name: detection_boxes:0
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40)
name: detection_classes:0
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40)
name: detection_scores:0
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: num_detections:0
outputs['raw_detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 204800, 4)
name: raw_detection_boxes:0
outputs['raw_detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 204800, 2)
name: raw_detection_scores:0
Method name is: tensorflow/serving/predict
signature_def['tensorflow/serving/predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['serialized_example'] tensor_info:
dtype: DT_STRING
shape: ()
name: tf_example:0
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40, 4)
name: detection_boxes:0
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40)
name: detection_classes:0
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40)
name: detection_scores:0
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: num_detections:0
outputs['raw_detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 204800, 4)
name: raw_detection_boxes:0
outputs['raw_detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 204800, 2)
name: raw_detection_scores:0
Method name is: tensorflow/serving/predict
然后我用以下方法创建了一个预测函数:
predict_fn = tf.contrib.predictor.from_saved_model(<saved_model_path>)
现在,我被困在构造要发送给它的图像有效载荷上。特别是,我对serialized_example
的输入名称,DT_STRING
的输入dtype和()
的输入形状感到困惑。有人愿意演示如何获取numpy数组或jpeg图像文件并为该模型形成适当的有效载荷吗?