我正在使用{在tf-nightly-gpu==1.13.0.dev20190116
上的Keras模型中:
with tf.distribute.MirroredStrategy().scope():
model = tf.keras.Model(...)
和具有以下内容的数据集:
dataset = (tf.data.Dataset
.list_files(...)
.map(load_example)
.cache()
.shuffle(100)
.repeat())
然后进行培训
model.fit(dataset, epochs=10, steps_per_epoch=1000)
效果很好,因为它会自动在单机多GPU设置上分割我的微型批处理。太酷了!
但是,我的随机播放缓冲区在每个时期都重新填充。有没有一种方法可以使改组缓冲区保持在某个时期?我尝试直接使用迭代器和张量调用model.fit
,但是tf.distribute不支持(还好吗?),而是引发了异常。
TL; DR:如何确保在各个时期维护我的tf.data随机播放缓冲区?