批量大小恒定的tf.data.Dataset

时间:2018-11-19 21:17:33

标签: tensorflow tensorflow-datasets

我有一个包含19个元素且批处理大小为10的数据集。我将数据集设置为连续迭代相同的元素,但是我注意到最后一批仅包含4个元素而不是5个,然后从5个开始,5、5、4等。

如何强制迭代器用来自下一次迭代的元素填充较短的批次,以便所有批次具有相同的大小?

P.S。只是了解一下,这不是训练模型时的明显行为吗?

2 个答案:

答案 0 :(得分:1)

要具有此行为,应在.repeat()batch()之前调用padded_batch()方法。所以:

file_names = [...]
def my_map_func(record):
    ....
dataset = tf.data.TFRecordDataset(file_names)\
    .map(map_func=my_map_func)\
    .repeat()\  # here!
    .batch(5)

答案 1 :(得分:0)

为了进一步介绍repeatbatch的用法,我会这样说。实际上,如果您想要固定的batch_size,则放置.repeat()的位置并不重要。如果在drop_remainder=True中设置了集.batch(),则第一维将不会得到None,无论.repeat()的位置如何,批次大小都是固定的。重复和批处理非常直观地使用,例如,让我们定义4个数据集,以更改批处理和重复的位置:

import tensorflow as tf
dataset = tf.data.Dataset.range(3)

dataset1 = dataset.batch(2,drop_remainder=True)
dataset1 = dataset1.repeat()

dataset2 = dataset.repeat()
dataset2 = dataset2.batch(2,drop_remainder=True)

dataset3 = dataset.repeat()
dataset3 = dataset3.batch(2)

dataset4 = dataset.batch(2)
dataset4 = dataset4.repeat()

您将得到以下结果:

数据集1 :请注意,在任何批次中2都不存在

  • 形状:(2,):请注意,批次大小为2,而不是无
  • 第1批:[0,1]
  • 第1批:[0,1]

数据集2 :这是您想要实现的目标。注意第二批中现在有2个了

  • 形状:(2,):同样,批量大小也不是None,repeat的位置也不同
  • 第1批:[0,1]
  • 第1批:[2,0]

数据集3

  • shape:(?,):因为您没有使用drop_remainder=True,所以得到None(无),但是您将始终获得固定大小为2的批次。重复一次
  • 第1批:[0,1]
  • 第2批:[2,0]
  • 第3批:[1,2]

数据集4

  • shape(?,):与Dataset3相同
  • 第1批:[0,1]
  • 批次2:[2]您得到的批次大小“不完整”
  • 第3批:[0,1]