如何使用export_savedmodel保存和还原tf.estimator.Estimator模型?

时间:2018-07-13 18:17:41

标签: tensorflow-estimator

我最近开始使用Tensorflow,并尝试使用tf.estimator.Estimator对象。我想做一个自然而然的先验操作:训练完我的分类器(即tf.estimator.Estimator的实例(使用train方法)之后,我想将其保存在文件中(无论扩展名是什么) ),然后稍后重新加载以预测一些新数据的标签。由于官方文档建议使用Estimator API,因此我认为应该实施和记录同样重要的内容。

我在其他页面上看到执行此操作的方法是export_savedmodel(请参阅the official documentation),但我根本不理解文档。没有说明如何使用此方法。参数serving_input_fn是什么?在Creating Custom Estimators教程或我阅读的任何教程中,我都从未遇到过。通过做一些谷歌搜索,我发现大约一年前,估计器是使用另一个类(tf.contrib.learn.Estimator)定义的,看起来像tf.estimator.Estimator正在重用一些以前的API。但是我在文档中找不到关于它的明确解释。

有人可以给我一个玩具的例子吗?或向我解释如何定义/查找此serving_input_fn

然后如何再次加载经过训练的分类器?

谢谢您的帮助!

编辑:我发现并不一定需要使用export_savemodel来保存模型。它实际上是自动完成的。然后,如果稍后我们定义一个具有相同model_dir参数的新估算器,它也会自动还原先前的估算器,如here所述。

1 个答案:

答案 0 :(得分:0)

如您所知,估计器在训练过程中会自动为您保存恢复模型。如果要将模型部署到字段(例如,为Tensorflow服务提供最佳模型),export_savemodel可能会有用。

这是一个简单的例子:

est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)

def serving_input_fn(): inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs)

基本上serving_input_fn负责用占位符替换数据集管道。在部署中,您可以将数据作为该模型的输入提供给该占位符,以进行推理或预测。