使用Keras在gcloud ml-engine上处理TB级数据的最佳方法

时间:2019-02-04 20:24:29

标签: tensorflow keras google-cloud-ml tensorflow-datasets tfrecord

我想在gcloud存储上训练约2TB图像数据的模型。我将图像数据另存为单独的tfrecords,并尝试使用此示例中的tensorflow数据api

https://medium.com/@moritzkrger/speeding-up-keras-with-tfrecord-datasets-5464f9836c36

但是看来keras的model.fit(...)不支持基于的tfrecord数据集的验证

https://github.com/keras-team/keras/pull/8388

是否有更好的方法来处理我缺少的ml-engine的keras的大量数据?

非常感谢!

1 个答案:

答案 0 :(得分:3)

如果您愿意使用tf.keras而不是实际的Keras,则可以使用TFRecordDataset API实例化tf.data并将其直接传递给model.fit()奖励:您可以直接从Google云端存储中进行流式传输,无需先下载数据

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

model.fit(ds_train)

要包括验证数据,请使用验证TFRecords创建一个TFRecordDataset并将其传递给validation_data的{​​{1}}参数。注意:as of TensorFlow 1.9是可能的。

最后的注释:您需要指定model.fit()参数。我用来了解所有TFRecordfiles中示例总数的一种技巧是简单地遍历文件并计数:

steps_per_epoch

您可以用来计算import tensorflow as tf def n_records(record_list): """Get the total number of records in a collection of TFRecords. Since a TFRecord file is intended to act as a stream of data, this needs to be done naively by iterating over the file and counting. See https://stackoverflow.com/questions/40472139 Args: record_list (list): list of GCS paths to TFRecords files """ counter = 0 for f in record_list: counter +=\ sum(1 for _ in tf.python_io.tf_record_iterator(f)) return counter

steps_per_epoch