具有归一化混淆矩阵的Matplotlib图的Colorbar不会更新值

时间:2017-07-22 07:26:01

标签: python matplotlib machine-learning colorbar

我正在尝试用KNearest Neighbors解决一个多类机器学习问题,并且正在使用Matplotlib.pyplot的imshow绘制一个混淆矩阵,用于预测我数据中所有10个类。有些类在数据中出现的次数远多于其他类,高达3000次,其他类可能只有50次,因此我将其标准化为仅显示百分比。图表旁边有一个颜色条,如果没有标准化,则范围从1到3000,这是有道理的。然而,在规范化之后,范围一直保持到3000.我正在使用Scikit在他们的网站here提供的绘图功能。是否有一些明显缺失的东西,或是否有额外的步骤来减少色条值范围?

代码

virdis = plt.cm.viridis
blues = plt.cm.Blues
autumn = plt.cm.autumn

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)

    bounds=[0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]
    plt.colorbar(boundaries=bounds)

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    cm = np.around(cm, decimals=3)

    thresh = cm.max() / 2.

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if i == 9 and j == 9 else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')

knn = KNeighborsClassifier()
knn.fit(X_train, y_train)

knn_score = knn.score(X_test, y_test)
knn_fold_score = model_selection.cross_val_score(knn, X_test, y_test, cv=10).mean()
predictions = knn.predict(X_test)

c_matrix = confusion_matrix(y_test, predictions)

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(c_matrix, classes=country_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

enter image description here

1 个答案:

答案 0 :(得分:2)

正如您所理解的那样,色条及其范围保持不变并且应该与情节保持一致,即plt.imshow。 Scikit Learn示例和您的示例都会在执行或决定是否进行规范化之前绘制矩阵。因此,这两个图及其相关的颜色条看起来完全相同。如果在绘图之前处理标准化,即移动以下块:

if normalize:
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    print("Normalized confusion matrix")
else:
    print('Confusion matrix, without normalization')

cm = np.around(cm, decimals=3)

plt.imshow(cm, interpolation='nearest', cmap=cmap)前面,标准化绘图的颜色条将在0到1的范围内。再次,只是为了提醒你,(颜色)绘图本身也会改变。我认为仅将颜色条的文本标签更改为0到1并不是一个好主意,而不更改颜色条本身及其相关图。