Tensorflow:如何查找tf.data.Dataset API对象的大小

时间:2018-06-19 01:17:13

标签: python tensorflow tensorflow-datasets

我理解Dataset API是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小。我正在谈论存储在文本文件或tfRecord文件中的大型数据语料库。通常使用tf.data.TextLineDataset或类似的东西来读取这些文件。使用tf.data.Dataset.from_tensor_slices找到加载的数据集的大小是微不足道的。

我询问数据集大小的原因如下: 我们说我的数据集大小是1000个元素。批量大小= 50个元素。然后训练步骤/批次(假设1个纪元)= 20.在这20个步骤中,我想将我的学习率从0.1递减到0.01为

tf.train.exponential_decay(
    learning_rate = 0.1,
    global_step = global_step,
    decay_steps = 20,
    decay_rate = 0.1,
    staircase=False,
    name=None
)

在上面的代码中,我有"和"想设置decay_steps = number of steps/batches per epoch = num_elements/batch_size。只有在事先知道数据集中的元素数量时才能计算出这一点。

提前知道大小的另一个原因是使用tf.data.Dataset.take()tf.data.Dataset.skip()方法将数据拆分为训练集和测试集。

PS:我不是在寻找蛮力的方法,比如遍历整个数据集并更新计数器以计算元素数量或putting a very large batch size and then finding the size of the resultant dataset等。

3 个答案:

答案 0 :(得分:1)

我知道这个问题已有两年了,但是也许这个答案会有用。

如果您使用tf.data.TextLineDataset读取数据,那么获取样本数量的一种方法可能是计算所使用的所有文本文件中的行数。

请考虑以下示例:

import random
import string
import tensorflow as tf

filenames = ["data0.txt", "data1.txt", "data2.txt"]

# Generate synthetic data.
for filename in filenames:
    with open(filename, "w") as f:
        lines = [random.choice(string.ascii_letters) for _ in range(random.randint(10, 100))]
        print("\n".join(lines), file=f)

dataset = tf.data.TextLineDataset(filenames)

尝试使用len来获取长度会引发TypeError

len(dataset)

但是可以相对快速地计算文件中的行数。

# https://stackoverflow.com/q/845058/5666087
def get_n_lines(filepath):
    i = -1
    with open(filepath) as f:
        for i, _ in enumerate(f):
            pass
    return i + 1

n_lines = sum(get_n_lines(f) for f in filenames)

在上面,n_lines等于使用{p>遍历数据集时发现的元素数量。

for i, _ in enumerate(dataset):
    pass
n_lines == i + 1

答案 1 :(得分:0)

您是否可以手动指定数据集的大小?

我如何加载数据:

sample_id_hldr = tf.placeholder(dtype=tf.int64, shape=(None,), name="samples")

sample_ids = tf.Variable(sample_id_hldr, validate_shape=False, name="samples_cache")
num_samples = tf.size(sample_ids)

data = tf.data.Dataset.from_tensor_slices(sample_ids)
# "load" data by id:
# return (id, data) for each id
data = data.map(
    lambda id: (id, some_load_op(id))
)

在这里,您可以通过使用占位符初始化sample_ids一次来指定所有样品ID。
您的样品编号可能是文件路径或简单数字(np.arange(num_elems)

然后可以在num_samples中找到元素的数量。

答案 2 :(得分:0)

您可以使用轻松获得数据样本数量:

dataset.__len__()

您可以像这样获得每个元素:

for step, element in enumerate(dataset.as_numpy_iterator()):
...     print(step, element)

您还可以得到一个样品的形状:

dataset.element_spec

如果要使用特定元素,也可以使用分片方法。