在生成分片的tfrecords时创建循环分片

时间:2019-08-23 19:15:32

标签: tensorflow computer-vision image-segmentation sharding tfrecord

我是tensorflow的新手,我正在研究tensorflow 1.14中的图像分割问题。当我尝试生成一个大的tfrecord文件时,我有一个庞大的数据集,并且生成tfrecords的过程非常缓慢。因此,我想创建tfrecords的“ n”个碎片。我找不到在线进行的方法。假设我有600张图片和600张蒙版。我想生成6个tfrecords碎片,每个以循环方式包含100张图像和100个蒙版。我想要的高级/ pseudo代码如下-

sharded_tf_record_writer:
create n TFRecordWriter
----> for each_item in n TFRecordWriter
      -----> write_example in round-robin fashion

我确实在线搜索,但是找不到相关答案。我不想使用Apache Beam进行分片。感谢您为实现这一目标而提出的任何想法/帮助/指导。

2 个答案:

答案 0 :(得分:0)

我曾在张量流数据集中的一个问题中问过同样的问题,而用户-Conchylicultor对此做出了回应-

  

写入由_TFRecordWriter完成。 Tfds将自动计算所需的分片数量,并在各个分片之间分发示例,但是每个分片都是按顺序编写的。   您无法控制分片的数量,它也会自动计算。

     

但是,由于示例没有在并行处理中进行处理,因此示例在分片之间分布的事实并不能使编写速度更快。如果要并行处理,则必须使用Apache Beam,它甚至可以扩展到庞大的数据集

张量流/数据集问题的链接是-https://github.com/tensorflow/datasets/issues/676

这可能会有所帮助。

答案 1 :(得分:0)

由于您在 tensorflow 中使用对象检测,因此官方 Tensorflow models 存储库中有一些不错的代码可以满足您的需求。请注意,此代码适用于 Tensorflow2(不确定它是否适用于 TF1)

请参阅从 coco 注释编写分片 tfrecords 的 example。这个想法是你在退出堆栈中打开一个 TFRecordWriter 列表(使用 contextlib2.ExitStack()),当每个线程完成写入时,它会自动关闭 TFRecords。

效用函数 open_sharded_output_tfrecords 函数创建这个 TFRecordWriter 列表

import contextlib2
import tensorflow as tf
with contextlib2.ExitStack() as tf_record_close_stack, tf.gfile.GFile(
    annotations_file, 'r'
) as fid:
    output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
        tf_record_close_stack, output_path, num_shards
    )

接下来,您可以使用 ProcessPoolExecutor 以循环方式并行将 tfrecords 写入每个分片(在此示例中为 4 个工作线程)

from concurrent.futures.process import ProcesPoolExecutor
with ProcessPoolExecutor(4) as executor:
    for idx, image in enumerate(images):
        futures = []
        future = executor.submit(
            _write_tf_record,
            image,
            idx,
            num_shards,
            output_tfrecords,
        )
        futures.append(future)
    for future in futures:
        future.result()

其中 _write_tf_record 可能如下所示:

def _write_tf_record(image, idx, num_shards, output_tfrecords)
    tf_example = create_tf_example(image)
    shard_idx = idx % num_shards
    output_tfrecords[shard_idx].write(tf_example.SerializeToString())

只要确保你有比多进程工作者更多的分片,否则两个不同的进程可能会访问同一个写入器。