检查模型输入时出错:找到:Keras中的<keras.preprocessing.image.directoryiterator ... object =“”>

时间:2016-12-22 23:44:54

标签: python keras

我在迭代器中有批量数据,在这里,使用原生的Keras并没有什么花哨的东西:

batches = gen.flow_from_directory(path, target_size=(224,224), class_mode=class_mode, shuffle=shuffle, batch_size=batch_size)

看起来很好:

print batches:

  

keras.preprocessing.image.DirectoryIterator对象位于0x7f107c004210 \

但是现在我已经编译好了,我已准备好fit

model.compile(optimizer=Adam(1e-5), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(batches, val_batches, nb_epoch=1)

但我一直在接受:

Exception: Error when checking model input: data should be a Numpy array, or list/dict of Numpy arrays. Found: <keras.preprocessing.image.DirectoryIterator object at 0x7f107c004210>...

No Keras不喜欢我正在使用迭代器?为什么我不能使用迭代器?我认为这就是重点 - 不要占用你所有的记忆,而是使用某种批处理迭代器。

1 个答案:

答案 0 :(得分:1)

fit方法期望其输入为Numpy数组或Numpy数组列表。您应该使用fit_generator代替,它将生成器作为参数。

model.fit_generator(generator=batches, 
                    validation_data=val_batches, 
                    nb_epoch=1)