如何使用Estimator和Dataset API训练的保存模型进行预测?

时间:2017-12-12 02:29:27

标签: python tensorflow machine-learning deep-learning tensorflow-estimator

我使用tf.estimatortf.data.TFRecordDataset训练了一个cnn模型,它在model_fn函数中定义了一个模型并在input_fn函数中输入。同时使用one-shot iterator一次获取一个批处理示例。

现在我在一个目录中训练了模型文件(ckpt,meta,index)。我想要做的是根据训练的模型预测图像的标签,而无需再次进行培训和评估。图像可以是numpy数组,但不可能是TFRecords文件(在跟踪时使用)。

在整天尝试后,我找不到有效的解决方案。我只能得到权重和偏差的值,不知道如何使我的预测图像和模型兼容。

仅供参考,我的培训代码为here

类似的问题是Prediction from model saved with tf.estimator.Estimator in Tensorflow ,但没有接受的答案,我的模型输入正在使用数据集api。

所以真的需要帮助。感谢。

1 个答案:

答案 0 :(得分:1)

我已经回答了类似的问题here

要使用自定义输入进行预测,您需要使用估算器的内置predict方法:

estimator = tf.estimator.Estimator(model_fn, ...)

predict_input_fn = ...  # define this using tf.data

predict_results = estimator.predict(predict_input_fn)
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))