使用scikit learning与预先计算的余弦相似度矩阵进行层次聚类会产生错误

时间:2019-05-14 18:00:50

标签: python scikit-learn hierarchical-clustering cosine-similarity distance-matrix

我们要在层次聚类中使用余弦相似度,并且已经计算出了余弦相似度。 在sklearn.cluster.AgglomerativeClustering文档中说:

  

需要距离矩阵(而不是相似度矩阵)作为输入   适合方法。

因此,我们将余弦相似度转换为距离

distance = 1 - similarity

我们的python代码最后在fit()方法处产生错误。 (因为它很大,所以我没有在代码中写X的实际值。)X只是一个余弦相似度矩阵,其值如上所述转换为距离。注意对角线,全为0。)这是代码:

import pandas as pd
import numpy as np 
from sklearn.cluster import AgglomerativeClustering

X = np.array([0,0.3,0.4],[0.3,0,0.7],[0.4,0.7,0])

cluster = AgglomerativeClustering(affinity='precomputed')  
cluster.fit(X)

错误是:

runfile('/Users/stackoverflowuser/Desktop/4.2/Pr/untitled0.py', wdir='/Users/stackoverflowuser/Desktop/4.2/Pr')
Traceback (most recent call last):

  File "<ipython-input-1-b8b98765b168>", line 1, in <module>
    runfile('/Users/stackoverflowuser/Desktop/4.2/Pr/untitled0.py', wdir='/Users/stackoverflowuser/Desktop/4.2/Pr')

  File "/anaconda2/lib/python2.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 704, in runfile
    execfile(filename, namespace)

  File "/anaconda2/lib/python2.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 100, in execfile
    builtins.execfile(filename, *where)

  File "/Users/stackoverflowuser/Desktop/4.2/Pr/untitled0.py", line 84, in <module>
    cluster.fit(X)

  File "/anaconda2/lib/python2.7/site-packages/sklearn/cluster/hierarchical.py", line 795, in fit
    (self.affinity, ))

ValueError: precomputed was provided as affinity. Ward can only work with euclidean distances.

有什么我可以提供的吗?已经谢谢了。

2 个答案:

答案 0 :(得分:0)

根据sklearn的文档:

  

如果链接为“病房”,则仅接受“欧几里得”。如果是“预先计算”,   需要距离矩阵(而不是相似矩阵)作为输入   适合方法。

因此,您需要将链接更改为完整,平均或单个链接之一。

答案来自: https://datascience.stackexchange.com/questions/51970/hierarchical-clustering-with-precomputed-cosine-similarity-matrix-using-scikit-l/

答案 1 :(得分:0)

您可以使用

n_clusters =所需集群的数量。

linkage =链接可以是"complete""average""single"以表示余弦亲和力。

AgglomerativeClustering(n_clusters=n_clusters,affinity="cosine",linkage="average")