来自tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory的tf.data.Dataset吗?

时间:2019-02-09 12:40:03

标签: python numpy tensorflow keras marshalling

如何从tf.data.Dataset创建一个tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory

我正在考虑tf.data.Dataset.from_generator,但是在给定返回类型的情况下,尚不清楚如何为其获取output_types关键字参数:

  

产生DirectoryIterator元组的(x, y),其中x是一个numpy数组,其中包含一批形状为(batch_size, *target_size, channels)的图像,而y是一个numpy数组相应的标签。

2 个答案:

答案 0 :(得分:2)

这是我的解决方案。为了展示它是如何工作的,我使用了猫/狗数据集:

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf


_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
#'/Users/mustafamuratarat/.keras/datasets/cats_and_dogs_filtered/train'

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

gen = img_gen.flow_from_directory(train_dir, target_size=(160, 160), batch_size=32)
#<tensorflow.python.keras.preprocessing.image.DirectoryIterator at 0x7fb9fde3b250>

#gen.class_indices
#{'cats': 0, 'dogs': 1}

#gen.target_size
#(160, 160)

# gen.batch_size
# 32

# gen.num_classes
# 2

dataset = tf.data.Dataset.from_generator(
    lambda: gen,
    output_types = (tf.float32, tf.float32),
    output_shapes = ([None, 160, 160, 3], [None, 2]),
)

#list(dataset.take(1).as_numpy_iterator())

然后您可以将 dataset 对象提供给任何模型。

答案 1 :(得分:1)

ImageDataGenerator中的batch_xbatch_y均为K.floatx()类型,因此默认情况下必须为tf.float32

类似的问题已经在How to use Keras generator with tf.data API上进行了讨论。让我从那里复制粘贴答案:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))

作者在图范围方面面临另一个问题,但我想它与您的问题无关。

或作为一个班轮:

tf.data.Dataset.from_generator(lambda:
    ImageDataGenerator().flow_from_directory('folder_path'),(tf.float32, tf.float32))