Tensorflow数据集过滤器,然后有效地重复

时间:2020-02-03 20:29:09

标签: tensorflow tensorflow2.0 tensorflow-datasets

我有一个特殊情况。我有一个庞大的数组_size = 1000000。我里面有四个具有特定值1的元素,其余所有元素都是0。我应用filter以便获得新的dataset,当对它进行迭代时会得到以下四个元素。

import numpy as np
import tensorflow as tf
import datetime

# huge array
_size = 1000000
index = np.arange(_size)
data = np.zeros(_size, dtype=np.int)

# mutate four elements to 1
np.random.seed(1234)
for i in range(4):
    data[np.random.randint(0, _size)] = 1

# create dataset
dataset = tf.data.Dataset.from_tensor_slices({"index": index, "data": data})

# apply filter
print("\n\nDataset after filtering\n")
dataset = dataset.filter(lambda x: x["data"] == 1)

# loop over
_start = datetime.datetime.now()
for d in dataset.as_numpy_iterator():
    print(d, f"- time taken "
             f"{(datetime.datetime.now() - _start).total_seconds()}s")

这将导致以下输出:

Dataset after filtering

{'index': 165158, 'data': 1} - time taken 3.401963s
{'index': 451283, 'data': 1} - time taken 9.146955s
{'index': 486191, 'data': 1} - time taken 9.843954s
{'index': 908341, 'data': 1} - time taken 18.632945s

如您所见,filter操作花费了大约20秒才能获取所有四个元素。现在,我想应用重复,并希望通过某种预取机制可以使第二个重复变得聪明,这样就无需再次执行搜索。有可能吗?

请参见以下代码:

# now repeat
print("\n\nDataset after filtering and repeat for two times\n")
dataset = dataset.repeat(2)

# loop over
_start = datetime.datetime.now()
for d in dataset.as_numpy_iterator():
    print(d, f"- time taken "
             f"{(datetime.datetime.now() - _start).total_seconds()}s")

其输出为:

Dataset after filtering and repeat for two times

{'index': 165158, 'data': 1} - time taken 3.316989s
{'index': 451283, 'data': 1} - time taken 9.274983s
{'index': 486191, 'data': 1} - time taken 10.044986s
{'index': 908341, 'data': 1} - time taken 19.286973s
{'index': 165158, 'data': 1} - time taken 24.989968s
{'index': 451283, 'data': 1} - time taken 31.962961s
{'index': 486191, 'data': 1} - time taken 32.836961s
{'index': 908341, 'data': 1} - time taken 43.239948s

预期的前四个元素大约需要20秒。但是接下来的四个元素再次花费相同的时间,即总计约40秒。这确实意味着将再次扫描整个阵列。我可以通过使用一些缓存或预取来节省第二次重复的时间吗?

请注意,我们开始时使用的数据集不是无限的,并且结束了,那么为什么第二次重复进行过滤需要花费时间?

0 个答案:

没有答案