我想知道在使用Dataset
时如何强制使用包含固定数量样本的批次。
例如,
import numpy as np
import tensorflow as tf
dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess = tf.InteractiveSession()
try:
while True:
print(batch.eval().shape)
except tf.errors.OutOfRangeError:
pass
在这个玩具示例中,数据共有101个样本,我要求批量10个样本。迭代时,最后一批的大小为1,这是我想要避免的。
在前一个(基于队列的)API中,tf.train.batch
具有allow_smaller_final_batch
参数,默认情况下设置为False
。我想用Dataset
重现此行为。
我想我可以使用Dataset.filter
:
dataset = tf.data.Dataset.range(101).batch(10)
.filter(lambda x: tf.equal(tf.shape(x)[0], 10))
但肯定应该有一些内置的方法来做到这一点?
答案 0 :(得分:1)
您可以使用tf.contrib.data.batch_and_drop_remainder(batch_size)
执行此操作:
dataset = tf.data.Dataset.range(101).apply(
tf.contrib.data.batch_and_drop_remainder(10))
答案 1 :(得分:1)
对于tensorflow>=2.0.0
,您可以将drop_remainder
的方法batch
的{{1}}参数用作:
tf.data.Dataset
dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True)
参数设置是否删除最后一批(如果其少于drop_remainder
个元素)。默认值为False。
我希望这对2019年以后的读者有帮助