使用“ tf.data”有效地计算TFRecords中的示例数?

时间:2018-11-07 00:56:50

标签: tensorflow tensorflow-datasets

我正在编写一些简单的代码来计算TFRecord文件中的示例数。以前我使用的是这样的东西:

def _count_example(path):
    return sum(1 for _ in tf.python_io.tf_record_iterator(path))

def _count_total_examples(tfrecord_glob_pattern):
    example_count = 0
    paths = glob.glob(tfrecord_glob_pattern)

    with futures.ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        results = executor.map(_count_example, paths, chunksize=max(len(paths) // os.cpu_count(), 1))
        for result in results:
            example_count += result

    return example_count

我已将测试TFRecord输入复制到RAM磁盘上,以确保我们不受I / O的限制。我的测试TFRecord文件的635,424个图像文件的128个分片。上面的代码在13秒内完成。

我正在尝试将相同的代码更改为使用tf.data,以查看是否可以获得类似的结果。所以我正在做类似的事情:

def _count_total_examples(tfrecord_glob_pattern):
    dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
    dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count())
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with tf.Session() as sess:
        example_count = 0
        for _ in trange(635424):
            sess.run(next_element)
            example_count += 1
    return example_count

此版本需要2分钟以上才能完成。我猜这是因为我打电话sess.run()的次数太多了,而且这开销很小。为了验证这一说法,我将上面的代码更改为:

def _count_total_examples(tfrecord_glob_pattern):
    dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
    dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count())
    dataset = dataset.map(lambda _: 1, num_parallel_calls=os.cpu_count())
    dataset = dataset.batch(4096)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with tf.Session() as sess:
        example_count = 0
        for _ in trange(635424 // 4096):
            sess.run(next_element)
            example_count += 1
    return example_count

这次代码在30秒内完成-仍然不理想,但我想更好。基本上,根据我的上述观察,我有两个问题:

1)解决此问题的有效方法是什么?还是我应该继续使用tf.python_io.tf_record_iterator

2)是否可以在TensorFlow图中“循环整个数据集,然后返回最终结果”?我已经尝试过类似tf.contrib.data.Reducer的方法,但是它似乎仍然很慢(大约需要一分钟才能完成):

def _count_total_examples(tfrecord_glob_pattern):
    dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
    dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count() * 2)
    reducer = tf.contrib.data.Reducer(
        init_func=lambda _: 0,
        reduce_func=lambda curr, _: curr + 1,
        finalize_func=lambda curr: curr
    )
    dataset = tf.contrib.data.reduce_dataset(dataset, reducer)
    with tf.Session() as sess:
        result = sess.run(dataset)
        example_count = result
    return example_count

0 个答案:

没有答案