我正在编写一些简单的代码来计算TFRecord文件中的示例数。以前我使用的是这样的东西:
def _count_example(path):
return sum(1 for _ in tf.python_io.tf_record_iterator(path))
def _count_total_examples(tfrecord_glob_pattern):
example_count = 0
paths = glob.glob(tfrecord_glob_pattern)
with futures.ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
results = executor.map(_count_example, paths, chunksize=max(len(paths) // os.cpu_count(), 1))
for result in results:
example_count += result
return example_count
我已将测试TFRecord输入复制到RAM磁盘上,以确保我们不受I / O的限制。我的测试TFRecord文件的635,424个图像文件的128个分片。上面的代码在13秒内完成。
我正在尝试将相同的代码更改为使用tf.data
,以查看是否可以获得类似的结果。所以我正在做类似的事情:
def _count_total_examples(tfrecord_glob_pattern):
dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count())
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
example_count = 0
for _ in trange(635424):
sess.run(next_element)
example_count += 1
return example_count
此版本需要2分钟以上才能完成。我猜这是因为我打电话sess.run()
的次数太多了,而且这开销很小。为了验证这一说法,我将上面的代码更改为:
def _count_total_examples(tfrecord_glob_pattern):
dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count())
dataset = dataset.map(lambda _: 1, num_parallel_calls=os.cpu_count())
dataset = dataset.batch(4096)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
example_count = 0
for _ in trange(635424 // 4096):
sess.run(next_element)
example_count += 1
return example_count
这次代码在30秒内完成-仍然不理想,但我想更好。基本上,根据我的上述观察,我有两个问题:
1)解决此问题的有效方法是什么?还是我应该继续使用tf.python_io.tf_record_iterator
?
2)是否可以在TensorFlow图中“循环整个数据集,然后返回最终结果”?我已经尝试过类似tf.contrib.data.Reducer
的方法,但是它似乎仍然很慢(大约需要一分钟才能完成):
def _count_total_examples(tfrecord_glob_pattern):
dataset = tf.data.Dataset.list_files(tfrecord_glob_pattern, shuffle=False)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=os.cpu_count() * 2)
reducer = tf.contrib.data.Reducer(
init_func=lambda _: 0,
reduce_func=lambda curr, _: curr + 1,
finalize_func=lambda curr: curr
)
dataset = tf.contrib.data.reduce_dataset(dataset, reducer)
with tf.Session() as sess:
result = sess.run(dataset)
example_count = result
return example_count