Tensorflow数据集比队列管道慢3倍?

时间:2018-04-06 14:13:51

标签: python tensorflow-datasets

我因张量流数据集api的速度而受到影响。我使用以下代码测试数据集管道和队列管道的速度:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import glob
import time

import tensorflow as tf


def disk_image_batch_dataset(img_paths, batch_size, shuffle=True, buffer_batch=128, repeat=-1):

    def parse_func(path):
        img = tf.read_file(path)
        img = tf.image.decode_png(img)
        return img

    dataset = tf.data.Dataset.from_tensor_slices(img_paths).map(parse_func)

    dataset = dataset.batch(batch_size)

    if shuffle:
        dataset = dataset.shuffle(buffer_batch)
    else:
        dataset = dataset.prefetch(buffer_batch)

    dataset = dataset.repeat(repeat)

    # make iterator
    iterator = dataset.make_one_shot_iterator()

    return iterator.get_next()


def disk_image_batch_queue(img_paths, batch_size, shuffle=True, num_threads=4, min_after_dequeue=100):

    _, img = tf.WholeFileReader().read(
        tf.train.string_input_producer(img_paths, shuffle=shuffle, capacity=len(img_paths)))
    img = tf.image.decode_png(img)
    img.set_shape([218, 178, 3])

    # batch datas
    if shuffle:
        capacity = min_after_dequeue + (num_threads + 1) * batch_size
        img_batch = tf.train.shuffle_batch([img],
                                           batch_size=batch_size,
                                           capacity=capacity,
                                           min_after_dequeue=min_after_dequeue,
                                           num_threads=num_threads)
    else:
        img_batch = tf.train.batch([img], batch_size=batch_size)

    return img_batch


paths = glob.glob('img_align_celeba/*.jpg')

with tf.Session() as sess:
    with tf.device('/cpu:0'):
        batch = disk_image_batch_dataset(paths, 128, shuffle=True, buffer_batch=128, repeat=-1)
        for _ in range(10):
            start = time.time()
            for _ in range(100):
                sess.run(batch)
            elapse = time.time() - start
            print('Dataset Average: %f ms' % (elapse / 100.0 * 1000))

        batch = disk_image_batch_queue(paths, 128, shuffle=True, num_threads=4, min_after_dequeue=100)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for _ in range(10):
            start = time.time()
            for _ in range(100):
                sess.run(batch)
            elapse = time.time() - start
            print('Queue Average: %f ms' % (elapse / 100.0 * 1000))
        coord.request_stop()
        coord.join(threads)

输出结果为:

数据集平均值:98.434892 ms

数据集平均值:38.566129 ms

数据集平均值:39.068711 ms

数据集平均值:39.103138 ms

数据集平均值:39.217770 ms

数据集平均值:39.106920 ms

数据集平均值:38.595331毫秒

数据集平均值:38.467741 ms

数据集平均值:40.517910毫秒

数据集平均值:40.987079 ms

队列平均值:44.756029毫秒

队列平均值:14.821191 ms

队列平均值:14.946148 ms

队列平均值:14.817519 ms

队列平均值:14.691849 ms

队列平均值:15.771849毫秒

队列平均值:16.030011毫秒

队列平均值:14.827731毫秒

队列平均值:12.955391毫秒

队列平均值:12.969120 ms

如何提高数据集pipline的使用率?

1 个答案:

答案 0 :(得分:0)

我找到了解决方案,我们可以在map函数中使用多线程:

let ZipList' bs'' = ZipList' bs <*> ZipList' bs'
in ZipList' (Cons (a a') Nil `mappend` bs'')