Tensorflow加速预测

时间:2018-04-22 10:38:49

标签: python tensorflow machine-learning

我的预测性能有问题。我所做的是在Python循环中反复调用test_predictions op并将其所有返回值放入列表中。代码如下所示:

predictions = []
for _ in trange(args.num_batches):
    predictions.extend(sess.run(model.test_predictions))

当我查看性能统计超过2/3的时间时,我的GPU卡空闲,可能是因为Python和TF代码之间不断切换。我不能让批量更大,因为它不适合记忆。我可以实施更好的解决方案吗?

1 个答案:

答案 0 :(得分:0)

没有"在Python和TF代码之间切换"。如果GPU空闲很多,这意味着您获取数据(图像?)来运行预测需要很长时间,而GPU必须等待很多数据才能到达。

尝试实施预取。

或者,如果您有足够的内存,只需一次读取所有图像并以此方式提供网络。