Tensorflow估算器:每次都无需加载检查点即可进行预测

时间:2018-07-18 13:40:57

标签: python tensorflow tensorflow-estimator

我正在Tensorflow(1.8)和python3.6中使用estimator为我的Reinforcement学习项目建立神经网络。而且我注意到每次使用estimator.predict()时,tensorflow都会加载model_dir下的检查点。但是如果您必须对同一检查点多次使用此功能,则效率极低,例如在强化学习中,我可能需要根据当前状态预测下一个动作,并且只有在选择特定动作后才能实现下一个状态。因此,数千次调用此函数是很平常的事。

所以我的问题是,如何在不每次加载检查点(相同检查点)的情况下调用此函数。

谢谢。

1 个答案:

答案 0 :(得分:1)

好吧,我想我已经为自己的问题找到了一个很好的答案。解决此问题的一个好方法是通过生成器构造tf.dataset。 Link is here

Generator将使您的estimator.predict保持打开状态,这样您就无需继续加载检查点。唯一需要做的就是在必要时更改此fastpredict对象(在这种情况下为self.next_feature)中产生的对象。

但是,我需要提及的是,您的最终目标是使整个产品成为一种服务还是诸如此类。您可能需要类似tf.serving之类的东西。因此,我建议您直接采用这种方式。我在这个过程中浪费了很多时间。因此,我希望这个答案可以帮助您保存您的答案。