如何使用sklearn更改混乱矩阵中盒子的颜色?

时间:2019-07-15 15:50:49

标签: python-3.x confusion-matrix

这是我产生混淆矩阵的代码段: 我想知道如何使用sklearn为那些与热图不在对角线上的框更改混淆矩阵中框的颜色。

    nb_classes = 15    
confusion_matrix = torch.zeros(nb_classes, nb_classes)

with torch.no_grad():
    for i, (inputs, target, classes, im_path) in enumerate(dataLoaders['test']):

        inputs = inputs.to(device)
        target = target.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(target.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

num_classes = 15
class_names = ['A2CH', 'A3CH', 'A4CH_LV', 'A4CH_RV', 'A5CH', 'Apical_MV_LA_IAS',
                 'OTHER', 'PLAX_TV', 'PLAX_full', 'PLAX_valves', 'PSAX_AV', 'PSAX_LV',
                 'Subcostal_IVC', 'Subcostal_heart', 'Suprasternal']                

plt.figure()
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)

tick_marks = numpy.arange(num_classes)
classNames = class_names

thresh = confusion_matrix.max() / 2.
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        plt.text(j, i, format(confusion_matrix[i, j]),
                ha="center", va="center",
                color="white" if  confusion_matrix[i, j] == 0 or confusion_matrix[i, j] > thresh else "black") 
plt.tight_layout()
plt.colorbar()
return plt
plt.show()   

enter image description here

2 个答案:

答案 0 :(得分:1)

使用热图绘制混淆矩阵

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
    [3,31,0,0,0,0,0,0,0,0,0], 
    [0,4,41,0,0,0,0,0,0,0,1], 
    [0,1,0,30,0,6,0,0,0,0,1], 
    [0,0,0,0,38,10,0,0,0,0,0], 
    [0,0,0,3,1,39,0,0,0,0,4], 
    [0,2,2,0,4,1,31,0,0,0,2],
    [0,1,0,0,0,0,0,36,0,2,0], 
    [0,0,0,0,0,0,1,5,37,5,1], 
    [3,0,0,0,0,0,0,0,0,39,0], 
    [0,0,0,0,0,0,0,0,0,0]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
              columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True,cmap="OrRd")

heatmap接受一个额外的参数cmap来更改矩阵的颜色。这些是营地的一些可能价值。

cmap = [Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, 
BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, 
Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, 
Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, 
PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, 
RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, 
Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, 
YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn,
autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, 
cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, 
cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r,
gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, 
gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, 
gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, icefire, icefire_r, inferno, 
inferno_r, jet, jet_r, magma, magma_r, mako, mako_r, nipy_spectral, nipy_spectral_r,
ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, 
rocket, rocket_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, 
tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, 
viridis, viridis_r, vlag, vlag_r, winter, winter_r]

cmap =“ OrRd” cmap = "OrRd"

cmap =“ Greens_r” cmap = "Greens_r" cmap =“ OrRd_r” cmap = "OrRd_r"

答案 1 :(得分:0)

def plot_confusion_matrix(y_true, y_pred, classes,
                      normalize=False,
                      title=None,
                      cmap=plt.cm.Blues):

您可以在cmap=plt.cm.Blues中将名称更改为所需的颜色,例如绿色,红色,橙色等。不要忘记在每个颜色词中添加s。此外,每种混淆矩阵颜色都有两种默认形式。例如,它是绿色。 1.果岭。它是对角线中的绿色。 2. Greens_r。用于对角线以外的绿色。

希望对您有帮助。