导出自定义Keras模型以用于Cloud ML Engine进行预测

时间:2018-07-06 21:50:21

标签: tensorflow keras google-cloud-ml google-prediction

我很难导出经过Keras培训的自定义VGG-Net(并非完全来自Keras),因此可以将其用于Google Cloud Predict API。我正在用Keras加载模型。

sess = tf.Session()
K.set_session(sess)

model = load_model(model.h5)

我要分类的图像被编码为base64字符串。因此,为了进行预测任务,我必须使用在一个Google示例中找到的一些代码对其进行解码。

channels = 3
height = 96
width = 96

def decode_and_resize(image_str_tensor):
   """Decodes jpeg string, resizes it and returns a uint8 tensor."""
   image = tf.image.decode_jpeg(image_str_tensor, channels=channels)
   image = tf.expand_dims(image, 0)
   image = tf.image.resize_bilinear(
       image, [height, width], align_corners=False)
   image = tf.squeeze(image, squeeze_dims=[0])
   image = tf.cast(image, dtype=tf.uint8)
   return image

image_str_tensor = tf.placeholder(tf.string, shape=[None])
key_input = tf.placeholder(tf.string, shape=[None]) 
key_output = tf.identity(key_input)

input_tensor = tf.map_fn(
    decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
input_tensor = tf.image.convert_image_dtype(image, dtype=tf.float32)

但是在此之后,我不再知道如何进行。现在如何将这个输入张量放入模型中,并获得正确的输出张量,以便能够定义SignatureDef,然后将图形导出为SavedModel?

任何帮助将不胜感激。

2 个答案:

答案 0 :(得分:0)

免责声明:尽管我是Cloud ML Engine预测服务的专家,并且对TensorFlow相当了解,但我对Keras并不十分了解。我只是将来自其他地方的信息拼凑在一起,尤其是this samplethis answer。我只能想象有更好的方法可以做到这一点,希望人们能发表这样的话。同时,我希望这能满足您的需求。

此特定答案假定您已经保存了模型。代码加载模型,然后将其导出为SavedModel。

基本思想是开始为输入(输入占位符,图像解码,调整大小和批处理等)构建“原始” TensorFlow模型,然后通过“重建”将其“连接”为Keras VGG模型VGG模型的结构,最后将保存的权重恢复到新建的模型中。然后,将这个版本的模型另存为SavedModel。

这里的“魔术”是原始TF预处理与VGG模型之间的联系。这是通过将TF预处理图形(以下代码中的input_tensor)的“输出”作为input_tensor传递给Keras VGG图形来完成的。input_tensor包含一批已经解码和调整大小的图片,就像VGG期望的那样。

import keras.backend as K
import tensorflow as tf
from keras.models import load_model, Sequential
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def

MODEL_FILE = 'model.h5'
WEIGHTS_FILE = 'weights.h5'
EXPORT_PATH = 'YOUR/EXPORT/PATH'

channels = 3
height = 96
width = 96

def build_serving_inputs():

  def decode_and_resize(image_str_tensor):
     """Decodes jpeg string, resizes it and returns a uint8 tensor."""
     image = tf.image.decode_jpeg(image_str_tensor, channels=channels)
     image = tf.expand_dims(image, 0)
     image = tf.image.resize_bilinear(
         image, [height, width], align_corners=False)
     image = tf.squeeze(image, squeeze_dims=[0])
     image = tf.cast(image, dtype=tf.uint8)
     return image

  image_str_tensor = tf.placeholder(tf.string, shape=[None])
  key_input = tf.placeholder(tf.string, shape=[None]) 
  key_output = tf.identity(key_input)

  input_tensor = tf.map_fn(
      decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
  input_tensor = tf.image.convert_image_dtype(input_tensor, dtype=tf.float32) 

  return image_str_tensor, input_tensor, key_input, key_output

# reset session
K.clear_session()

with tf.Graph().as_default() as g, tf.Session(graph=g) as sess:
  K.set_session(sess)

  image_str_tensor, input_tensor, key_input, key_output = build_serving_inputs()

  # disable loading of learning nodes
  K.set_learning_phase(0)

  # Load model and save out the weights
  model = load_model(MODEL_FILE)
  model.save_weights(WEIGHTS_FILE)

  # Rebuild the VGG16 model with the weights
  new_model = keras.applications.vgg16.VGG16(
    include_top=True, weights=WEIGHTS_FILE, input_tensor=input_tensor,
    input_shape=[width, height, channels], pooling=None)

  # export saved model
  tf.saved_model.simple_save(
      sess,
      EXPORT_PATH,
      inputs={'image_bytes': image_str_tensor, 'key': key_input},
      outputs={'predictions': new_model.outputs[0], 'key': key_output}
  )

注意:我不知道此代码是否还可以正常工作(尚未测试);我担心它如何处理批次尺寸。 build_serving_inputs创建一个具有批量尺寸的张量,并将其传递给Keras。

答案 1 :(得分:0)

TensorFlow Keras(tf.keras)现在可以从Keras模型转换为TF Estimator tf.keras.estimator.model_to_estimator。 Estimator将带您到SavedModel,您可以将其与Cloud ML Engine一起使用进行预测。检出此post以了解此API的用法。