TensorFlow Python-是否可以将tensorflow_datasets数据集插入ImageGenerator?

时间:2019-10-26 17:44:18

标签: python python-3.x tensorflow tensorflow-datasets tensorflow2.0

我已经尝试了很长时间在tensorflow.keras.preprocessing.image.ImageGenerator函数中使用数据增强,但是我看到的每个示例都在包含文件的目录中传递。我的目标是使用tensorflow_datasets导入MNIST,然后将其传递给数据增强功能,但我一直无法找到方法。

如果目录更容易使用,并且任何人都可以找到简单的方法,并愿意向我成功说明,我愿意使用它。

请参见下面的代码

谢谢你,麦克斯

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def main():
    data, info = tfds.load("mnist", with_info=True)
    train_data, test_data = data['train'], data['test']

    image_gen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        fill_mode='nearest')

    #
    # What Do I Do Here??
    #

    # train_data_gen = image_gen.flow(data)


if __name__ == "__main__":
    main()


1 个答案:

答案 0 :(得分:0)

今天上午,我花了很大一部分时间试图弄清,如果使用这两种方法中的任何一种,则是处理TF2中数据加载 AND 增强的“正确”方法。 "TensorFlow 2.0 removes redundant APIs"这么多吧?我目前的理解是,这两种数据加载方法是独立的,并且打算单独使用 (尽管如果TF中的某人能够进入这里,那将是很好的选择)。

首先,ImageDataGenerator使用实时数据增强生成一批张量图像数据。我已经看到一些使用tfds的解决方案来将数据集读入ImageDataGenerator的numpy数组中,但是如果这不是反模式,我会感到震惊。我的理解是,如果您使用ImageDataGenerator,则应该同时使用它来加载和预处理数据。

我已选择采用官方(也许?!)tensorflow_datasets路线。我没有利用ImageDataGenerator中的内置增强功能,而是使用tfds.load加载数据集,然后结合使用缓存和map calls进行预处理。

首先,我们使用S3 API加载数据:

train_size = 45000
val_size = 5000
(train, val, test), info = tfds.load(
    "cifar10:3.*.*",
    as_supervised=True,
    split=[
        "train[0:{}]".format(train_size),
        "train[{}:{}]".format(train_size, train_size + val_size),
        "test",
    ],
    with_info=True
)

然后,使用一系列帮助器和tf.image函数,我们可以进行预处理:

def _pad_image(
    image: tf.Tensor, label: tf.Tensor, padding: int = 4
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Pads and image and returns a given supervised training pair."""
    image = tf.pad(
        image, [[padding, padding], [padding, padding], [0, 0]], mode="CONSTANT"
    )
    return image, label


def _crop_image(
    image: tf.Tensor, label: tf.Tensor, out_size: List[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly crops an image and returns a given supervised training pair."""
    if out_size is None:
        out_size = [32, 32, 3]
    image = tf.image.random_crop(image, out_size)
    return image, label


def _random_flip(
    image: tf.Tensor, label: tf.Tensor, prob=0.5
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly flips an image and returns a given supervised training pair."""
    if tf.random.uniform(()) > 1 - prob:
        image = tf.image.flip_left_right(image)
    return image, label


processed_train = (
    train.map(_pad_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .cache()
    .map(_crop_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(_random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)

由于我们有一个tf.data.Dataset,因此我们可以使用标准API进行批处理,重复等。

相关问题