从按列排序的方阵中获取较低的对角线索引

时间:2018-04-01 13:17:29

标签: python numpy

我正在尝试在matplotlib.pyplot.subplot的下对角线创建成对图。因此,我需要来自方阵的下对角线的指数。由于我的情节顺序,我需要按列进行排序。例如,假设我有以下矩阵4x4:

[ 1,  2,  3,  4]
[ 5,  6,  8,  7]
[ 8,  9, 10, 11]
[12, 13, 14, 15]

我需要按以下顺序排列他们的索引:5,8,12,9,13,14。如何在几行代码中实现这一点?我将分享我的解决方案,但我觉得我可以以更优雅的方式实现这一目标。

我的解决方案

>>> import numpy as np
>>> n = 4 # Matrix order
>>> a = np.arange(1,n*n+1).reshape(n,n)
>>> a
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12],
       [13, 14, 15, 16]])
>>> index = np.triu_indices(n, 1)
>>> a.T[index]
array([ 5,  9, 13, 10, 14, 15])

上下文

接下来我要做的是:

>>> subplot_idx = a.T[index]
>>> for idx in subplot_idx:
...     plt.subplot(n, n, idx)
...     # plot something

2 个答案:

答案 0 :(得分:2)

这个怎么样?

n = 4
1 + np.ravel_multi_index(np.triu_indices(n, 1)[::-1], (n, n))
# array([ 5,  9, 13, 10, 14, 15])

与您的解决方案类似,但不需要设置完整的方阵。

答案 1 :(得分:2)

更便宜的方法是避免索引创建部分并使用boolean-indexing的掩码。现在,由于它在NumPy中的行主要排序并且我们需要较低的diag元素,我们需要在输入数组的转置版本上使用上部diag掩码(上部diag元素的掩码设置为True,而rest为False)。我们将使用broadcasting有效地创建一个带有ranged array outer comparison的上层诊断掩码并转换为转置数组。因此,对于输入数组a,它将是 -

r = np.arange(len(a))
out = a.T[r[:,None] < r]

假设我们将使用小于65536 x 65536大小的数组的矩阵,我们可以使用r的较低精度,从而实现显着的性能提升 -

r = np.arange(len(a), dtype=np.uint16)

相同的想法并使用NumPy内置np.tri创建一个较低的诊断掩码,因此有一个 elegant 单行方式(如所要求的)将是 -

a.T[~np.tri(len(a), dtype=bool)]

示例运行 -

In [116]: a
Out[116]: 
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12],
       [13, 14, 15, 16]])

In [117]: a.T[~np.tri(len(a), dtype=bool)]
Out[117]: array([ 5,  9, 13, 10, 14, 15])

基准

方法 -

# Original soln
def extract_lower_diag_org(a):
    n = len(a)
    index = np.triu_indices(n, 1)
    return a.T[index]

# Proposed soln
def extract_lower_diag_mask(a):
    r = np.arange(len(a), dtype=np.uint16)
    return a.T[r[:,None] < r]

更大阵列上的计时 -

In [142]: a = np.random.rand(5000,5000)

In [143]: %timeit extract_lower_diag_org(a)
1 loop, best of 3: 216 ms per loop

In [144]: %timeit extract_lower_diag_mask(a)
10 loops, best of 3: 50.2 ms per loop

In [145]: %timeit a.T[~np.tri(len(a), dtype=bool)]
10 loops, best of 3: 52.1 ms per loop

使用建议的基于掩码的方法查看这些大型数组的 4x+ 加速。