我最近开始使用Tensorflow,并尝试使用tf.estimator.Estimator对象。我想做一个自然而然的先验操作:训练完我的分类器(即tf.estimator.Estimator的实例(使用train
方法)之后,我想将其保存在文件中(无论扩展名是什么) ),然后稍后重新加载以预测一些新数据的标签。由于官方文档建议使用Estimator API,因此我认为应该实施和记录同样重要的内容。
我在其他页面上看到执行此操作的方法是export_savedmodel
(请参阅the official documentation),但我根本不理解文档。没有说明如何使用此方法。参数serving_input_fn
是什么?在Creating Custom Estimators教程或我阅读的任何教程中,我都从未遇到过。通过做一些谷歌搜索,我发现大约一年前,估计器是使用另一个类(tf.contrib.learn.Estimator
)定义的,看起来像tf.estimator.Estimator正在重用一些以前的API。但是我在文档中找不到关于它的明确解释。
有人可以给我一个玩具的例子吗?或向我解释如何定义/查找此serving_input_fn
?
然后如何再次加载经过训练的分类器?
谢谢您的帮助!
编辑:我发现并不一定需要使用export_savemodel来保存模型。它实际上是自动完成的。然后,如果稍后我们定义一个具有相同model_dir参数的新估算器,它也会自动还原先前的估算器,如here所述。
答案 0 :(得分:0)
如您所知,估计器在训练过程中会自动为您保存恢复模型。如果要将模型部署到字段(例如,为Tensorflow服务提供最佳模型),export_savemodel可能会有用。
这是一个简单的例子:
est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)
def serving_input_fn():
inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
基本上serving_input_fn负责用占位符替换数据集管道。在部署中,您可以将数据作为该模型的输入提供给该占位符,以进行推理或预测。