Tensorflow的DirectoryIterator如何工作?

时间:2019-10-21 15:35:34

标签: tensorflow keras tensorflow-datasets

我习惯使用model.fix(train_data,train_labels, epochs=10)之类的东西,在其中我使用glob将充满图像的文件夹读入RAM。我想在培训发生时直接从HDD中读取。我发现:

https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/DirectoryIterator

只有我不知道它是如何工作的。我搜寻了互联网,以寻求更多帮助,然后链接了文档,但没有找到任何帮助。我在DirectoryIterator中有标签和目录。我只是不知道如何将DirectoryIterator馈入模型?

代码显示了我到目前为止所做的。我还尝试使用张量流sess并将DirectoryIterator作为feed_dict进行输入。代码很杂乱,只是尝试这种方式而已。在代码中,我尝试使用fit_generator来适合DirectoryIterator。

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
import cv2 as ocv
import glob
import matplotlib.pyplot as plt
from tensorflow import image
import glob

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Plot inline
%matplotlib inline

# Load an color image in 1-colour 0-grayscale -1-bw
img = ocv.imread('C:/Users/ew/Documents/Python Scripts/Noodles/my.png',1)
RGB_im = ocv.cvtColor(img, ocv.COLOR_BGR2RGB)
img.shape
plt.imshow(RGB_im)

cv_img = []
for img in glob.glob("C:\\Users\\EW\\pictures\\Noodles\\Banana\\*.jpg"):
    cv_img.append(img)
    #n= ocv.imread(img)
    #cv_img.append(n)
print(cv_img[1])

image_data_generator = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
                                       samplewise_center=False,
                                       featurewise_std_normalization=False,
                                       samplewise_std_normalization=False,
                                       zca_whitening=False, zca_epsilon=1e-06,
                                       rotation_range=0,
                                       width_shift_range=0.0,
                                       height_shift_range=0.0,
                                       brightness_range=None,
                                       shear_range=0.0,
                                       zoom_range=0.0,
                                       channel_shift_range=0.0,
                                       fill_mode='nearest',
                                       cval=0.0,
                                       horizontal_flip=False,
                                       vertical_flip=False,
                                       rescale=None,
                                       preprocessing_function=None,
                                       data_format='channels_last',
                                       validation_split=0.3,
#                                       interpolation_order=1,
                                       dtype='float32')

noodle_data = directory = "C:\\Users\\EW\\pictures\\Noodles\\"
image_set = keras.preprocessing.image.DirectoryIterator(directory,
    image_data_generator,
    target_size=(256, 256),
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=32,
    shuffle=True,
    seed=None,
    data_format=None,
    save_to_dir=None,
    save_prefix='',
    save_format='png',
    follow_links=False,
    subset=None,
    interpolation='nearest',
    dtype=None)

model = Sequential()

#add model layers
model.add(Dense(10, activation='relu', input_shape=(256,256)))
model.add(Dense(10, activation='relu'))
model.add(Dense(1))

model.fit_generator(noodle_data , steps_per_epoch=16, validation_data=val_it, validation_steps=8)
---> 12 model.fit_generator(prawn_data , steps_per_epoch=16, validation_data=prawn_data, validation_steps=8)
AttributeError: 'str' object has no attribute 'shape'

1 个答案:

答案 0 :(得分:0)

我认为你应该这样写:

image_data_generator = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
                                   ...,
                                   dtype='float32')

directory = "C:\\Users\\EW\\pictures\\Noodles\\"

image_set = image_data_generator.flow_from_directory(directory,
                                                     target_size=(256, 256),
                                                     color_mode='rgb',
                                                     ...,
                                                     dtype=None)

因此,您调用名为 image_data_generator 的 ImageDataGenerator() 实例并使用其方法 flow_from_directory() 从目录中读取。并且您不应该在 flow_from_directory 中将 image_data_generator 作为参数传递。