在Google Cloud Platform中进行Keras ML培训的理想方式是读取存储在存储桶中的数据批次中的数据?

时间:2019-02-17 18:47:46

标签: python tensorflow keras google-cloud-platform

这是我第一次尝试在云中训练模型,而我正在努力处理所有小的内部事务。 我将训练数据存储在Google Cloud平台的存储桶中, 遵循gs://test/train 数据集约为100k。 当前,数据根据其标签分布在单独的文件夹中。

我不知道访问数据的理想方法。 通常在Keras中,我将ImageDataGeneratorflow_from_directory配合使用,它会自动创建一个生成器,可以将其输入到模型中。

是否存在用于Google Cloud Platform的Python功能?

如果不是,通过生成器访问数据的理想方法是什么,那么我可以将其馈送到 Keras model.fit_generator

谢谢。

1 个答案:

答案 0 :(得分:2)

ImageDataGenerator.flow_from_directory()当前不允许您直接从GCS存储桶中流式传输数据。我认为您有两种选择:

1 /将数据从GCS复制到用于运行脚本的VM本地磁盘。我想您是通过ML Engine还是在Compute Engine实例上执行此操作。无论哪种方式,都可以使用gsutilpython cloud storage API在训练脚本的开头复制数据。这里有一个缺点:这会在脚本开始时花费一些时间,尤其是在数据集很大时。

2/2 /使用tf.keras时,可以在tf.data数据集上训练模型。这里的好处是TensorFlow的io实用程序允许您直接从GCS存储桶中读取。如果要将数据转换为TFRecords,则可以实例化Dataset对象,而无需先将数据下载到本地磁盘:

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

# Fit a tf.keras model
model.fit(ds_train)

有关TFRecord选项的更多信息,请参见this question。对于直接使用Dataset.from_tensor_slices从GCS上的图像实例化的Dataset对象,这也可以正常工作,这样您就不必先将数据存储为TFRecords格式:

def load_and_preprocess_image(path):
"""Read an image GCS path and process it into an image tensor

Args:
    path (tensor): string tensor, pointer to GCS or local image path

Returns:
    tensor: processed image tensor
"""

    image = tf.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    return image

image_paths = ['gs://my-bucket/img1.png',
               'gs://my-bucket/img2/png'...]
path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
image_ds = path_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels) # can be a list of labels    
model.fit(tf.data.Dataset.zip((images_ds, labels_ds)))

有关更多示例,请参见tutorials on the TF website

3 /最后,还应该可以编写自己的python生成器或改编ImageDataGenerator的源代码,以便使用TensorFlow io函数读取图像。同样,这些在gs://路径下也可以正常工作:

import tensorflow as tf
tf.enable_eager_execution()
path = 'gs://path/to/my/image.png'
tf.image.decode_png(tf.io.read_file(path)) # this works

另请参见this related question。这可能比上面列出的选项要慢。