我们可以使用google-ml-engine预测图片标题吗?

时间:2017-07-10 11:10:49

标签: google-cloud-ml-engine

人口普查和鲜花样本展示了如何使用谷歌的ml引擎预测班级标签。

我们可以部署自己的模型来生成图片标题吗?如果是,预测如何运作?预测响应的格式是什么?

更具体地说,在下面显示的附件中,概率子数组给出了每个类的索引和机会。如果我们使用图像标题模型,预测响应将如何显示?

附件:http://boaloysius.me/sites/default/files/inline-images/predict1_0.png

1 个答案:

答案 0 :(得分:0)

Cloud ML Engine允许您部署几乎任何能够导出的TensorFlow模型。对于任何此类模型,您可以定义输入和输出,这就是请求和响应的形式。

我认为通过一个例子来理解这一点可能很有用。您可以想象导出这样的模型:

def my_model(image_in):
  # Construct an inference graph for predicting captions
  #
  # image_in is a tensor/array with shape=(None,) and dtype=string
  # Meaning, a batch of raw image bytes.

  ... Do your interesting stuff here ...      

  # caption_out is a tensor/matrix with shape=(None, MAX_WORDS) and 
  # dtype=tf.string), that is, you will be returning a batch
  # of captions, one per input image, one word per column with
  # padding when the number of words to output is < MAX_WORDS
  return caption_out

image_in = tf.placeholder(shape=(None,), dtype=tf.string)
caption_out = my_model(image_in)

inputs = {"image_bytes": tf.saved_model.utils.build_tensor_info(image_in}
outputs = {"caption": tf.saved_model.utils.build_tensor_info(caption_out)}

signature = tf.saved_model.signature_def_utils.build_signature_def(
      inputs=inputs,
      outputs=outputs,
      method_name='tensorflow/serving/predict'
)

导出此模型后(参见this post),您将构建一个类似的JSON请求:

{
  "instances": [
    {
      "image_bytes": {
        "b64": <base64_encoded_image1>
      }
    },
    {
      "image_bytes": {
        "b64": <base64_encoded_image2>
      }
    }
  ]
}

让我们分析一下请求。首先,我们将向服务发送一批图像。所有请求都是一个JSON对象,其数组值属性称为“实例”;数组中的每个条目都是一个实例,用于提供图形以进行预测。请注意,这就是为什么我们需要在导出的模型上将最外层维度设置为None - 它们需要能够处理可变大小的批次。

数组中的每个条目本身都是一个JSON对象,其中属性是我们在导出模型时定义的input字典的键。在这种情况下,我们只定义了image_bytes。因为image_bytes是一个字节字符串,所以我们需要对数据进行base64编码,这是通过传递{"b64": <data>}形式的JSON对象来实现的。如果我们想要向服务发送多个图像,我们可以为每个图像添加一个类似的条目到instance数组。

现在,此示例的示例JSON响应可能如下所示:

{
  "predictions": [
    {
      "caption": [
        "the",
        "quick",
        "brown",
        "",
        "",
        ""
      ]
    },
    {
      "caption": [
        "A",
        "person",
        "on",
        "the",
        "beach",
        ""
      ]
    }
  ]
}

所有响应都是具有名为“predictions”的数组值属性的JSON对象。数组的每个元素都是与请求中instances数组中相应输入相关联的预测。

数组中的每个条目都是一个JSON对象,其属性由我们之前导出的outputs dict的键确定。在这种情况下,每个输入都有一个名为caption的输出。请注意,标题caption_out的源张量实际上是一个矩阵,其中行数等于发送到服务的实例数以及我们定义为某个常量的列数。但是,服务不是返回矩阵,而是独立地返回矩阵的每一行作为prediction数组中的条目。矩阵的第二维是一些常数,并且可能模型本身会将额外的单词填充为空字符串(如上所述)。

一个非常重要的注意事项:在上面的示例中,我展示了原始JSON请求/响应主体。从您的帖子中可以看出,您正在使用Google的通用客户端,该客户端正在解析响应并在其周围添加结构,具体而言,您要打印的对象会在嵌套字段[data][modelData:protected]中包含预测。

我个人的建议是不要将该客户端用于该服务,而是使用通用请求/响应库(以及Google的身份验证库),但是由于您有工作的东西,欢迎您使用适合您的任何内容。