恢复模型tf.estimator.DNNClassifier

时间:2018-06-04 10:09:18

标签: tensorflow

model = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[1024, 512, 256])
model.train(input_fn=input_func,steps=5000)

这创建检查点

我复出第2天;现在我从检查点需要我的模型;如何恢复?

sess=tf.Session()
saver = tf.train.import_meta_graph(file_path + "/" + "model.ckpt-1000.meta")
saver.restore(sess,tf.train.latest_checkpoint(file_path))
model = ????? -- how do I get my model back?

2 个答案:

答案 0 :(得分:0)

我也被困在同一个问题上一整天,关于如何从保存的检查点和图表中进行推理。几乎所有建议都是使用.meta文件加载图形,并将其冻结为某个.pb文件。我尝试了这些方法,他们在示例中工作,但由于某种原因我的情况失败了。而且,对TF来说是新手,也无法理解它。

所以,我只是进行了一些打击和试验(经过4-5小时的网络嗅探)以与训练期间相同的方式加载模型,令我惊讶的是它起作用了。 您需要培训脚本/估算工具或模型函数来加载模型。

test_data = tf.estimator.inputs.numpy_input_fn(<input_feed_dict/x=np.array(something)>, num_epochs=1, shuffle=False)
model = tf.estimator.Estimator(model_fn=<Your_Estimator_Function>, model_dir="<model_dir>")

predictions = model.predict(input_fn=test_data)

另一方面,在经历了这么多链接后,我可以说没有适用于高级张量流API的文档(特别是初学者)。所有保存和恢复Tensorflow模型教程都使用了低级Tensorflow API,它支持所有这些功能以及更多功能(优雅的会话/图表/摘要等)

在保存答案后看到评论,我认为这也是类似的事情,所以不确定是否保留我的答案。仍然保持它

答案 1 :(得分:0)

不确定我为什么挣扎。答案很简单。阅读关于检查点的文章:https://www.tensorflow.org/get_started/checkpoints

重新加载模型的非常简单的代码:

model_load = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[10, 10, 10, 10], model_dir="C:/Users/AI101~1/AppData/Local/Temp/tmpm2ndcvf_")