获取所有列索引等于max并使用它们索引另一个数组:numpy vs sparse csr_matrix

时间:2017-08-15 19:43:35

标签: python numpy scipy sparse-matrix

在示例here之后,我能够找到2D numpy数组的列索引,并返回所有最大值出现的列索引数组。

但是现在我想在稀疏的csr_matrix上做同样的事情。

x = np.array([[0,0,1,0,0,0,2],[0,0,0,4,0,0,0],[0,9,1,0,0,0,2],[0,0,1,0,0,9,2]])
max_col_inds = np.argwhere(x == np.max(x))[:,1]
# array([1, 5], dtype=int64)

然后我想使用该结果得到1D数组的第1和第5个元素:

words[max_col_inds]

如果x是2D numpy数组且words是1D numpy数组,则可行。

但是现在如果我用scipy.sparse.csr.csr_matrix替换x,我会在调用np.argwhere()时得到这个:

TypeError: tuple indices must be integers, not tuple

1 个答案:

答案 0 :(得分:1)

In [804]: x = np.array([[0,0,1,0,0,0,2],[0,0,0,4,0,0,0],[0,9,1,0,0,0,2],[0,0,1,0,0,9,2]])
In [805]: np.max(x)
Out[805]: 9
In [806]: np.where(x == 9)
Out[806]: (array([2, 3], dtype=int32), array([1, 5], dtype=int32))

argwhere只是np.transpose(np.where(...));也就是说,将元组转换为2d数组并对其进行转置:

In [807]: np.argwhere(x ==9)
Out[807]: 
array([[2, 1],
       [3, 5]], dtype=int32)

用稀疏

做同样的事情
In [808]: xM = sparse.csr_matrix(x)
In [809]: xM == 9
Out[809]: 
<4x7 sparse matrix of type '<class 'numpy.bool_'>'
    with 2 stored elements in Compressed Sparse Row format>

np.wherenp.nonzero相同:

In [810]: (xM==9).nonzero()
Out[810]: (array([2, 3], dtype=int32), array([1, 5], dtype=int32))
In [811]: np.transpose((xM==9).nonzero())
Out[811]: 
array([[2, 1],
       [3, 5]], dtype=int32)

实际上,当前numpy argwhere与稀疏一起使用。那是因为np.nonzero委托给矩阵方法:

In [813]: np.argwhere(xM==9)
Out[813]: 
array([[2, 1],
       [3, 5]], dtype=int32)