如何从生成器创建固定长度的tf.Dataset?

时间:2019-06-08 23:32:49

标签: tensorflow tensorflow-datasets

我有一个生成器,可以生成无限量的数据(随机图像裁剪)。我想基于10,000个第一个数据点创建一个tf.Dataset并将其缓存以使用它们来训练模型?

当前,我有一个生成器,该生成器需要1-2秒来创建每个数据点,这是主要的性能阻止器。我必须等待一分钟才能生成一批64张图像(preprocessing()函数非常昂贵,因此我想重用结果)。

ds = tf.Dataset.from_generator()方法允许我们创建这样的无限数据集。相反,我想使用生成器的N个优先输出来创建有限数据集,并将其缓存为:

ds = ds.cache()


另一种解决方案是继续生成新数据,并在呈现生成器时使用缓存的数据点。

1 个答案:

答案 0 :(得分:1)

您可以将Dataset.cache函数与Dataset.take函数配合使用。

如果所有内容都适合内存,就像做这样的事情一样简单:

def generate_example():
  i = 0
  while(True):
    print ('yielding value {}'.format(i))
    yield tf.random.uniform((64,64,3))
    i +=1

ds = tf.data.Dataset.from_generator(generate_example, tf.float32)

first_n_datapoints = ds.take(n).cache()

现在请注意,如果我将n设置为3,请执行以下简单操作:

for i in first_n_datapoints.repeat():
  print ('')
  print (i.shape)

然后我看到输出确认已缓存前三个值(对于生成的前三个值,我只看到yielding value {i}输出一次:

yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...

如果一切都无法容纳在内存中,那么我们可以将文件路径传递给缓存函数,在缓存中它将生成的张量缓存到磁盘。

此处有更多信息:https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache