如何在cloud-ml中正确预测jpeg图像

时间:2016-12-21 11:25:48

标签: python google-cloud-ml

我想预测cloud-ml中的jpeg图像。

我的训练模型是初始模型,我想将输入发送到图的第一层:'DecodeJpeg/contents:0'(我必须发送jpeg图像)。我通过添加retrain.py

将此图层设置为可能的输入
inputs = {'image_bytes': 'DecodeJpeg/contents:0'}
tf.add_to_collection('inputs', json.dumps(inputs))

然后我将训练结果保存在两个文件(export和export.meta)中:

saver.save(sess, os.path.join(output_directory,'export'))

我使用这些文件在cloud-ml中创建了一个模型。

正如Google云端官方博客中的某些帖子(hereherehere中所建议的那样)我正在尝试使用

进行预测
gcloud beta ml predict --json-instances=request.json --model=MODEL

其中实例是以base64格式解码的jpeg图像:

python -c 'import base64, sys, json; img = base64.b64encode(open(sys.argv[1], "rb").read()); print json.dumps({"key":"0", "image_bytes": {"b64": img}})' image.jpg &> request.json

然而请求还给我:

error: 'Prediction failed: '

我的手术有什么问题?你有什么建议吗? 我特别从this帖子开始,我认为当用image_bytes读取请求时,cloud-ml会自动转换jpeg格式的base64图像。这是对的吗?否则我该怎么办?

3 个答案:

答案 0 :(得分:0)

CloudML要求您使用批量图像来提供图表。

我很确定这是重新使用retrain.py的问题。看到该代码的sess.run line;它一次送入一张图片。与flowers sample中的批处理jpeg占位符进行比较。

答案 1 :(得分:0)

请注意,需要构建三个略有不同的TF图:训练,评估和预测。有关详细信息,请参阅this recent blog post。训练和预测图直接使用预处理嵌入,因此它们不包含初始图。对于预测,我们需要将图像字节作为输入并使用Inception来提取嵌入。

对于在线预测,您需要导出预测图。您还应该指定输出和输入键。

构建预测图(the code):

def build_prediction_graph(self):
   """Builds prediction graph and registers appropriate endpoints."""
   tensors = self.build_graph(None, 1, GraphMod.PREDICT)
   keys_placeholder = tf.placeholder(tf.string, shape=[None])
   inputs = {
      'key': keys_placeholder.name,
      'image_bytes': tensors.input_jpeg.name
   }

   tf.add_to_collection('inputs', json.dumps(inputs))

   # To extract the id, we need to add the identity function.
   keys = tf.identity(keys_placeholder)
   outputs = {
       'key': keys.name,
       'prediction': tensors.predictions[0].name,
       'scores': tensors.predictions[1].name
   }
   tf.add_to_collection('outputs', json.dumps(outputs))

导出预测图:

def export(self, last_checkpoint, output_dir):
  # Build and save prediction meta graph and trained variable values.
  with tf.Session(graph=tf.Graph()) as sess:        
    self.build_prediction_graph()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    self.restore_from_checkpoint(sess, self.inception_checkpoint_file,
                                 last_checkpoint)
    saver = tf.train.Saver()
    saver.export_meta_graph(filename=os.path.join(output_dir, 'export.meta'))
    saver.save(sess, os.path.join(output_dir, 'export'), write_meta_graph=False)

last_checkpoint必须指向培训中的最新检查点文件:

self.model.export(tf.train.latest_checkpoint(self.train_path), self.model_path)

答案 2 :(得分:0)

在您的帖子中,您表明您的输入集合只有“image_bytes”张量别名。但是,在构成请求的代码中,您包含2个输入:一个是“键”,另一个是“image_bytes”。因此,我的建议是从请求中删除“密钥”或将“密钥”添加到输入集合。

第二个问题是DecodeJpeg / contents的形状:0',是()。对于Cloud ML,您需要具有类似(无)的形状,以便您可以将其输入。

在这里您的问题的其他答案中有一些建议,关于您如何能够关注公开帖子来修改图表,但是我可以告诉他们这两个问题。

如果您遇到任何其他问题,请告诉我们。