有没有一种方法可以找到tf.data.Dataset的批处理大小

时间:2020-03-23 18:45:46

标签: tensorflow

我了解您可以为数据集分配批处理大小并返回一个新的数据集对象。给定数据集对象时,是否有API可以查询批量大小?

我试图在以下位置找到电话

https://www.tensorflow.org/api_docs/python/tf/data/Dataset

3 个答案:

答案 0 :(得分:1)

我不知道您是否可以将其作为属性获取,但是您可以只遍历数据集一次并打印形状:

# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf 
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3

# iterating once
for one_batch in f:
    print('batch size:', one_batch.shape[0])
    break

如果您知道数据集也具有目标/标签,则必须进行以下迭代:

# iterating once
for one_batch_x, one_batch_y in f:
    print('batch size:', one_batch_x.shape[0])
    break

在两种情况下,它将打印:

batch size:  3

答案 1 :(得分:1)

当您调用.batch(32)方法时,它将返回一个tensorflow.python.data.ops.dataset_ops.BatchDataset对象。如Tensorflow Documentation中所述,此类对象具有称为._batch_size的私有属性,其中包含一个batch_size张量。

在tensorflow 2.X中,您只需要调用此张量的.numpy()方法即可将其转换为numpy.int64类型。 在tensorflow 1.X中,您需要校准.eval()方法。

答案 2 :(得分:0)

在Tensorflow 1. *中,通过batch_size访问dataset._dataset._batch_size

import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

with tf.compat.v1.Session() as sess:
    batch_size = sess.run(dataset._dataset._batch_size)
    print(batch_size) # 10

在Tensorflow 2中,您可以通过dataset._batch_size访问:

import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

batch_size = dataset._batch_size.numpy()

print(batch_size) # 10