如何正确使用tf.data在Tensorflow v2中训练Keras模型?

时间:2019-05-12 03:58:04

标签: python tensorflow tensorflow-serving tensorflow-datasets tensorflow-estimator

我不清楚如何在tensorflow v2中最佳使用tf.data训练Keras模型:

我采用了以下方法:

dataset = tf.data.TFRecordDataset("filename.tfrecords").shuffle(total_size)
validation_size = int(validation_size_split_ratio * total_size)

validation_dataset = dataset.take(validation_size)
validation_dataset = validation_dataset.map(preprocess_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
validation_dataset = validation_dataset.shuffle(validation_size)
# validation_dataset = validation_dataset.cache()
# validation_dataset = validation_dataset.repeat()
validation_dataset = validation_dataset.batch(batch_size)
validation_dataset = validation_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

train_dataset = dataset.skip(validation_size)
train_dataset = train_dataset.map(preprocess_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(total_size - validation_size)
# train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

其中数据集被定义为对所有数据进行一次完整遍历,然后停止(即到达末尾)。请注意,所有重复操作均已被注释掉。该模型将进行如下训练:

training_history = model.fit(train_dataset, validation_data=validation_dataset, epochs=epochs)

在这种情况下,每个时期对应于整个数据集对象的新遍历。另外,我们可以取消注释上面的“ repeat”语句,然后像这样训练模型:

training_history = model.fit(train_dataset, steps_per_epoch=total_size // batch_size, epochs=epochs, validation_data=validation_dataset, validation_steps=500)

哪种方法是最佳方法,为什么?我们是否应该像第二种方法一样使用重复操作。在实践中,我发现第一种方法的培训时间更长……但是,什么时候应该使用“重复”操作。当Keras完成训练阶段然后运行验证集时,似乎还会有很长的停顿。我可以从字面上看到GPU利用率从95%下降到0 ...对此有任何建议吗?谢谢!

使用上述第一种方法,在每个时期之后,我还会看到此消息:

Filling up shuffle buffer (this may take a while)

在每个时期之后都需要重新填充随机缓冲吗?

我也看到了,我不明白:

Encountered a stop event that was not preceded by a start event.

特别是,此消息后,我看到GPU使用率降至0

0 个答案:

没有答案