使用数据集固定大小批次(可能丢弃最后一批)

时间:2017-11-20 14:50:01

标签: python tensorflow

我想知道在使用Dataset时如何强制使用包含固定数量样本的批次。

例如,

import numpy as np
import tensorflow as tf

dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

sess = tf.InteractiveSession()
try:
  while True:
    print(batch.eval().shape)
except tf.errors.OutOfRangeError:
  pass

在这个玩具示例中,数据共有101个样本,我要求批量10个样本。迭代时,最后一批的大小为1,这是我想要避免的。

在前一个(基于队列的)API中,tf.train.batch具有allow_smaller_final_batch参数,默认情况下设置为False。我想用Dataset重现此行为。

我想我可以使用Dataset.filter

dataset = tf.data.Dataset.range(101).batch(10)
  .filter(lambda x: tf.equal(tf.shape(x)[0], 10))

但肯定应该有一些内置的方法来做到这一点?

2 个答案:

答案 0 :(得分:1)

您可以使用tf.contrib.data.batch_and_drop_remainder(batch_size)执行此操作:

dataset = tf.data.Dataset.range(101).apply(
    tf.contrib.data.batch_and_drop_remainder(10))

答案 1 :(得分:1)

对于tensorflow>=2.0.0,您可以将drop_remainder的方法batch的{​​{1}}参数用作:

tf.data.Dataset

dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True) 参数设置是否删除最后一批(如果其少于drop_remainder个元素)。默认值为False。

我希望这对2019年以后的读者有帮助