使用ImageDataGenerator训练的Keras模型进行预测的正确方法

时间:2017-09-23 09:10:18

标签: machine-learning keras

我已经训练了一个模型,通过在Keras中使用ImageDataGenerator来应用一些图像增强,如下所示:

train_datagen = ImageDataGenerator(
        rotation_range=60, 
        width_shift_range=0.1,
        height_shift_range=0.1, 
        horizontal_flip=True)
train_datagen.fit(x_train)

history = model.fit_generator(
    train_datagen.flow(x_train, y_train, batch_size=7),
    steps_per_epoch=600,
    epochs=epochs,
    callbacks=callbacks_list
)

我应该如何使用此模型进行预测?使用model.predict()如下所示?

predictions = model.predict(x_test)

或者我应该使用model.predict_generator()x_test x_test未标记的情况下应用ImageDataGenerator?

如果我使用predict_generator():怎么做?

两种方法有什么区别?

1 个答案:

答案 0 :(得分:4)

predict_generator()是一种便利功能,可以更轻松地加载图像并应用与训练样本相同的预处理。我建议使用它而不是model.predict

在你的情况下,只需:

test_gen = ImageDataGenerator()
predictions = model.predict_generator(test_gen.flow(# ... your params here ... #))