对于小数据集,Estimator.train()和.predict()太慢

时间:2019-08-27 08:01:25

标签: python tensorflow-estimator

我正在尝试实现一个DQN,该DQN在同一模型上对Estimator.train()和其后的Estimator.predict()进行多次调用,每个示例都有少量示例。但是每次调用至少要花费几百毫秒到一秒的时间,这与1-20等小数字的示例数无关。

我认为这些延迟是由于重建图表并在每次调用时保存检查点而引起的。有没有办法在内存中保留相同的图形和参数,以进行快速的火车预测迭代,或者以其他方式加快速度?

1 个答案:

答案 0 :(得分:0)

转换为tf.keras.Model而不是Estimator,并使用tf.keras.Model.fit()代替Estimator.train()fit()不具有train()的固定延迟。 Keras predict()也不是。