更改混乱矩阵图框中的文本

时间:2019-12-17 20:23:36

标签: python matplotlib scikit-learn

我正在尝试使用scikitlearn 0.22中提供的plot_confusion_matrix函数。但是,我遇到一个问题,因为我的其中一个框看不到框中的文本值,因为所有文本颜色似乎都设置为与该框相同的值。无论我选择哪个cmap,都是如此。顺便说一句,该值也是最低的。 example they provide中不会发生这种情况。如何更改方框文本的颜色,以便可以清楚地看到所有值?除非必须,否则我不想像其他许多解决方案中所建议的那样使用seaborn

下面是一个可复制的示例。

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix

np.random.seed(3851)

# import some data to play with
bc = datasets.load_breast_cancer()
X = bc.data
y = bc.target
class_names = bc.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
np.random.shuffle(y_test)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.0001).fit(X_train, y_train)

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
titles_options = [("Confusion matrix, without normalization", None),
                  ("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
    disp = plot_confusion_matrix(classifier, X_test, y_test,
                                 cmap=plt.cm.Blues,
                                 normalize=normalize)
    disp.ax_.set_title(title)

    print(title)
    print(disp.confusion_matrix)

plt.show()

Confusion matrix

1 个答案:

答案 0 :(得分:1)

我相信您遇到了一个错误。我已提交the issue on github,并提供了可能的解决方法。

基本上,根据颜色值是高于还是低于阈值(该阈值应位于cmap范围的中间)来选择文本的颜色。但是我相信阈值的计算方式存在问题,因此归一化示例中的所有值最终都将低于阈值并以较浅的颜色绘制。

如果要临时修复,可以在第96行的scikit-learn .../site-packages/sklearn/metrics/_plot/confusion_matrix.py安装中修改一个文件,

thresh = cm.min()+(cm.max() - cm.min()) / 2.而不是thresh = (cm.max() - cm.min()) / 2.

===================================================================
--- metrics/_plot/confusion_matrix.py   (date 1576701552905)
+++ metrics/_plot/confusion_matrix.py   (date 1576701552905)
@@ -93,7 +93,7 @@
                 values_format = '.2g'

             # print text with appropriate color depending on background
-            thresh = (cm.max() - cm.min()) / 2.
+            thresh = cm.min()+(cm.max() - cm.min()) / 2.
             for i, j in product(range(n_classes), range(n_classes)):
                 color = cmap_max if cm[i, j] < thresh else cmap_min
                 self.text_[i, j] = ax.text(j, i,