我有一个生成器,可以生成无限量的数据(随机图像裁剪)。我想基于10,000个第一个数据点创建一个tf.Dataset
并将其缓存以使用它们来训练模型?
当前,我有一个生成器,该生成器需要1-2秒来创建每个数据点,这是主要的性能阻止器。我必须等待一分钟才能生成一批64张图像(preprocessing()
函数非常昂贵,因此我想重用结果)。
ds = tf.Dataset.from_generator()
方法允许我们创建这样的无限数据集。相反,我想使用生成器的N个优先输出来创建有限数据集,并将其缓存为:
ds = ds.cache()
。
另一种解决方案是继续生成新数据,并在呈现生成器时使用缓存的数据点。
答案 0 :(得分:1)
您可以将Dataset.cache
函数与Dataset.take
函数配合使用。
如果所有内容都适合内存,就像做这样的事情一样简单:
def generate_example():
i = 0
while(True):
print ('yielding value {}'.format(i))
yield tf.random.uniform((64,64,3))
i +=1
ds = tf.data.Dataset.from_generator(generate_example, tf.float32)
first_n_datapoints = ds.take(n).cache()
现在请注意,如果我将n
设置为3,请执行以下简单操作:
for i in first_n_datapoints.repeat():
print ('')
print (i.shape)
然后我看到输出确认已缓存前三个值(对于生成的前三个值,我只看到yielding value {i}
输出一次:
yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...
如果一切都无法容纳在内存中,那么我们可以将文件路径传递给缓存函数,在缓存中它将生成的张量缓存到磁盘。
此处有更多信息:https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache