我将 TensorFlow 用于强化学习项目,在该项目中,我创建了一组模型,用于预测大型输入数据集的输出。在评估每个模型后,计算分数,并将分数最低的模型从总体中移除,并替换为最佳网络的调整参数版本。
我可以通过在群体的个体成员上运行数据集来使用批量预测,但我无法避免在群体的每个成员上运行 model.predict
的循环:
for model in population:
outputs = model.predict(inputs, batch_size=BATCH_SIZE)
由于 model.predict
是 tf.function
,因此每次运行时都会收到以下警告:
WARNING:tensorflow:5 out of the last 15 calls to <function Model.make_predict_function
<locals>.predict_function at 0x7f99d5e41730> triggered tf.function retracing. Tracing is
expensive and the excessive number of tracings could be due to (1) creating @tf.function
repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects
instead of tensors. For (1), please define your @tf.function outside of the loop. For (2),
@tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can
avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guid
/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for
more details.
我对 TensorFlow 的经验不是那么丰富,但据我所知,tf.function
依赖于将函数跟踪到图形中,以便更快地运行。此外,通过在每个模型上多次运行 model.predict
,我在回溯每个模型时使用了大量计算能力。
有没有办法在多个模型上使用批量训练?如果我每次调用该函数时都收到此警告的垃圾邮件,我肯定会觉得我做错了什么。目前我已经禁用了 TensorFlow 警告日志记录,但我仍然对这种情况感到不安。