我了解您可以为数据集分配批处理大小并返回一个新的数据集对象。给定数据集对象时,是否有API可以查询批量大小?
我试图在以下位置找到电话
答案 0 :(得分:1)
我不知道您是否可以将其作为属性获取,但是您可以只遍历数据集一次并打印形状:
# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3
# iterating once
for one_batch in f:
print('batch size:', one_batch.shape[0])
break
如果您知道数据集也具有目标/标签,则必须进行以下迭代:
# iterating once
for one_batch_x, one_batch_y in f:
print('batch size:', one_batch_x.shape[0])
break
在两种情况下,它将打印:
batch size: 3
答案 1 :(得分:1)
当您调用.batch(32)
方法时,它将返回一个tensorflow.python.data.ops.dataset_ops.BatchDataset
对象。如Tensorflow Documentation中所述,此类对象具有称为._batch_size
的私有属性,其中包含一个batch_size张量。
在tensorflow 2.X中,您只需要调用此张量的.numpy()
方法即可将其转换为numpy.int64
类型。
在tensorflow 1.X中,您需要校准.eval()
方法。
答案 2 :(得分:0)
在Tensorflow 1. *中,通过batch_size
访问dataset._dataset._batch_size
:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
with tf.compat.v1.Session() as sess:
batch_size = sess.run(dataset._dataset._batch_size)
print(batch_size) # 10
在Tensorflow 2中,您可以通过dataset._batch_size
访问:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
batch_size = dataset._batch_size.numpy()
print(batch_size) # 10