从keras生成器获取真实标签

时间:2018-12-27 19:35:10

标签: python scikit-learn keras

我想使用sklearn.metrics.confusion_matrix(y_true, y_pred)为keras模型创建混淆矩阵。

训练模型后,我可以使用predict_generator(generator)获得测试数据集的预测,这给了我y_pred。如何从数据生成器获取相应的真实标签y_true

2 个答案:

答案 0 :(得分:1)

generator.classes将为您提供稀疏格式的观测值。您可能需要密集的格式(即单热编码格式)。您可以通过以下方式得到它:

import pandas as pd
pd.get_dummies(pd.Series(generator.classes)).to_dense()

不过请注意:在生成预测并获取观察到的类之前,必须将生成器的shuffle属性设置为False,否则您的预测和观察将不会对齐!

答案 1 :(得分:0)

在创建数据生成器(您自己的或内置的ImageDataGenerator)之后,使用受过训练的模型进行预测:

true_labels = data_generator.classes
predictions = model.predict_generator(data_generator)

sklearn的混淆矩阵需要一维标签数组,因此您必须使用np.argmax()

转换预测
y_true = true_labels
y_pred = np.array([np.argmax(x) for x in predictions])

然后,您可以直接在confusion_matrix函数中使用这些变量

cm = sklearn.metrics.confusion_matrix(y_true, y_pred)

您可以使用此处的示例plot_confusion_matrix()函数对其进行绘制:

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

enter image description here