使用tf.distribute时,如何避免在每个tf.keras纪元上重新填充tf.data随机缓冲区?

时间:2019-01-16 12:31:01

标签: python tensorflow keras

我正在使用{在tf-nightly-gpu==1.13.0.dev20190116上的Keras模型中:

with tf.distribute.MirroredStrategy().scope():
    model = tf.keras.Model(...)

和具有以下内容的数据集:

dataset = (tf.data.Dataset
    .list_files(...)
    .map(load_example)
    .cache()
    .shuffle(100)
    .repeat())

然后进行培训

model.fit(dataset, epochs=10, steps_per_epoch=1000)

效果很好,因为它会自动在单机多GPU设置上分割我的微型批处理。太酷了!

但是,我的随机播放缓冲区在每个时期都重新填充。有没有一种方法可以使改组缓冲区保持在某个时期?我尝试直接使用迭代器和张量调用model.fit,但是tf.distribute不支持(还好吗?),而是引发了异常。

TL; DR:如何确保在各个时期维护我的tf.data随机播放缓冲区?

1 个答案:

答案 0 :(得分:0)

shuffle有一个参数reshuffle_each_iteration,可以将其设置为False,以便仅在第1个时期发生混洗,并且在以后的时期中保持状态。