tfds.load()之后如何在TensorFlow 2.0中应用数据增强

时间:2019-03-13 11:41:27

标签: python tensorflow tensorflow-datasets data-augmentation tensorflow2.0

我正在关注this guide

它显示了如何使用tfds.load()方法从新的TensorFlow数据集中下载数据集:

import tensorflow_datasets as tfds    
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

接下来的步骤显示了如何使用map方法将函数应用于数据集中的每个项目:

def format_example(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    # Resize the image if required
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

然后访问我们可以使用的元素:

for features in ds_train.take(1):
  image, label = features["image"], features["label"]

OR

for example in tfds.as_numpy(train_ds):
  numpy_images, numpy_labels = example["image"], example["label"]

但是,该指南未提及任何有关数据增强的内容。我想使用类似于Keras的ImageDataGenerator类的实时数据增强。我尝试使用:

if np.random.rand() > 0.5:
    image = tf.image.flip_left_right(image)

format_example()中的其他类似增强功能,但是,如何验证它正在执行实时增强并且没有替换数据集中的原始图像?

我可以通过将batch_size=-1传递到tfds.load()然后使用tfds.as_numpy()来将完整的数据集转换为Numpy数组,但是这样会把所有不需要的图像加载到内存中。我应该能够使用train = train.prefetch(tf.data.experimental.AUTOTUNE)加载足够的数据以用于下一个训练循环。

1 个答案:

答案 0 :(得分:1)

您正在从错误的方向解决问题。

首先,例如使用tfds.loadcifar10下载数据(为简单起见,我们将使用默认的TRAINTEST拆分):

import tensorflow_datasets as tfds

dataloader = tfds.load("cifar10", as_supervised=True)
train, test = dataloader["train"], dataloader["test"]

(您可以使用自定义tfds.Split对象来创建验证数据集或其他see documentation

traintesttf.data.Dataset对象,因此您可以使用mapapplybatch和类似的功能。

下面是一个示例,我将(主要使用tf.image):

  • 将每个图片转换为tf.float64范围内的0-1(不要使用官方文档中的愚蠢代码段,这样可以确保正确的图片格式)
  • cache()的结果,因为可以在每个repeat
  • 之后重复使用这些结果
  • 随机翻转left_to_right每个图像
  • 随机更改图像的对比度
  • 随机播放数据和批处理
  • 重要提示:在数据集用完后,重复所有步骤。这意味着一个时期之后,以上所有转换都将再次应用(缓存的除外)。

以下是执行上述操作的代码(您可以将lambda更改为函子或函数):

train = train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache().map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
    100
).batch(
    64
).repeat()

这种tf.data.Dataset可以直接传递给Keras的fitevaluatepredict方法。

验证它实际上就是这样

我看到您对我的解释高度怀疑,让我们来看一个例子:

1。获取一小部分数据

这是采用单个元素的一种方法,虽然这是不可读和不直观的,但是如果您使用Tensorflow做任何事情,都应该可以:

# Horrible API is horrible
element = tfds.load(
    # Take one percent of test and take 1 element from it
    "cifar10",
    as_supervised=True,
    split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
).take(1)

2。重复数据并检查是否相同:

使用Tensorflow 2.0实际上可以做到,而无需愚蠢的解决方法(几乎):

element = element.repeat(2)
# You can iterate through tf.data.Dataset now, finally...
images = [image[0] for image in element]
print(f"Are the same: {tf.reduce_all(tf.equal(images[0], images[1]))}")

它毫不奇怪地返回:

Are the same: True

3。每次重复随机增强后,检查数据是否不同

下面的代码段repeat中的单个元素5次,并检查哪些相等和哪些不同。

element = (
    tfds.load(
        # Take one percent of test and take 1 element
        "cifar10",
        as_supervised=True,
        split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
    )
    .take(1)
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label))
    .repeat(5)
)

images = [image[0] for image in element]

for i in range(len(images)):
    for j in range(i, len(images)):
        print(
            f"{i} same as {j}: {tf.reduce_all(tf.equal(images[i], images[j]))}"
        )

输出(在我的情况下,每次运行都不同):

0 same as 0: True
0 same as 1: False
0 same as 2: True
0 same as 3: False
0 same as 4: False
1 same as 1: True
1 same as 2: False
1 same as 3: True
1 same as 4: True
2 same as 2: True
2 same as 3: False
2 same as 4: False
3 same as 3: True
3 same as 4: True
4 same as 4: True

您也可以将这些图像都投射到numpy上,并使用skimage.io.imshowmatplotlib.pyplot.imshow或其他替代方法自己查看图像。

另一个可视化的实时数据增强示例

This answer使用TensorboardMNIST提供了关于数据增强的更全面,更易读的视图,可能想检查一下(是的,无耻的插件,但我想很有用)。