我有一个相当大的矩阵M
。我试图找到前5个最近距离及其指数。
M = csr_matrix(M)
dst = pairwise_distances(M,Y=None,metric='euclidean')
dst
变成了一个巨大的矩阵,我正在尝试高效排序或使用scipy或sklearn找到最近的5个距离。
这是我想要做的一个例子:
X = np.array([[2, 3, 5], [2, 3, 6], [2, 3, 8], [2, 3, 3], [2, 3, 4]])
然后我将dst
计算为:
[[ 0. 1. 3. 2. 1.]
[ 1. 0. 2. 3. 2.]
[ 3. 2. 0. 5. 4.]
[ 2. 3. 5. 0. 1.]
[ 1. 2. 4. 1. 0.]]
因此,第0行到它自己的距离为0.
,第0行到第1行的距离为1.
,...第2行到第3行的距离为5.
, 等等。我想找到这些最接近的5个距离并将它们放在一个带有相应行的列表中,可能就像 [distance,row,row] 。我不想要任何对角元素或重复元素,所以我采用上三角矩阵如下:
[[ inf 1. 3. 2. 1.]
[ nan inf 2. 3. 2.]
[ nan nan inf 5. 4.]
[ nan nan nan inf 1.]
[ nan nan nan nan inf]]
现在,前5个距离最小的距离是:
[1, 0, 1], [1, 0, 4], [1, 3, 4], [2, 1, 2], [2, 0, 3], [2, 1, 4]
如您所见,有三个元素具有距离2
,三个元素具有距离1
。从这些我想随机选择一个距离为2
的元素,因为我只想要顶部的 f 元素,其中 f = 5。
这只是一个样本,因为这个矩阵可能非常大。除了使用基本的排序函数之外,还有一种有效的方法吗?我无法找到任何sklearn或scipy来帮助我。
答案 0 :(得分:1)
这是针对您的问题的完全矢量化解决方案:
import numpy as np
from scipy.spatial.distance import pdist
def smallest(M, f):
# compute the condensed distance matrix
dst = pdist(M, 'euclidean')
# indices of the upper triangular matrix
rows, cols = np.triu_indices(M.shape[0], k=1)
# indices of the f smallest distances
idx = np.argsort(dst)[:f]
# gather results in the specified format: distance, row, column
return np.vstack((dst[idx], rows[idx], cols[idx])).T
请注意,np.argsort(dst)[:f]
会生成按升序排序的压缩距离矩阵f
的最小dst
个元素的索引。
以下演示重现了您的玩具示例的结果,并显示了函数smallest
如何处理相当大的整数矩阵:
In [59]: X = np.array([[2, 3, 5], [2, 3, 6], [2, 3, 8], [2, 3, 3], [2, 3, 4]])
In [60]: smallest(X, 5)
Out[60]:
array([[ 1., 0., 1.],
[ 1., 0., 4.],
[ 1., 3., 4.],
[ 2., 0., 3.],
[ 2., 1., 2.]])
In [61]: large_X = np.random.randint(100, size=(10000, 2000))
In [62]: large_X
Out[62]:
array([[ 8, 78, 97, ..., 23, 93, 90],
[42, 2, 21, ..., 68, 45, 62],
[28, 45, 30, ..., 0, 75, 48],
...,
[26, 88, 78, ..., 0, 88, 43],
[91, 53, 94, ..., 85, 44, 37],
[39, 8, 10, ..., 46, 15, 67]])
In [63]: %time smallest(large_X, 5)
Wall time: 1min 32s
Out[63]:
array([[ 1676.12529365, 4815. , 5863. ],
[ 1692.97253374, 1628. , 2950. ],
[ 1693.558384 , 5742. , 8240. ],
[ 1695.86408654, 2140. , 6969. ],
[ 1696.68853948, 5477. , 6641. ]])