我使用tf.estimator.train_and_evaluate()
来训练我的自定义估算器。我的数据集按8:1:1划分,用于训练,评估和测试。在培训结束时,我想恢复最佳模型,并使用带有测试数据的tf.estimator.Estimator.evaluate()
评估模型。目前,最佳模型是使用tf.estimator.BestExporter
导出的。
虽然tf.estimator.Estimator.evaluate()
接受checkpoint_path
并还原变量,但我找不到任何简单的方法来使用tf.estimator.BestExporter
生成的导出模型。我当然可以在训练过程中保留所有检查点,并亲自寻找最佳模型,但这似乎不是最佳选择。
有人可以告诉我一个简单的解决方法吗?也许可以将保存的模型转换为检查点吗?
答案 0 :(得分:4)
也许您可以尝试tf.estimator.WarmStartSettings:https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator/WarmStartSettings
它可以将重量加载到pb文件中并继续训练,这在我的项目中起作用。
您可以按以下方式设置热启动:
ws = WarmStartSettings(ckpt_to_initialize_from="/[model_dir]/export/best-exporter/[timestamp]/variables/variables")
然后一切都会好的
答案 1 :(得分:1)
基于@SumNeuron的Github问题tf.contrib.estimator.SavedModelEstimator
的解决方案,是从保存的模型加载到Estimator
的方法。
以下对我有用:
estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)
令人困惑的是,这基本上是完全没有记录的。
答案 2 :(得分:0)
希望其他人会找到一种更清洁的方法。
tf.estimator.BestExporter
导出最佳模型,如下所示:
<your_estimator.model_dir>
+--export
+--best_exporter
+--xxxxxxxxxx(timestamp)
+--saved_model.pb
+--variables
+--variables.data-00000-of-00001
+--variables.index
另一方面,在your_estimator.model_dir
中,检查点存储在三个文件中。
model.ckpt-xxxx.data-00000-of-00001
model.ckpt-xxxx.index
model.ckpt-xxxx.meta
首先,我使用了tf.estimator.Estimator.evaluate(..., checkpoint_path='<your_estimator.model_dir>/export/best_exporter/<xxxxxxxxxx>/variables/variables')
,但这没用。
在复制your_estimator.model_dir
中的一个图元文件并将其重命名为“ variables.meta”之后,评估似乎可以正常工作。
答案 3 :(得分:0)
我对Estimator
API还是陌生的,但我想我知道您在寻找什么,尽管这同样令人讨厌。
从这个colab开始,这是一个玩具习惯Estimator
,上面有一些铃铛和口哨声:
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(<model_dir>)
predict_fn(pred_features) # pred_features corresponds to your input features
,此估算器都使用BestExporter
exporter = tf.estimator.BestExporter(
name="best_exporter",
serving_input_receiver_fn=serving_input_receiver_fn,
exports_to_keep=5
) # this will keep the 5 best checkpoints
以及在训练后仅导出模型:
est.export_savedmodel('./here', serving_input_receiver_fn)
如果让您感到困惑的是Estimator
API没有“正确”的方式来加载SavedModel
,我已经在GitHub上创建了issue。
但是,如果您尝试将其加载到其他设备上,请参阅我的其他问题:
解决设备放置问题,还有其他GitHub问题
简而言之,目前,如果您使用Estimator
导出程序进行导出,则在您训练Estimator
时要使用的设备就是必须可用的设备。如果您将model_fn
设置为clear_devices
,并在{{1}}中手动导出Estimator
,那么您应该一切顺利。目前,在导出模型之后,似乎没有办法更改此设置。