使用tf.data.Dataset.from_generator()并行批处理

时间:2018-06-20 23:08:17

标签: python tensorflow parallel-processing

我有一种计算昂贵的方法来模拟编码为生成器的数据,该数据要输入到Tensorflow的神经网络模型中。我想通过并行化对生成器的调用来快速生成一批数据。我当前的数据输入管道如下图所示:

def data_iterator():
     # data generation procedure to be parallelized
     pass

dataset = tf.data.Dataset.from_generator(data_iterator, 
                                        (tf.float32,tf.float32), 
                                        (tf.TensorShape([HEIGHT, None, 1]), 
                                         tf.TensorShape([2])))

dataset = dataset.padded_batch(BATCH_SIZE, 
                               padded_shapes=(tf.TensorShape([HEIGHT, None, 1]), 
                                              tf.TensorShape([2])))
iterator = dataset.make_one_shot_iterator()
x_image, y_ = iterator.get_next()

不幸的是,我发现使用from_generator()方法创建一批数据无法轻松并行化,而无法指定我使用多少线程从生成器中提取数据来创建批处理。我在此线程(https://stackoverflow.com/a/47089278)中尝试了一种解决方案,该线程使用包装函数并并行映射,发现速度根本没有增加。以前,我使用队列输入数据,如下所示:

class CustomRunner(object):
    """
    This class manages the the background threads needed to fill
        a queue full of data.
    """
    def __init__(self):
        self.dataX = tf.placeholder(dtype=tf.float32, shape=[HEIGHT,WIDTH,1])
        self.dataY = tf.placeholder(dtype=tf.float32, shape=[2])
        # The actual queue of data. The queue contains a vector for
        # the features, and a scalar label.
        self.queue = tf.FIFOQueue(shapes=[[HEIGHT,WIDTH,1],[2]],
                                           dtypes=[tf.float32, tf.float32],
                                           capacity=QUEUE_CAPACITY)
        self.enqueue_op = self.queue.enqueue([self.dataX, self.dataY])

    def get_inputs(self,batch_size):
        """
        Return's tensors containing a batch of images and labels
        """
        images_batch, labels_batch = self.queue.dequeue_many(batch_size)
        return images_batch, labels_batch

    def thread_main(self, sess):
        """
        Function run on alternate thread. Basically, keep adding data to the queue.
        """
        for dataX, dataY in data_iterator():
            sess.run(self.enqueue_op, feed_dict={self.dataX:dataX,  self.dataY:dataY})

    def start_threads(self, sess, n_threads=1):
        """ Start background threads to feed queue """
        threads = []
        for n in range(n_threads):
            t = threading.Thread(target=self.thread_main, args=(sess,))
            t.daemon = True # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads

sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=24, 
                 inter_op_parallelism_threads=24))

我不得不更改此设置,因为我无法使用可变宽度的字典。有人知道这是否有好的解决方案吗?

0 个答案:

没有答案