用于导出模型的完整代码:(我已经对它进行了训练,现在可以从权重文件中加载)
def cnn_layers(inputs):
conv_base= keras.applications.mobilenetv2.MobileNetV2(input_shape=(224,224,3), input_tensor=inputs, include_top=False, weights='imagenet')
for layer in conv_base.layers[:-200]:
layer.trainable = False
last_layer = conv_base.output
x = GlobalAveragePooling2D()(last_layer)
x= keras.layers.GaussianNoise(0.3)(x)
x = Dense(1024,name='fc-1')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.advanced_activations.LeakyReLU(0.3)(x)
x = Dropout(0.4)(x)
x = Dense(512,name='fc-2')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.advanced_activations.LeakyReLU(0.3)(x)
x = Dropout(0.3)(x)
out = Dense(10, activation='softmax',name='output_layer')(x)
return out
model_input = layers.Input(shape=(224,224,3))
model_output = cnn_layers(model_input)
test_model = keras.models.Model(inputs=model_input, outputs=model_output)
weight_path = os.path.join(tempfile.gettempdir(), 'saved_wt.h5')
test_model.load_weights(weight_path)
export_path='export'
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def, predict_signature_def
from tensorflow.contrib.session_bundle import exporter
builder = saved_model_builder.SavedModelBuilder(export_path)
signature = predict_signature_def(inputs={'image': test_model.input},
outputs={'prediction': test_model.output})
with K.get_session() as sess:
builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map={'predict': signature})
builder.save()
并且(dir 1
的输出具有saved_model.pb
和models
dir):
python /tensorflow/python/tools/saved_model_cli.py show --dir /1 --all
是
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 224, 224, 3)
name: input_1:0
The given SavedModel SignatureDef contains the following output(s):
outputs['prediction'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 107)
name: output_layer/Softmax:0
Method name is: tensorflow/serving/predict
接受b64字符串:
该代码是为(224, 224, 3)
numpy数组编写的。因此,我对以上代码进行的修改是:
_bytes
传递时,应将b64
添加到输入中。因此, predict_signature_def(inputs={'image':......
更改为
predict_signature_def(inputs={'image_bytes':.....
type(test_model.input)
是:(224, 224, 3)
和dtype: DT_FLOAT
。所以, signature = predict_signature_def(inputs={'image': test_model.input},.....
更改为(reference)
temp = tf.placeholder(shape=[None], dtype=tf.string)
signature = predict_signature_def(inputs={'image_bytes': temp},.....
修改:
使用请求发送的代码为:(如评论中所述)
encoded_image = None
with open('/1.jpg', "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
object_for_api = {"signature_name": "predict",
"instances": [
{
"image_bytes":{"b64":encoded_image}
#"b64":encoded_image (or this way since "image" is not needed)
}]
}
p=requests.post(url='http://localhost:8501/v1/models/mnist:predict', json=json.dumps(object_for_api),headers=headers)
print(p)
我遇到<Response [400]>
错误。我认为我的发送方式没有错误。在导出模型的代码中需要进行某些更改,特别是
temp = tf.placeholder(shape=[None], dtype=tf.string)
。
答案 0 :(得分:0)
查看您提供的文档是获取图像并将其发送到API。如果对图像进行编码,则可以以文本格式轻松传输图像,其中base64几乎是标准格式。因此,我们要做的是在正确的位置创建一个图像为base64的json对象,然后将此json对象发送到REST API中。 python具有请求库,这使得以JSON格式发送python字典非常容易。
因此,拍摄图像,对其进行编码,将其放入字典中,然后使用请求将其发送出去:
import requests
import base64
encoded_image = None
with open("image.png", "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
object_for_api = {"signature_name": "predict",
"instances": [
{
"image": {"b64": encoded_image}
}]
}
requests.post(url='http://localhost:8501/v1/models/mnist:predict', json=object_for_api)
您还可以将numpy数组编码为JSON,但API文档似乎并没有在寻找它。
答案 1 :(得分:0)
两个注意事项:
tf.saved_model.simple_save
model_to_estimator
很方便。X-Trace-Id
的输出显示输入和输出的外部尺寸均为saved_model_cli
),但发送浮点数的JSON数组效率很低最后一点,修改代码以进行图像解码服务器端通常更容易,因此您要通过网络发送base64编码的JPG或PNG,而不是浮点数组。这是Keras的一个示例(我打算用更简单的代码更新该答案)。