我使用tensorflow DataSet作为输入数据管道。我想知道如何在第一个时期没有数据改组的情况下进行训练,并开始从第二个时期开始洗牌数据。
图形通常在迭代训练开始之前构建,在训练期间,如何更改DataSet混洗行为似乎不是直截了当的,因为它看起来像是在改变图形。
任何想法?
感谢, 哈利
答案 0 :(得分:0)
Dataset.shuffle()
的buffer_size
参数可以是计算tf.Tensor
,因此您可以使用以下代码使用Dataset.range(NUM_EPOCHS).flat_map(...)
将一系列纪元数转换为({1}}混乱或其他方式)per_epoch_dataset
:
NUM_EPOCHS = ... # The total number of epochs.
BUFFER_SIZE = ... # The shuffle buffer size to use from the second epoch on.
per_epoch_dataset = ... # A `Dataset` representing the elements of a single epoch.
def shuffle_after_first_epoch(epoch):
# Set `epoch_buffer_size` to 1 (i.e. no shuffling) in the 0th epoch,
# and `BUFFER_SIZE` thereafter.
epoch_buffer_size = tf.cond(tf.equal(epoch, 0),
lambda: tf.constant(1, tf.int64),
lambda: tf.constant(BUFFER_SIZE, tf.int64))
return per_epoch_dataset.shuffle(epoch_buffer_size)
dataset = tf.data.Dataset.range(NUM_EPOCHS).flat_map(shuffle_after_first_epoch)