设置到Tensorflow的数据输入管道(网络摄像头图像)时,要花费大量时间将数据从系统RAM加载到GPU内存。
我正在尝试通过对象检测网络提供恒定的图像流(1024x1024)。我目前正在AWS上使用V100进行推理。
第一次尝试是通过简单的feed dict操作。
# Get layers
img_input_tensor = sess.graph.get_tensor_by_name('import/input_image:0')
img_anchors_input_tensor = sess.graph.get_tensor_by_name('import/input_anchors:0')
img_meta_input_tensor = sess.graph.get_tensor_by_name('import/input_image_meta:0')
detections_input_tensor = sess.graph.get_tensor_by_name('import/output_detections:0')
detections = sess.run(detections_input_tensor,
feed_dict={img_input_tensor: molded_image, img_meta_input_tensor: image_meta, img_anchors_input_tensor: image_anchor})
每张图像的推理时间约为0.06毫秒。
但是,在阅读Tensorflow手册后,我注意到建议使用tf.data
API来加载数据以进行推理。
# setup data input
data = tf.data.Dataset.from_tensors((img_input_tensor, img_meta_input_tensor, img_anchors_input_tensor, detections_input_tensor))
iterator = data.make_initializable_iterator() # create the iterator
next_batch = iterator.get_next()
# load data
sess.run(iterator.initializer,
feed_dict={img_input_tensor: molded_image, img_meta_input_tensor: image_meta, img_anchors_input_tensor: image_anchor})
# inference
detections = sess.run([next_batch])[0][3]
这将推理时间加快到0.01ms,将加载数据所花费的时间花了0.1ms。此Iterator
方法比“较慢的” feed_dict
方法要长得多。我可以做些什么来加快加载过程吗?
答案 0 :(得分:0)
Here是有关数据管道优化的出色指南。我个人发现.prefetch
方法是增加输入管道的最简单方法。但是,本文提供了更高级的技术。
但是,如果您的输入数据不在tfrecords中,而是您自己输入,则必须以某种方式自己实现所描述的技术(缓冲,交错操作)。