根据对角线值选择子矩阵

时间:2013-11-20 16:47:54

标签: python arrays numpy

我想根据对角线是否小于某个截止值来选择numpy矩阵的子矩阵。例如,给定矩阵:

Test = array([[1,2,3,4,5],
              [2,3,4,5,6],
              [3,4,5,6,7],
              [4,5,6,7,8],
              [5,6,7,8,9]])

我想选择对角线值小于6的行和列。在这个例子中,对角线值被排序,这样我就可以选择Test [:3,:3],但是我想要解决的一般问题并非如此。

以下代码段有效:

def MatrixCut(M,Ecut):
    D = diag(M)
    indices = D<Ecut
    n = sum(indices)
    NewM = zeros((n,n),'d')
    ii = -1
    for i,ibool in enumerate(indices):
        if ibool:
            ii += 1
            jj = -1
            for j,jbool in enumerate(indices):
                if jbool:
                    jj += 1
                    NewM[ii,jj] = M[i,j]
    return NewM

print MatrixCut(Test,6)
[[ 1.  2.  3.]
 [ 2.  3.  4.]
 [ 3.  4.  5.]]

然而,这是一个非常丑陋的代码,有各种危险的东西,比如将ii / jj索引初始化为-1,如果不知何故我进入循环并取M [-1,则不会导致错误-1]。

另外,必须采用numpythonic的方式来做到这一点。对于一维数组,您可以这样做:

D = diag(A)
A[D<Ecut]

但二维阵列的类似之处并不起作用:

D = diag(Test)
Test[D<6,D<6]
array([1, 3, 5])

这样做有好办法吗?提前谢谢。

2 个答案:

答案 0 :(得分:3)

当对角线未分类时,这也有效:

In [7]: Test = array([[1,2,3,4,5],
              [2,3,4,5,6],
              [3,4,5,6,7],
              [4,5,6,7,8],
              [5,6,7,8,9]])

In [8]: d = np.argwhere(np.diag(Test) < 6).squeeze()

In [9]: Test[d][:,d]
Out[9]: 
array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5]])

或者,要使用单个下标调用,您可以执行以下操作:

In [10]: d = np.argwhere(np.diag(Test) < 6)

In [11]: Test[d, d.flat]
Out[11]: 
array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5]])

[更新]:第二种形式的说明。

首先,尝试Test[d, d]可能很诱人,但这只会从数组的对角线中提取元素:

In [75]: Test[d, d]
Out[75]: 
array([[1],
       [3],
       [5]])

问题是d具有形状(3,1),因此如果我们在两个下标中都使用d,则输出数组将具有与d相同的形状。 d.flat等同于使用d.flatten()d.ravel()flat除外只返回迭代器而不是数组)。结果是结果具有形状(3,):

In [76]: d
Out[76]: 
array([[0],
       [1],
       [2]])

In [77]: d.flatten()
Out[77]: array([0, 1, 2])

In [79]: print d.shape, d.flatten().shape
(3, 1) (3,)

Test[d, d.flat]有效的原因是因为numpy的general broadcasting rules导致d的最后一个维度(即1)被广播到{{1}的最后(且唯一)维度}(这是3)。同样,广播d.flat以匹配d.flat的第一维。结果是两(3,3)个索引数组,它们等同于以下数组di

j

只是为了确保它们有效:

In [80]: dd = d.flatten()

In [81]: i = np.hstack((d, d, d)

In [82]: j = np.vstack((dd, dd, dd))

In [83]: print i
[[0 0 0]
 [1 1 1]
 [2 2 2]]

In [84]: print j
[[0 1 2]
 [0 1 2]
 [0 1 2]]

答案 1 :(得分:2)

我找到解决任务的唯一方法是有点棘手

>>> Test[[[i] for i,x in enumerate(D<6) if x], D<6]
array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5]])

可能不是最好的一个。基于this回答。 或者(感谢@bogatron或提醒我argwhere):

>>> Test[np.argwhere(D<6), D<6]
array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5]])