是否可以从保存的模型中恢复张量流估计量?

时间:2018-11-07 05:40:44

标签: tensorflow tensorflow-estimator

我使用tf.estimator.train_and_evaluate()来训练我的自定义估算器。我的数据集按8:1:1划分,用于训练,评估和测试。在培训结束时,我想恢复最佳模型,并使用带有测试数据的tf.estimator.Estimator.evaluate()评估模型。目前,最佳模型是使用tf.estimator.BestExporter导出的。

虽然tf.estimator.Estimator.evaluate()接受checkpoint_path并还原变量,但我找不到任何简单的方法来使用tf.estimator.BestExporter生成的导出模型。我当然可以在训练过程中保留所有检查点,并亲自寻找最佳模型,但这似乎不是最佳选择。

有人可以告诉我一个简单的解决方法吗?也许可以将保存的模型转换为检查点吗?

4 个答案:

答案 0 :(得分:4)

也许您可以尝试tf.estimator.WarmStartSettings:https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator/WarmStartSettings

它可以将重量加载到pb文件中并继续训练,这在我的项目中起作用。

您可以按以下方式设置热启动:

ws = WarmStartSettings(ckpt_to_initialize_from="/[model_dir]/export/best-exporter/[timestamp]/variables/variables")

然后一切都会好的

答案 1 :(得分:1)

基于@SumNeuron的Github问题tf.contrib.estimator.SavedModelEstimator的解决方案,是从保存的模型加载到Estimator的方法。

以下对我有用:

estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)

令人困惑的是,这基本上是完全没有记录的。

答案 2 :(得分:0)

希望其他人会找到一种更清洁的方法。

tf.estimator.BestExporter导出最佳模型,如下所示:

<your_estimator.model_dir>
+--export
   +--best_exporter
      +--xxxxxxxxxx(timestamp)
         +--saved_model.pb
         +--variables
            +--variables.data-00000-of-00001
            +--variables.index

另一方面,在your_estimator.model_dir中,检查点存储在三个文件中。

model.ckpt-xxxx.data-00000-of-00001
model.ckpt-xxxx.index
model.ckpt-xxxx.meta

首先,我使用了tf.estimator.Estimator.evaluate(..., checkpoint_path='<your_estimator.model_dir>/export/best_exporter/<xxxxxxxxxx>/variables/variables'),但这没用。

在复制your_estimator.model_dir中的一个图元文件并将其重命名为“ variables.meta”之后,评估似乎可以正常工作。

答案 3 :(得分:0)

我对Estimator API还是陌生的,但我想我知道您在寻找什么,尽管这同样令人讨厌。

从这个colab开始,这是一个玩具习惯Estimator,上面有一些铃铛和口哨声:

from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(<model_dir>)
predict_fn(pred_features) # pred_features corresponds to your input features

,此估算器都使用BestExporter

exporter = tf.estimator.BestExporter(
    name="best_exporter",
    serving_input_receiver_fn=serving_input_receiver_fn,
    exports_to_keep=5
) # this will keep the 5 best checkpoints

以及在训练后仅导出模型:

est.export_savedmodel('./here', serving_input_receiver_fn)

如果让您感到困惑的是Estimator API没有“正确”的方式来加载SavedModel,我已经在GitHub上创建了issue

但是,如果您尝试将其加载到其他设备上,请参阅我的其他问题:

解决设备放置问题,还有其他GitHub问题

简而言之,目前,如果您使用Estimator导出程序进行导出,则在您训练Estimator时要使用的设备就是必须可用的设备。如果您将model_fn设置为clear_devices,并在{{1}}中手动导出Estimator,那么您应该一切顺利。目前,在导出模型之后,似乎没有办法更改此设置。