如何获取tf.data.Dataset的长度(data_size / batch_size)?

时间:2019-10-13 16:08:07

标签: tensorflow tensorflow2.0

我想获取我的tf.data.Dataset的长度。 (数据大小/批处理大小)

在Pytorch中,我可以通过简单的代码获得此信息:

length = len(data_loader)

但是,它在tensorflow 2.0中不起作用。

我怎么得到这个?

1 个答案:

答案 0 :(得分:2)

在TensorFlow 2.0中,您创建了一个tf.data.Dataset对象,它是Python的可迭代对象。

在循环遍历所有元素之前,您不会事先知道数据集中有多少个元素。

因此,假设您以这种方式创建了数据集:

batch_size = 12
dataset = tf.data.Dataset.from_tensor_slices(something).batch(batch_size)

您可以通过这种方式获得批次总数:

number_of_batches = len([_ for _ in iter(dataset)])