我有一个Keras模型,我想将其转换为Tensorflow protobuf(例如saved_model.pb
)。
此模型来自vgg-19网络上的传输学习,其中头部被切断并使用完全连接的+ softmax层进行训练,而vgg-19网络的其余部分被冻结
我可以在Keras中加载模型,然后使用keras.backend.get_session()
在tensorflow中运行模型,生成正确的预测:
frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")
keras_prediction = keras_model.predict(frame)
print(keras_prediction)
with keras.backend.get_session() as sess:
tvars = tf.trainable_variables()
output = sess.graph.get_tensor_by_name('Softmax:0')
input_tensor = sess.graph.get_tensor_by_name('input_1:0')
tf_prediction = sess.run(output, {input_tensor: frame})
print(tf_prediction) # this matches keras_prediction exactly
如果我不包含第tvars = tf.trainable_variables()
行,则tf_prediction
变量完全错误,并且根本不匹配keras_prediction
的输出。事实上,输出中的所有值(具有4个概率值的单个数组)完全相同(~0.25,全部加1)。这让我怀疑如果先没有调用tf.trainable_variables()
,头部的权重刚刚初始化为0,这在检查模型变量后得到了证实。在任何情况下,调用tf.trainable_variables()
都会导致张量流预测正确。
问题在于,当我尝试保存此模型时,来自tf.trainable_variables()
的变量实际上并未保存到.pb
文件中:
with keras.backend.get_session() as sess:
tvars = tf.trainable_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)
我要问的是,如何将Keras模型保存为tf.training_variables()
完整的Tensorflow protobuf?
非常感谢!
答案 0 :(得分:4)
因此,冻结图中变量(转换为常量)的方法应该有效,但不是必需的,而且比其他方法更棘手。 (更多内容见下文)。如果您希望图形由于某种原因而冻结(例如导出到移动设备),我需要更多细节来帮助调试,因为我不确定Keras在幕后使用您的图表做了什么隐含的事情。但是,如果你想稍后保存并加载图形,我可以解释如何做到这一点,(虽然不能保证Keras正在做的事情不会搞砸了......,很高兴帮助调试它。)
所以这里有两种格式。一个是GraphDef
,用于检查点,因为它不包含有关输入和输出的元数据。另一个是MetaGraphDef
,其中包含元数据和图表def,元数据可用于预测和运行ModelServer
(来自张量流/服务)。
在任何一种情况下,您都需要做的不仅仅是调用graph_io.write_graph
,因为变量通常存储在graphdef之外。
这两个用例都有包装库。 tf.train.Saver
主要用于保存和恢复检查点。
但是,由于您需要预测,我建议使用tf.saved_model.builder.SavedModelBuilder
来构建SavedModel二进制文件。我为此提供了一些锅炉板:
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
output = sess.graph.get_tensor_by_name('Softmax:0')
input_tensor = sess.graph.get_tensor_by_name('input_1:0')
sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
{'input': input_tensor},
{'output': output}
)
builder.add_meta_graph_and_variables(
sess, tf.saved_model.tag_constants.SERVING,
signature_def_map={
DEFAULT_SIG_DEF: sig_def
}
)
builder.save()
运行此代码后,您应该有一个mymodel/saved_model.pb
文件以及一个目录mymodel/variables/
,其中包含与变量值对应的protobufs。
然后再次加载模型,只需使用tf.saved_model.loader
:
# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
meta_graph_def = tf.saved_model.loader.load(
sess,
tf.saved_model.tag_constants.SERVING,
'./mymodel'
)
# From this point variables and graph structure are restored
sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))
显然,通过tensorflow / serve或Cloud ML Engine,这个代码可以提供更有效的预测,但这应该有效。 Keras可能正在做一些会干扰这个过程的事情,如果是这样我们也想听听它(我想确保Keras用户也能冻结图形,所以如果你想给我一个带有完整代码的要点或者其他什么,我可以找到一个能够帮助我调试Keras的人。)
编辑:你可以在这里找到一个端到端的例子:https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/keras/trainer/model.py#L85