TensorFlow中迭代器之间的复用

时间:2018-11-12 23:21:20

标签: python tensorflow

功能上我需要的:我的数据集按块分区,每个块都位于一个二进制文件中。 我有一种算法,该算法也可对块进行操作以降低计算复杂度,然后在访问所有块后将结果合并在一起。重要的是,要有一个单个的小批量数据源自一个块,并确切地知道哪个块能够将某些参数传递到特定于该特定块的图形中。下一次迭代,当再次从块0开始时,应使用所有块的下一个小批量。块可以具有不等长,并且应该永远重复。

我当前的解决方案:当前,我为每个块(即每个文件)创建一个tf.Iterator,并使用tf.data.FixedLengthRecordDataset创建一个

// for every file:
ds = tf.dataFixedLengthRecordDataset(...)
ds = ds.repeat()
ds = ds.batch(...)
ds = ds.map(...)
ds = ds.prefetch(buffer_size=1)

it = ds.make_one_shot_iterator()

然后,我有一个“主”迭代器,它在文件级迭代器之间进行多路复用。这是通过以下方式完成的:

itr_handle = tf.placeholder(tf.string, shape=())
master_itr = tf.data.Iterator.from_string_handle(itr_handle, output_types)
master_next = master_itr.get_next()

因此,每次执行图形时,我都会将要用于此执行的相应迭代器的字符串句柄传递给占位符。这样,每个文件级迭代器仍具有自己的状态;因此,当要求相同的块文件进行下一个迷你批处理时,它会有效地返回下一个迷你批处理,而不是重新打开文件,而只是简单地再次返回第一个迷你批处理。

问题:创建文件级迭代器的过程缓慢。每个文件创建一个迭代器至少需要200毫秒。我使用的数据集可以轻松包含多达100个块文件,这导致TensorFlow / Python停在那里,使这些Iterator对象和图形节点停留20秒钟,而实际上未处理任何数据。

问题:

  1. 例如仅使用一个迭代器来解决此问题的另一种方法?
  2. 否则,如何加快Iterator的创建速度?

0 个答案:

没有答案