我使用tf.estimator
和tf.data.TFRecordDataset
训练了一个cnn模型,它在model_fn
函数中定义了一个模型并在input_fn
函数中输入。同时使用one-shot iterator一次获取一个批处理示例。
现在我在一个目录中训练了模型文件(ckpt,meta,index)。我想要做的是根据训练的模型预测图像的标签,而无需再次进行培训和评估。图像可以是numpy数组,但不可能是TFRecords文件(在跟踪时使用)。
在整天尝试后,我找不到有效的解决方案。我只能得到权重和偏差的值,不知道如何使我的预测图像和模型兼容。
仅供参考,我的培训代码为here。
类似的问题是Prediction from model saved with tf.estimator.Estimator
in Tensorflow
,但没有接受的答案,我的模型输入正在使用数据集api。
所以真的需要帮助。感谢。
答案 0 :(得分:1)
我已经回答了类似的问题here。
要使用自定义输入进行预测,您需要使用估算器的内置predict
方法:
estimator = tf.estimator.Estimator(model_fn, ...)
predict_input_fn = ... # define this using tf.data
predict_results = estimator.predict(predict_input_fn)
for idx, prediction in enumerate(predict_results):
print(idx)
for key in prediction:
print("...{}: {}".format(key, prediction[key]))