将数据加载到Tensorflow进行实时推理的最有效方法是什么?

时间:2019-02-15 20:23:57

标签: python-3.x tensorflow gpu

设置到Tensorflow的数据输入管道(网络摄像头图像)时,要花费大量时间将数据从系统RAM加载到GPU内存。

我正在尝试通过对象检测网络提供恒定的图像流(1024x1024)。我目前正在AWS上使用V100进行推理。

第一次尝试是通过简单的feed dict操作。

# Get layers
img_input_tensor = sess.graph.get_tensor_by_name('import/input_image:0')
img_anchors_input_tensor = sess.graph.get_tensor_by_name('import/input_anchors:0')
img_meta_input_tensor = sess.graph.get_tensor_by_name('import/input_image_meta:0')
detections_input_tensor = sess.graph.get_tensor_by_name('import/output_detections:0')

detections = sess.run(detections_input_tensor,
                 feed_dict={img_input_tensor: molded_image, img_meta_input_tensor: image_meta, img_anchors_input_tensor: image_anchor})

每张图像的推理时间约为0.06毫秒。

但是,在阅读Tensorflow手册后,我注意到建议使用tf.data API来加载数据以进行推理。

# setup data input
data = tf.data.Dataset.from_tensors((img_input_tensor, img_meta_input_tensor, img_anchors_input_tensor, detections_input_tensor))
iterator = data.make_initializable_iterator()  # create the iterator
next_batch = iterator.get_next()

# load data
sess.run(iterator.initializer,
                 feed_dict={img_input_tensor: molded_image, img_meta_input_tensor: image_meta, img_anchors_input_tensor: image_anchor})

# inference
detections = sess.run([next_batch])[0][3]

这将推理时间加快到0.01ms,将加载数据所花费的时间花了0.1ms。此Iterator方法比“较慢的” feed_dict方法要长得多。我可以做些什么来加快加载过程吗?

1 个答案:

答案 0 :(得分:0)

Here是有关数据管道优化的出色指南。我个人发现.prefetch方法是增加输入管道的最简单方法。但是,本文提供了更高级的技术。

但是,如果您的输入数据不在tfrecords中,而是您自己输入,则必须以某种方式自己实现所描述的技术(缓冲,交错操作)。