从Keras ImageDataGenerator获取火车测试数据

时间:2020-11-11 11:29:01

标签: python tensorflow keras

我正在按照以下方式使用Keras的ImageDataGenerator

datagen = ImageDataGenerator(samplewise_center=True,
    samplewise_std_normalization=True, validation_split=0.30
                               )

然后使用.flow语句按以下方式获得训练和测试拆分。

train_iterator = datagen.flow(x, y, subset='training')
test_iterator = datagen.flow(x, y, subset='validation')

此处x代表形状为(588, 120, 120, 1)的图像,y代表multiclass输出(588, 4)

(588, 120, 120, 1)形状输入数据中,共有588个样本,每个样本的形状为(120, 120, 1)。输出具有4个类。

然后,我使用以下代码训练和测试我的CNN。

history =model.fit_generator(train_iterator,
                             
                              epochs=10,
                              validation_data=test_iterator,
                              callbacks=callbacks_list) 


pred_test = model.predict(test_iterator, steps=len(test_iterator), verbose=0)

我的问题是: 如何访问test_iterator用于预测的测试数据(x和y)。

2 个答案:

答案 0 :(得分:1)

flow()返回一个生成(x,y)元组的迭代器,您可以使用test_iterator.next()访问元素。

答案 1 :(得分:1)

仅供参考。在ImageDataGenerator中,您设置了samplewise_center = True, samplewise_std_normalization =真实。如果这是您要的内容,则必须首先使生成器适合累积输入数据的统计信息,所以首先这样做

datagen.fit(x)

文档为here.