我想获取我的tf.data.Dataset的长度。 (数据大小/批处理大小)
在Pytorch中,我可以通过简单的代码获得此信息:
length = len(data_loader)
但是,它在tensorflow 2.0中不起作用。
我怎么得到这个?
答案 0 :(得分:2)
在TensorFlow 2.0中,您创建了一个tf.data.Dataset
对象,它是Python的可迭代对象。
在循环遍历所有元素之前,您不会事先知道数据集中有多少个元素。
因此,假设您以这种方式创建了数据集:
batch_size = 12
dataset = tf.data.Dataset.from_tensor_slices(something).batch(batch_size)
您可以通过这种方式获得批次总数:
number_of_batches = len([_ for _ in iter(dataset)])