Keras模型以Tensorflow输入b64编码数据而不是numpy ml-engine预测

时间:2017-11-23 01:32:04

标签: tensorflow keras tensorflow-serving google-cloud-ml

我正在尝试转换keras模型,以便将其用于google cloud的ml-engine上的预测。我有一个预先训练的分类器,它接受一个numpy数组作为输入。我发送给model.predict的正常工作数据名为input_data。

我将其转换为base 64并使用以下几行将其转储到json文件:

data = {}
data['image_bytes'] = [{'b64':base64.b64encode(input_data.tostring())}]

with open('weights/keras/example.json', 'w') as outfile:
    json.dump(data, outfile)

现在,我尝试从现有模型创建TF模型:

from keras.models import model_from_json
import tensorflow as tf
from keras import backend as K
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def, predict_signature_def
init = tf.global_variables_initializer()
with tf.Session() as sess:
    K.set_session(sess)
    sess.run(init)
    print("Keras model & weights loading...")
    K.set_learning_phase(0)
    with open(json_file_path, 'r') as file_handle:
        model = model_from_json(file_handle.read())

    model.load_weights(weight_file_path)

    builder = saved_model_builder.SavedModelBuilder(export_path)

    raw_byte_strings = tf.placeholder(dtype=tf.string, shape=[None], name='source')
    decode = lambda raw_byte_str: tf.decode_raw(raw_byte_str, tf.float32)
    input_images = tf.map_fn(decode, raw_byte_strings)
    print(input_images)
    signature = predict_signature_def(inputs={'image_bytes': input_images},
                                    outputs={'output': model.output})

    builder.add_meta_graph_and_variables(
        sess=sess,
        tags=[tag_constants.SERVING],
        signature_def_map={
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
        }
    )
    builder.save()

当我尝试在本地测试时,我收到以下错误:

ERROR:root:Exception during running the graph: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,473,473,3]
     [[Node: input_1 = Placeholder[dtype=DT_FLOAT, shape=[?,473,473,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

帮助?

0 个答案:

没有答案