多类数据的混淆矩阵

时间:2021-04-06 11:42:29

标签: keras confusion-matrix

我已经将以下模型拟合到我的 7 个类数据中,并且我想为我的模型创建一个混淆矩阵:

history1 = model1.fit(data_generator.flow(train_x, to_categorical(train_y),batch_size=BATCH_SIZE), 
                    steps_per_epoch=len(train_x) / BATCH_SIZE,
                    validation_data=data_generator.flow(val_x, to_categorical(val_y),batch_size=BATCH_SIZE),
                    validation_steps=len(val_x) / BATCH_SIZE,epochs=NUM_EPOCHS)

当我这样做时,预测训练集的结果:

y_train_pred = model1.predict(train_x)
cm_train = confusion_matrix(train_y, y_train_pred)

它给了我这个错误:

Classification metrics can't handle a mix of unknown and multiclass targets

你能指导我怎么做吗?

1 个答案:

答案 0 :(得分:0)

您似乎正在使用 sklearn.metrics.confusion_matrix 尝试使用 from sklearn.metrics import multilabel_confusion_matrix 相反。