scikit-learn中聚类的混淆矩阵

时间:2017-12-08 06:25:48

标签: python scikit-learn cluster-analysis confusion-matrix scikits

我有一组带有已知标签的数据。我想尝试群集,看看我是否可以获得已知标签给出的相同群集。为了测量准确度,我需要得到类似混淆矩阵的东西。

我知道我可以轻松地为分类问题的测试集获得混淆矩阵。我已经像this那样试过了。

但是,它不能用于聚类,因为它希望列和行具有相同的标签集,这对于分类问题是有意义的。但对于聚类问题,我的期望是这样的。

  

行 - 实际标签

     

列 - 新的群集名称(即群集1,群集2等)

有办法做到这一点吗?

修改:以下是更多详情。

sklearn.metrics.confusion_matrix中,它希望y_testy_pred具有相同的值,并labels作为这些值的标签。

这就是为什么它给出了一个像这样的行和列具有相同标签的矩阵。

enter image description here

但在我的情况下(KMeans Clustering),实际值是字符串,估计值是数字(即簇号)

因此,如果我致电confusion_matrix(y_true, y_pred),则会出现以下错误。

ValueError: Mix of label input types (string and number)

这是真正的问题。对于分类问题,这是有道理的。但是对于群集问题,这种限制不应该存在,因为真实的标签名称和新的群集名称不需要相同。

有了这个,我理解我试图使用一个应该用于分类问题的工具来解决聚类问题。所以,我的问题是,有没有办法可以为may集群数据获得这样的矩阵。

希望问题现在更加清晰。如果不是,请告诉我。

2 个答案:

答案 0 :(得分:1)

您可以轻松计算成对交叉矩阵。

但是,如果sklearn库已针对分类用例进行了优化,则可能需要自己执行此操作。

答案 1 :(得分:0)

我自己写了一段代码。

# Compute confusion matrix
def confusion_matrix(act_labels, pred_labels):
    uniqueLabels = list(set(act_labels))
    clusters = list(set(pred_labels))
    cm = [[0 for i in range(len(clusters))] for i in range(len(uniqueLabels))]
    for i, act_label in enumerate(uniqueLabels):
        for j, pred_label in enumerate(pred_labels):
            if act_labels[j] == act_label:
                cm[i][pred_label] = cm[i][pred_label] + 1
    return cm

# Example
labels=['a','b','c',
        'a','b','c',
        'a','b','c',
        'a','b','c']
pred=[  1,1,2,
        0,1,2,
        1,1,1,
        0,1,2]
cnf_matrix = confusion_matrix(labels, pred)
print('\n'.join([''.join(['{:4}'.format(item) for item in row])
      for row in cnf_matrix]))

修改 (Dayyyuumm)刚刚发现我可以使用Pandas Crosstab轻松完成此任务: - /。

labels=['a','b','c',
        'a','b','c',
        'a','b','c',
        'a','b','c']
pred=[  1,1,2,
        0,1,2,
        1,1,1,
        0,1,2]   

# Create a DataFrame with labels and varieties as columns: df
df = pd.DataFrame({'Labels': labels, 'Clusters': pred})

# Create crosstab: ct
ct = pd.crosstab(df['Labels'], df['Clusters'])

# Display ct
print(ct)