有没有办法用多个模型进行批量预测?

时间:2021-02-04 19:28:21

标签: python tensorflow machine-learning tf.keras

我将 TensorFlow 用于强化学习项目,在该项目中,我创建了一组模型,用于预测大型输入数据集的输出。在评估每个模型后,计算分数,并将分数最低的模型从总体中移除,并替换为最佳网络的调整参数版本。

我可以通过在群体的个体成员上运行数据集来使用批量预测,但我无法避免在群体的每个成员上运行 model.predict 的循环:

for model in population:
    outputs = model.predict(inputs, batch_size=BATCH_SIZE)

由于 model.predicttf.function,因此每次运行时都会收到以下警告:

WARNING:tensorflow:5 out of the last 15 calls to <function Model.make_predict_function
<locals>.predict_function at 0x7f99d5e41730> triggered tf.function retracing. Tracing is
expensive and the excessive number of tracings could be due to (1) creating @tf.function
repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects
instead of tensors. For (1), please define your @tf.function outside of the loop. For (2),
@tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can
avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guid
/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for 
more details.

我对 TensorFlow 的经验不是那么丰富,但据我所知,tf.function 依赖于将函数跟踪到图形中,以便更快地运行。此外,通过在每个模型上多次运行 model.predict,我在回溯每个模型时使用了大量计算能力。

有没有办法在多个模型上使用批量训练?如果我每次调用该函数时都收到此警告的垃圾邮件,我肯定会觉得我做错了什么。目前我已经禁用了 TensorFlow 警告日志记录,但我仍然对这种情况感到不安。

0 个答案:

没有答案