tf.data.Dataset:如何获取数据集大小(一个元素的元素数量)?

时间:2018-06-07 09:03:23

标签: python-3.x tensorflow tensorflow-datasets

我们说我已经用这种方式定义了一个数据集:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

如何获取数据集内的元素数量(因此,构成纪元的单个元素的数量)?

我知道tf.data.Dataset已经知道数据集的维度,因为repeat()方法允许重复输入管道达到指定数量的纪元。因此,它必须是获取此信息的一种方式。

17 个答案:

答案 0 :(得分:5)

您可以将其用于TF2中的TFRecords:

ds = tf.data.TFRecordDataset(dataset_filenames)
ds_size = sum(1 for _ in ds)

答案 1 :(得分:4)

tf.data.Dataset.list_files创建一个名为MatchingFiles:0的张量(如果适用,使用适当的前缀)。

你可以评估

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

获取文件数。

当然,这仅适用于简单的情况,特别是如果每​​张图像只有一个样本(或已知数量的样本)。

在更复杂的情况下,例如当您不知道每个文件中的样本数量时,您只能观察到一个时期结束时的样本数量。

为此,您可以观看Dataset计算的时期数。 repeat()创建一个名为_count的成员,用于计算时期数。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算数据集大小。

这个计数器可能埋没在连续调用成员函数时创建的Dataset层次结构中,所以我们必须像这样挖掘它。

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

请注意,使用此技术时,数据集大小的计算并不精确,因为epoch_counter递增的批处理通常会混合来自两个连续历元的样本。所以这个计算精确到你的批次长度。

答案 2 :(得分:3)

不幸的是,我不相信TF中有这样的功能。使用TF 2.0并渴望执行,您可以遍历数据集:

num_elements = 0
for element in dataset:
    num_elements += 1

这是我想出的最有效的存储方式

确实感觉这是应该在很久以前添加的功能。手指交叉,他们在以后的版本中增加了长度功能。

答案 3 :(得分:3)

在这里看看:https://github.com/tensorflow/tensorflow/issues/26966

该功能不适用于TFRecord数据集,但适用于其他类型。

TL; DR:

  

num_elements = tf.data.experimental.cardinality(dataset).numpy()

答案 4 :(得分:2)

在TF2.0中,我这样做

for num, _ in enumerate(dataset):
    pass

print(f'Number of elements: {num}')

答案 5 :(得分:1)

我看到了很多获取样本数量的方法,但实际上您可以在 keras 中轻松完成:

len(dataset) * BATCH_SIZE

答案 6 :(得分:1)

您可以在 tensorflow 2.4.0 中使用 len(filename_dataset)

答案 7 :(得分:1)

从TensorFlow(> = 2.3)开始,可以使用:

!important

请注意,在应用 print(dataset.cardinality().numpy()) 操作时,此操作可以返回-2。

答案 8 :(得分:1)

对于某些数据集(例如COCO),基数函数不返回大小。快速计算数据集大小的一种方法是使用map reduce,例如:

ds.map(lambda x: 1, num_parallel_calls=tf.data.experimental.AUTOTUNE).reduce(tf.constant(0), lambda x,_: x+1)

答案 9 :(得分:1)

len(list(dataset))在渴望模式下工作,尽管显然这不是一个好的通用解决方案。

答案 10 :(得分:0)

晚了一点,但是对于存储在TFRecord数据集中的大型数据集,我使用了这个(TF 1.15)

hypernova({
  getComponent (name, { returnMeta }) {
    returnMeta.src = 'http://localhost:3000/public/client.js'
  }
})

答案 11 :(得分:0)

对于张量流数据集,您可以使用_, info = tfds.load(with_info=True)。然后,您可以致电info.splits['train'].num_examples。但是即使在这种情况下,如果您定义自己的拆分也无法正常工作。

因此您可以对文件进行计数或遍历数据集(如其他答案中所述):

num_training_examples = 0
num_validation_examples = 0

for example in training_set:
    num_training_examples += 1

for example in validation_set:
    num_validation_examples += 1

答案 12 :(得分:0)

这对我有用:

lengt_dataset = dataset.reduce(0, lambda x,_: x+1).numpy()

迭代数据集并增加var x,它作为数据集的长度返回。

答案 13 :(得分:0)

假设您要在oxford-iiit-pet数据集中找到训练分组的数量:

ds, info = tfds.load('oxford_iiit_pet', split='train', shuffle_files=True, as_supervised=True, with_info=True)

print(info.splits['train'].num_examples)

答案 14 :(得分:0)

以下代码可在TF2中使用:

var indexPath:[IndexPath] = []
for section in 0..<self.tableView.numberOfSections {
    for row in 0..<self.tableView.numberOfRows(inSection: section) {
        guard let cell = self.tableView.cellForRow(
                at: IndexPath(row: row, section: section)) as? MyCellType else {
            return
        }
        if myCheck { // do your check here
            indexPath.append(IndexPath(row: row, section: section))
        }
    }
}
if let first = indexPath.first {
    self.tableView.scrollToRow(at: first, at: .middle, animated: true)
}

答案 15 :(得分:0)

在 version=2.5.0 中,您只需调用 print(dataset.cardinality()) 即可查看数据集的长度和类型。

答案 16 :(得分:0)

我很惊讶这个问题没有明确的解决方案,因为这是一个如此简单的特性。当我通过 TQDM 迭代数据集时,我发现 TQDM 找到了数据大小。这是如何工作的?

for x in tqdm(ds['train']):
  //Something

-> 1%|          | 15643/1281167 [00:16<07:06, 2964.90it/s]v
t=tqdm(ds['train'])
t.total
-> 1281167