Keras图像数据生成器:fit_generator()参数错误

时间:2017-10-10 11:53:21

标签: python image keras

对于来自Keras的cifar10数据集的项目,我在图像增强方面遇到了问题。简而言之,我遵循了Keras文档中的描述:https://keras.io/preprocessing/image/。但是,我无法运行我的代码,因为fit_generator函数发生了参数错误(见下文):

TypeErrorTraceback (most recent call last)
/home/nct01/nct01002/new_MAI-DL/lab_1/code_lab1.3.py in <module>()
     86 #trying the dataaugmentation fit function
     87 history = nn.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
---> 88                     steps_per_epoch=len(x_train) / 32, epochs=30, validation_sp
lit=0.15)
     89 
     90 #Evaluate the model with test set

以下是我用于准备数据的代码:

#Load the CIFAR dataset, already provided by Keras
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

#Check sizes of dataset
print 'Number of train examples', x_train.shape[0]
print 'Size of train examples', x_train.shape[1:]

#Normalize data
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255
x_test = x_test / 255


#import library for data augmentation
from keras.preprocessing.image import ImageDataGenerator

#prepare image augmentation
datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# apply datagen to training data
datagen.fit(x_train)

以下是初始化培训的代码:

#trying the dataaugmentation fit function
history = nn.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train) / 32, epochs=30, validation_split=0.15)

P.s。:对于该项目,我需要使用Keras版本1.1.1

0 个答案:

没有答案