TensorFlow DataSet shuffle - 仅从第二个时期开始的数据混洗

时间:2018-04-18 04:19:23

标签: tensorflow tensorflow-datasets

我使用tensorflow DataSet作为输入数据管道。我想知道如何在第一个时期没有数据改组的情况下进行训练,并开始从第二个时期开始洗牌数据。

图形通常在迭代训练开始之前构建,在训练期间,如何更改DataSet混洗行为似乎不是直截了当的,因为它看起来像是在改变图形。

任何想法?

感谢, 哈利

1 个答案:

答案 0 :(得分:0)

Dataset.shuffle()buffer_size参数可以是计算tf.Tensor,因此您可以使用以下代码使用Dataset.range(NUM_EPOCHS).flat_map(...)将一系列纪元数转换为({1}}混乱或其他方式)per_epoch_dataset

的元素
NUM_EPOCHS = ...  # The total number of epochs.
BUFFER_SIZE = ...  # The shuffle buffer size to use from the second epoch on.

per_epoch_dataset = ...  # A `Dataset` representing the elements of a single epoch.

def shuffle_after_first_epoch(epoch):
  # Set `epoch_buffer_size` to 1 (i.e. no shuffling) in the 0th epoch,
  # and `BUFFER_SIZE` thereafter.
  epoch_buffer_size = tf.cond(tf.equal(epoch, 0),
                              lambda: tf.constant(1, tf.int64),
                              lambda: tf.constant(BUFFER_SIZE, tf.int64))
  return per_epoch_dataset.shuffle(epoch_buffer_size)

dataset = tf.data.Dataset.range(NUM_EPOCHS).flat_map(shuffle_after_first_epoch)