NumPy阵列的下三角区域中n个最大值的指数

时间:2017-11-14 15:09:04

标签: python arrays numpy distance

我有一个numpy余弦相似矩阵。我想找到n个最大值的索引,但是不要在对角线上排除1.0,而只排除其中的下三角区域。

similarities = [[ 1.          0.18898224  0.16903085]
 [ 0.18898224  1.          0.67082039]
 [ 0.16903085  0.67082039  1.        ]]

在这种情况下,如果我想要两个最高值,我希望它返回[1, 0][2, 1]

我尝试过使用argpartition,但这不会返回我正在寻找的内容

n_select = 1
most_similar = (-similarities).argpartition(n_select, axis=None)[:n_select]

如何获得除对角线1以外的n个最高值,并且还排除上三角形元素?

2 个答案:

答案 0 :(得分:2)

方法#1

使用np.tril_indices -

的一种方法
def n_largest_indices_tril(a, n=2):
    m = a.shape[0]
    r,c = np.tril_indices(m,-1)
    idx = a[r,c].argpartition(-n)[-n:]
    return zip(r[idx], c[idx])

示例运行 -

In [39]: a
Out[39]: 
array([[ 1.  ,  0.4 ,  0.59,  0.15,  0.29],
       [ 0.4 ,  1.  ,  0.03,  0.57,  0.57],
       [ 0.59,  0.03,  1.  ,  0.9 ,  0.52],
       [ 0.15,  0.57,  0.9 ,  1.  ,  0.37],
       [ 0.29,  0.57,  0.52,  0.37,  1.  ]])

In [40]: n_largest_indices_tril(a, n=2)
Out[40]: [(2, 0), (3, 2)]

In [41]: n_largest_indices_tril(a, n=3)
Out[41]: [(4, 1), (2, 0), (3, 2)]

方法#2

为了提高性能,我们可能希望避免生成所有较低的三角形索引,而是使用掩码,为我们提供第二种方法来解决我们的情况,如下所示 -

def n_largest_indices_tril_v2(a, n=2):
    m = a.shape[0]
    r = np.arange(m)
    mask = r[:,None] > r
    idx = a[mask].argpartition(-n)[-n:]

    clens = np.arange(m).cumsum()    
    grp_start = clens[:-1]
    grp_stop = clens[1:]-1    

    rows = np.searchsorted(grp_stop, idx)+1    
    cols  = idx - grp_start[rows-1]
    return zip(rows, cols)

运行时测试

In [143]: # Setup symmetric array 
     ...: N = 1000
     ...: a = np.random.rand(N,N)*0.9
     ...: np.fill_diagonal(a,1)
     ...: m = a.shape[0]
     ...: r,c = np.tril_indices(m,-1)
     ...: a[r,c] = a[c,r]

In [144]: %timeit n_largest_indices_tril(a, n=2)
100 loops, best of 3: 12.5 ms per loop

In [145]: %timeit n_largest_indices_tril_v2(a, n=2)
100 loops, best of 3: 7.85 ms per loop

适用于n最小指数

要获得最小的n,只需使用ndarray.argpartition(n)[:n]代替这两种方法。

答案 1 :(得分:0)

请记住,方阵的对角元素具有唯一属性:i + j = n,其中n是矩阵维。 然后,您可以找到数组中n +个(对角线元素)最大元素,然后迭代它们并排除元组(i,j),其中i + j = n。 希望它有所帮助!