我想根据对角线是否小于某个截止值来选择numpy矩阵的子矩阵。例如,给定矩阵:
Test = array([[1,2,3,4,5],
[2,3,4,5,6],
[3,4,5,6,7],
[4,5,6,7,8],
[5,6,7,8,9]])
我想选择对角线值小于6的行和列。在这个例子中,对角线值被排序,这样我就可以选择Test [:3,:3],但是我想要解决的一般问题并非如此。
以下代码段有效:
def MatrixCut(M,Ecut):
D = diag(M)
indices = D<Ecut
n = sum(indices)
NewM = zeros((n,n),'d')
ii = -1
for i,ibool in enumerate(indices):
if ibool:
ii += 1
jj = -1
for j,jbool in enumerate(indices):
if jbool:
jj += 1
NewM[ii,jj] = M[i,j]
return NewM
print MatrixCut(Test,6)
[[ 1. 2. 3.]
[ 2. 3. 4.]
[ 3. 4. 5.]]
然而,这是一个非常丑陋的代码,有各种危险的东西,比如将ii / jj索引初始化为-1,如果不知何故我进入循环并取M [-1,则不会导致错误-1]。
另外,必须采用numpythonic的方式来做到这一点。对于一维数组,您可以这样做:
D = diag(A)
A[D<Ecut]
但二维阵列的类似之处并不起作用:
D = diag(Test)
Test[D<6,D<6]
array([1, 3, 5])
这样做有好办法吗?提前谢谢。
答案 0 :(得分:3)
当对角线未分类时,这也有效:
In [7]: Test = array([[1,2,3,4,5],
[2,3,4,5,6],
[3,4,5,6,7],
[4,5,6,7,8],
[5,6,7,8,9]])
In [8]: d = np.argwhere(np.diag(Test) < 6).squeeze()
In [9]: Test[d][:,d]
Out[9]:
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
或者,要使用单个下标调用,您可以执行以下操作:
In [10]: d = np.argwhere(np.diag(Test) < 6)
In [11]: Test[d, d.flat]
Out[11]:
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
首先,尝试Test[d, d]
可能很诱人,但这只会从数组的对角线中提取元素:
In [75]: Test[d, d]
Out[75]:
array([[1],
[3],
[5]])
问题是d
具有形状(3,1),因此如果我们在两个下标中都使用d
,则输出数组将具有与d
相同的形状。 d.flat
等同于使用d.flatten()
或d.ravel()
(flat
除外只返回迭代器而不是数组)。结果是结果具有形状(3,):
In [76]: d
Out[76]:
array([[0],
[1],
[2]])
In [77]: d.flatten()
Out[77]: array([0, 1, 2])
In [79]: print d.shape, d.flatten().shape
(3, 1) (3,)
Test[d, d.flat]
有效的原因是因为numpy的general broadcasting rules导致d
的最后一个维度(即1)被广播到{{1}的最后(且唯一)维度}(这是3)。同样,广播d.flat
以匹配d.flat
的第一维。结果是两(3,3)个索引数组,它们等同于以下数组d
和i
:
j
只是为了确保它们有效:
In [80]: dd = d.flatten()
In [81]: i = np.hstack((d, d, d)
In [82]: j = np.vstack((dd, dd, dd))
In [83]: print i
[[0 0 0]
[1 1 1]
[2 2 2]]
In [84]: print j
[[0 1 2]
[0 1 2]
[0 1 2]]
答案 1 :(得分:2)
我找到解决任务的唯一方法是有点棘手
>>> Test[[[i] for i,x in enumerate(D<6) if x], D<6]
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
可能不是最好的一个。基于this回答。
或者(感谢@bogatron或提醒我argwhere
):
>>> Test[np.argwhere(D<6), D<6]
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])