Cloud ml的编码问题

时间:2018-11-01 21:25:46

标签: python tensorflow google-cloud-ml

我正在运行此tutorial中编写的代码,当我尝试使用JSON进行在线预测时,该JSON包含以b64编码的图片,我得到的消息是期望uint8并得到一个十六进制字符串代替。 我已经检查了JSON,它的格式还可以。可能是什么问题?

我正在使用Google的CLI进行预测,而我的模型版本是tensorflow 1.10以及我的运行时版本。我使用tf API通过fast-rcnn进行对象检测

将图像转换为b64以及JSON和tfrecord的代码

import base64
import io
import json
from PIL import Image
import tensorflow as tf
width = 1024
height = 768
predict_instance_json = "inputs.json"
predict_instance_tfr = "inputs.tfr"
with tf.python_io.TFRecordWriter(predict_instance_tfr) as tfr_writer:
  with open(predict_instance_json, "wb") as fp:
    for image in ["image1.jpg", "image2.jpg"]:
      img = Image.open(image)
      img = img.resize((width, height), Image.ANTIALIAS)
      output_str = io.BytesIO()
      img.save(output_str, "JPEG")
      fp.write(
          json.dumps({"b64": base64.b64encode(output_str.getvalue())}) + "\n")
      tfr_writer.write(output_str.getvalue())
      output_str.close()

预测命令:

gcloud ml-engine predict --model=${YOUR_MODEL} --version=${YOUR_VERSION} --json-instances=inputs.json

我已经在本地测试了模型并创建了一个带有tensorflow服务的docker容器,它工作正常,但在Cloud ml上没有成功。

错误提示:

"error": "Prediction failed: Error processing input: Expected uint8, got '\\xff\\xd8\\...

\\xff\\xd9' of type 'str' instead."

1 个答案:

答案 0 :(得分:1)

问题在于图形的导出方式,在调用脚本export_inference_graph.py时添加标志--input_type非常重要,否则API模型的输入为UINT8而不是字符串。

--input_type encoded_image_string_tensor