使用tensorflow 1.6和ml-engine在模型执行期间预测失败。它们应该兼容

时间:2018-04-05 00:03:32

标签: python tensorflow google-cloud-ml inference object-detection-api

我使用tensorflow1.6从头开始训练SSD-inception-v2模型。没有警告或错误。然后我使用以下标志导出模型:

- input_type encoded_image_string_tensor --pipeline_config_path experiments / ssd_inception_v2 / ssd_inception_v2.config --trained_checkpoint_prefix experiments / ssd_inception_v2 / train / model.ckpt-400097 --output_directory experiments / ssd_inception_v2 / frozen_graphs /

之后,我将saved_mode.pb上传到Google存储桶,在ml-engine中创建了一个模型并创建了一个版本(我确实使用了--runtime-version = 1.6)。

最后,我使用gcloud命令询问在线预测,但获得了以下错误:

{ "错误":"预测失败:模型执行期间出错:AbortionError(代码= StatusCode.INVALID_ARGUMENT,details ="第二个输入必须是标量,但它有形状[1] \ n \吨[[节点:地图/同时/ decode_image / cond_jpeg / cond_png / DecodePng /开关=开关[T = DT_STRING,_class = [" LOC:/ TensorArrayReadV3&#34],_device =" / job:localhost / replica:0 / task:0 / device:CPU:0"](map / while / TensorArrayReadV3,map / while / decode_image / is_jpeg)]]")" }

日志描述了模型执行时出现的问题。

1 个答案:

答案 0 :(得分:0)

预测请求的格式为(cf the official docs):

{
  "instances": [
    ...
  ]
}

根据关于对象检测的this blog post,标志--input_type encoded_image_string_tensor生成一个名为inputs的单个输入的模型,该输入接受一批JPG或PNG图像。那些图像必须是base64编码的。所以把它们放在一起,实际的请求应该是这样的:

{
  "instances": [
    {
      "inputs": {
        "b64": "..."
      }
    }
  ]
}

由于只有一个输入,我们可以使用一个简写,一个对象/字典的实例{"输入":{" b64":...}}是只是字典的价值,即{" b64":...}:

{
  "instances": [
    {
      "b64": "..."
    }
  ]
}

请注意,如果模型只有一个输入,则其中任何一个都是可以接受的。

即使以上是服务接受的请求的格式,gcloud命令行工具实际上并不期望请求的全部内容。它期待实际的"实例",即JSON中[]之间的事物,由换行符分隔。这意味着您的文件应如下所示:

{"b64": "..."}

或者

{"inputs": {"b64": "..."}}

如果要发送多个图像,则文件中每行都有一个图像。

尝试使用以下代码生成输出:

json_data = []
for index, image in enumerate(images, 1):
    with open(image, "rb") as open_file:
        byte_content = open_file.read()
    # Convert to base64
    base64_bytes = b64encode(byte_content)
    # Decode bytes to text
    base64_string = base64_bytes.decode("utf-8")
    # Create dictionary
    raw_data = {"b64": base64_string}

    # Put data to json
    json_data.append(json.dumps(raw_data))

# Write to the file
with open(predict_instance_json, "w") as fp:
    fp.write('\n'.join(json_data))