人口普查和鲜花样本展示了如何使用谷歌的ml引擎预测班级标签。
我们可以部署自己的模型来生成图片标题吗?如果是,预测如何运作?预测响应的格式是什么?
更具体地说,在下面显示的附件中,概率子数组给出了每个类的索引和机会。如果我们使用图像标题模型,预测响应将如何显示?
附件:http://boaloysius.me/sites/default/files/inline-images/predict1_0.png
答案 0 :(得分:0)
Cloud ML Engine允许您部署几乎任何能够导出的TensorFlow模型。对于任何此类模型,您可以定义输入和输出,这就是请求和响应的形式。
我认为通过一个例子来理解这一点可能很有用。您可以想象导出这样的模型:
def my_model(image_in):
# Construct an inference graph for predicting captions
#
# image_in is a tensor/array with shape=(None,) and dtype=string
# Meaning, a batch of raw image bytes.
... Do your interesting stuff here ...
# caption_out is a tensor/matrix with shape=(None, MAX_WORDS) and
# dtype=tf.string), that is, you will be returning a batch
# of captions, one per input image, one word per column with
# padding when the number of words to output is < MAX_WORDS
return caption_out
image_in = tf.placeholder(shape=(None,), dtype=tf.string)
caption_out = my_model(image_in)
inputs = {"image_bytes": tf.saved_model.utils.build_tensor_info(image_in}
outputs = {"caption": tf.saved_model.utils.build_tensor_info(caption_out)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name='tensorflow/serving/predict'
)
导出此模型后(参见this post),您将构建一个类似的JSON请求:
{
"instances": [
{
"image_bytes": {
"b64": <base64_encoded_image1>
}
},
{
"image_bytes": {
"b64": <base64_encoded_image2>
}
}
]
}
让我们分析一下请求。首先,我们将向服务发送一批图像。所有请求都是一个JSON对象,其数组值属性称为“实例”;数组中的每个条目都是一个实例,用于提供图形以进行预测。请注意,这就是为什么我们需要在导出的模型上将最外层维度设置为None
- 它们需要能够处理可变大小的批次。
数组中的每个条目本身都是一个JSON对象,其中属性是我们在导出模型时定义的input
字典的键。在这种情况下,我们只定义了image_bytes
。因为image_bytes
是一个字节字符串,所以我们需要对数据进行base64编码,这是通过传递{"b64": <data>}
形式的JSON对象来实现的。如果我们想要向服务发送多个图像,我们可以为每个图像添加一个类似的条目到instance
数组。
现在,此示例的示例JSON响应可能如下所示:
{
"predictions": [
{
"caption": [
"the",
"quick",
"brown",
"",
"",
""
]
},
{
"caption": [
"A",
"person",
"on",
"the",
"beach",
""
]
}
]
}
所有响应都是具有名为“predictions”的数组值属性的JSON对象。数组的每个元素都是与请求中instances
数组中相应输入相关联的预测。
数组中的每个条目都是一个JSON对象,其属性由我们之前导出的outputs
dict的键确定。在这种情况下,每个输入都有一个名为caption
的输出。请注意,标题caption_out
的源张量实际上是一个矩阵,其中行数等于发送到服务的实例数以及我们定义为某个常量的列数。但是,服务不是返回矩阵,而是独立地返回矩阵的每一行作为prediction
数组中的条目。矩阵的第二维是一些常数,并且可能模型本身会将额外的单词填充为空字符串(如上所述)。
一个非常重要的注意事项:在上面的示例中,我展示了原始JSON请求/响应主体。从您的帖子中可以看出,您正在使用Google的通用客户端,该客户端正在解析响应并在其周围添加结构,具体而言,您要打印的对象会在嵌套字段[data]
和[modelData:protected]
中包含预测。
我个人的建议是不要将该客户端用于该服务,而是使用通用请求/响应库(以及Google的身份验证库),但是由于您有工作的东西,欢迎您使用适合您的任何内容。