用于预测的Tensorflow导出估计器

时间:2018-05-30 20:31:06

标签: python tensorflow mnist

我想知道如何导出估算器,然后从MNIST教程Tensorflow's page导入它进行预测。 谢谢!

1 个答案:

答案 0 :(得分:1)

Estimatormodel_dir个args,用于保存模型。因此,在预测期间,我们使用Estimator并调用predict方法重新创建图表并加载检查点。

对于MNIST示例,预测代码为:

tf.reset_default_graph()

# An input-function to predict the class of new data.
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": eval_data},
    num_epochs=1,
    shuffle=False)

mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

#Prediction call
predictions = mnist_classifier.predict(input_fn=predict_input_fn)

pred_class = np.array([p['classes'] for p in predictions]).squeeze()
print(pred_class)

# Output
# [7 2 1 ... 4 5 6]