在TensorFlow 2.0中,如何查看数据集中的元素数量?

时间:2019-05-29 23:45:38

标签: python tensorflow dataset tensorflow-datasets tensorflow2.0

当我加载数据集时,我想知道是否有任何快速方法可以找到该数据集中的样品或批次数量。我知道如果我用with_info=True加载数据集,我可以看到例如total_num_examples=6000,,但是如果我拆分数据集,此信息将不可用。

目前,我按以下方式计算样本数量,但想知道是否有更好的解决方案:

train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)

cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)

cifar10_trainsub3 = cifar10_trainsub3.batch(1000)

n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
    print(i, n, batch['image'].shape)
    n += len(batch['image'])

print(i, n)

1 个答案:

答案 0 :(得分:1)

如果有可能知道长度,则可以使用:

tf.data.experimental.cardinality(dataset)

但是问题是TF数据集固有地延迟加载。因此,我们可能不知道数据集的大小。确实,完全有可能使数据集表示无限的数据集!

如果数据集足够小,您也可以对其进行迭代以获取长度。我之前使用了以下丑陋的小结构,但这取决于数据集足够小,我们可以高兴地将其加载到内存中,并且实际上这并不是您上面的for循环的改进!

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1