我用TensorFlow的估算器训练了我的模型。似乎export_savedmodel
应该用来制作.pb
文件,但我真的不知道如何构建serving_input_receiver_fn
。有人有什么想法吗?
欢迎使用示例代码。
额外问题:
当我想重新加载模型时,.pb
是我需要的唯一文件吗? Variable
不必要?
与使用adam优化工具的.pb
相比,.ckpt
减少了模型文件的大小?
答案 0 :(得分:6)
您可以使用freeze_graph.py
从.pb
+ .ckpt
生成.pbtxt
如果您使用的是tf.estimator.Estimator
,那么您会在model_dir
python freeze_graph.py \
--input_graph=graph.pbtxt \
--input_checkpoint=model.ckpt-308 \
--output_graph=output_graph.pb
--output_node_names=<output_node>
- 当我想重新加载模型时,.pb是我需要的唯一文件吗?变量不必要?
醇>
是,您还必须知道模型的输入节点和输出节点名称。然后使用import_graph_def
加载.pb文件并使用get_operation_by_name
获取输入和输出操作
- 与使用adam优化器的.ckpt相比,.pb会减少多少模型文件大小?
醇>
.pb文件不是压缩的.ckpt文件,因此没有“压缩率”。
但是,有一种to optimize your .pb file推理方式,这种优化可能会减少文件大小,因为它会删除仅培训操作的图形部分(请参阅完整说明here)。 / p>
[comment]如何获取输入和输出节点名称?
您可以使用op name
参数设置输入和输出节点名称。
要列出.pbtxt
文件中的节点名称,请使用以下脚本。
import tensorflow as tf
from google.protobuf import text_format
with open('graph.pbtxt') as f:
graph_def = text_format.Parse(f.read(), tf.GraphDef())
print [n.name for n in graph_def.node]
[评论]我发现有一个tf.estimator.Estimator.export_savedmodel(),是直接在.pb中存储模型的函数吗?而我正在努力参与其中的参数serve_input_receiver_fn。有任何想法吗?
export_savedmodel()
生成SavedModel
,这是TensorFlow模型的通用序列化格式。它应包含适合TensorFlow Serving APIs的所有内容
serving_input_receiver_fn()
是为了生成SavedModel
而必须提供的所需内容的一部分,它通过向图表添加占位符来确定模型的输入签名。
来自文档
此功能具有以下用途:
- 将占位符添加到服务系统将提供的图表中 推理请求。
- 添加转换所需的任何其他操作 数据从输入格式到预期的特征张量 模型。
如果您以序列化tf.Examples
(这是一种典型模式)的形式收到推理请求,那么您可以使用doc中提供的示例。
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
[comment]是否有想法在'.pb'中列出节点名称?
这取决于它是如何生成的。
如果它是SavedModel
使用:
import tensorflow as tf
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
'./saved_models/1519232535')
print [n.name for n in meta_graph_def.graph_def.node]
如果是MetaGraph
,请使用:
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print [n.name for n in graph_def.node]