建议使用张量流数据集作为输入管道,可以按如下方式设置:
# Specify dataset
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()
我应该能够获得批量大小(从数据集本身或从它创建的迭代器,即iterator
和next_batch
)。也许有人想知道数据集或其迭代器中有多少批次。或者已经调用了多少批次以及迭代器中有多少批次?人们可能也想要一次获得特定元素,甚至整个数据集。
我无法在tensorflow文档中找到任何内容。这可能吗?如果没有,是否有人知道这是否已被请求作为tensorflow GitHub上的问题?
答案 0 :(得分:1)
您只是自己指定了批量大小
dataset.batch(128)
。您的批次中有128
个示例。
答案 1 :(得分:1)
至少在 TF2 中,数据集的类型是静态定义的,可通过 tf.data.Dataset.element_spec
访问。
这是一个有点复杂的返回类型,因为它具有与您的数据集匹配的元组嵌套。
>>> tf.data.Dataset.from_tensor_slices([[[1]],[[2]]]).element_spec.shape
TensorShape([1, 1])
如果您的数据组织为元组[图像,标签],那么您将获得一个 TensorSpecs 元组。如果您确定返回类型的嵌套,则可以对其进行索引。例如
>>> image = tf.data.Dataset.from_tensor_slices([[1],[2],[3],[4]]).batch(2, drop_remainder=True)
>>> label = tf.data.Dataset.from_tensor_slices([[1],[2],[3],[4]]).batch(2, drop_remainder=True)
>>> train = tf.data.Dataset.zip((image, label))
>>> train.element_spec[0].shape[0]
2
答案 2 :(得分:0)
试试这个
import tensorflow as tf
import numpy as np
features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
print(np.shape(sess.run(batch_data)[0])[0])

答案 3 :(得分:0)
在TF2中,tf.data.Dataset
是iterables,因此您只需执行以下操作即可获得批处理:
batch = next(iter(dataset))
然后计算批次大小是微不足道的,因为它变成了first dimension的大小:
batch_size = batch.shape[0]
因此,完整的示例如下:
# Specify dataset
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset = dataset.batch(128)
# Calculate and print batch size
batch_size = next(iter(dataset)).shape[0]
print('Batch size:', batch_size) # prints 128
或者,如果需要将其用作功能:
def calculate_batch_size(dataset):
return next(iter(dataset)).shape[0]
请注意,数据集上的iterating需要急切执行。此外,此解决方案假定您的数据集已批处理,如果不是这种情况,则可能会出错。如果在批处理之后对数据集执行其他更改其元素形状的操作,则也可能会遇到错误。