Scipy稀疏矩阵切片返回IndexError

时间:2019-05-17 10:01:20

标签: python numpy scipy slice sparse-matrix

如果我尝试对稀疏矩阵进行切片或查看给定[row,colum]处的值,则会得到IndexError

更准确地说,我有以下scipy.sparse.csr_matrix,该文件是在保存后从文件加载的

...
>>> A = scipy.sparse.csr_matrix((vals, (rows, cols)), shape=(output_dim, input_dim))
>>> np.save(open('test_matrix.dat', 'wb'), A)
...
>>> A = np.load('test_matrix.dat', allow_pickle=True)
>>> A
array(<831232x798208 sparse matrix of type '<class 'numpy.float32'>'
    with 109886100 stored elements in Compressed Sparse Row format>,
      dtype=object)

但是,当我尝试获取给定[row,column]对的值时,出现以下错误

>>> A[1,1]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for array

为什么会这样?

请澄清一下,我确定矩阵不是空的,如果可以的话我可以看到其内容

>>> print(A)
  (0, 1)    0.24914551
  (0, 2)    0.6669922
  (1, 1)    0.75097656
  (1, 3)    0.6640625
  (2, 3)    0.3359375
  (2, 514)  0.34960938
...

1 个答案:

答案 0 :(得分:0)

保存并重新加载稀疏数组时,您创建的数组只有一个条目;一个对象,即您的稀疏数组。因此,A在[1,1]处没有任何东西。您应该改用scipy.sparse.save_npz

例如:

import scipy.sparse as sps
import numpy as np

A = sps.csr_matrix((10,10))
A
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
np.save('test_matrix.dat', A)
B = np.load('test_matrix.dat.npy', allow_pickle=True)
B
array(<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>, dtype=object)
B[1,1]
IndexError                                Traceback (most recent call last)
<ipython-input-101-969f8bd5206a> in <module>
----> 1 B[1,1]

IndexError: too many indices for array
sps.save_npz('sparse_dat')
C = sps.load_npz('sparse_dat.npz')
C
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
C[1,1]
0.0

请记住,您仍然可以像这样从A中检索B

D = B.tolist()
D
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
D[1,1]
0.0