我有一种计算昂贵的方法来模拟编码为生成器的数据,该数据要输入到Tensorflow的神经网络模型中。我想通过并行化对生成器的调用来快速生成一批数据。我当前的数据输入管道如下图所示:
def data_iterator():
# data generation procedure to be parallelized
pass
dataset = tf.data.Dataset.from_generator(data_iterator,
(tf.float32,tf.float32),
(tf.TensorShape([HEIGHT, None, 1]),
tf.TensorShape([2])))
dataset = dataset.padded_batch(BATCH_SIZE,
padded_shapes=(tf.TensorShape([HEIGHT, None, 1]),
tf.TensorShape([2])))
iterator = dataset.make_one_shot_iterator()
x_image, y_ = iterator.get_next()
不幸的是,我发现使用from_generator()
方法创建一批数据无法轻松并行化,而无法指定我使用多少线程从生成器中提取数据来创建批处理。我在此线程(https://stackoverflow.com/a/47089278)中尝试了一种解决方案,该线程使用包装函数并并行映射,发现速度根本没有增加。以前,我使用队列输入数据,如下所示:
class CustomRunner(object):
"""
This class manages the the background threads needed to fill
a queue full of data.
"""
def __init__(self):
self.dataX = tf.placeholder(dtype=tf.float32, shape=[HEIGHT,WIDTH,1])
self.dataY = tf.placeholder(dtype=tf.float32, shape=[2])
# The actual queue of data. The queue contains a vector for
# the features, and a scalar label.
self.queue = tf.FIFOQueue(shapes=[[HEIGHT,WIDTH,1],[2]],
dtypes=[tf.float32, tf.float32],
capacity=QUEUE_CAPACITY)
self.enqueue_op = self.queue.enqueue([self.dataX, self.dataY])
def get_inputs(self,batch_size):
"""
Return's tensors containing a batch of images and labels
"""
images_batch, labels_batch = self.queue.dequeue_many(batch_size)
return images_batch, labels_batch
def thread_main(self, sess):
"""
Function run on alternate thread. Basically, keep adding data to the queue.
"""
for dataX, dataY in data_iterator():
sess.run(self.enqueue_op, feed_dict={self.dataX:dataX, self.dataY:dataY})
def start_threads(self, sess, n_threads=1):
""" Start background threads to feed queue """
threads = []
for n in range(n_threads):
t = threading.Thread(target=self.thread_main, args=(sess,))
t.daemon = True # thread will close when parent quits
t.start()
threads.append(t)
return threads
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=24,
inter_op_parallelism_threads=24))
我不得不更改此设置,因为我无法使用可变宽度的字典。有人知道这是否有好的解决方案吗?