ImageDataGenerator-使用model.fit而不是model.fit_generator进行训练

时间:2020-03-12 03:35:48

标签: python-3.x keras conv-neural-network data-augmentation

我是使用Keras的ImageDataGenerator的初学者,我不小心使用了model.fit而不是model.fit_generator。

train_gen = gen_Image_data()
test_gen = ImageDataGenerator()
train_samples = train_gen.flow(X,y, batch_size=64)
test_samples = test_gen.flow(X_val, y_val, batch_size=64)
history = model.fit(train_samples, steps_per_epoch = np.ceil(len(X)/64),
                  validation_data=(test_samples),
                  validation_steps=np.ceil(len(X_val)/64),
                  epochs=300, verbose=1, callbacks=[es])

那是一个明显的错误,我是否必须使用fit_generator重新训练一切?

感谢您的帮助

更新,我忘记了gen_Image_data()

的代码
def gen_Image_data():
   gen = ImageDataGenerator(
         width_shift_range=0.1,
         horizontal_flip=True)
   return gen

1 个答案:

答案 0 :(得分:1)

您无需费心重新训练模型,因为model.fit方法也支持生成器,并且model.fit_generator也包含在model.fit方法中!