numpy大数组索引崩溃了解释器

时间:2015-10-14 04:44:42

标签: python arrays numpy matrix

我想引用一个具有两个索引i和j数组的numpy矩阵数组。我在下面使用的方法工作正常但在处理非常大的数组时崩溃了解释器。我理解为什么会发生这种情况,但是我知道一种更好的方法来解决这个问题的方法太新了。

有没有办法用大型数组有效地实现以下代码?

import numpy as np
np.set_printoptions(precision=4,suppress=True)

def test(COUNT):
    M = np.random.random_sample((COUNT,4,4,)) # Many matrices
    i = np.random.randint(4, size=COUNT)
    j = np.random.randint(4, size=COUNT)

    # Debug prints
    print M # Print the source matrices for reference
    print i # Print the i indices for reference
    print j # Print the j indices, for reference

    # return the diagonal, this is where the code fails because
    # M[:,i,j] gets incredibly large. This is what i'm trying to solve
    return  M[:,i,j].diagonal() 
    #return np.einsum('ii->i', M[:,i,j])

一些例子:

# test 1 item, easy
print test(1)

[[[ 0.4158  0.2146  0.0371  0.4449]
  [ 0.8894  0.9889  0.0961  0.7343]
  [ 0.8905  0.2062  0.1663  0.04  ]
  [ 0.691   0.1203  0.6524  0.636 ]]]
[1]    
[0]
[ 0.8894]

完美,第一个(也是唯一的)矩阵的索引[1] [0]是0.884

# test 2 items
print test(2)

[[[ 0.0697  0.434   0.8456  0.592 ]
  [ 0.4413  0.8893  0.9973  0.9184]
  [ 0.7951  0.7392  0.8603  0.8069]
  [ 0.5054  0.3846  0.7708  0.0563]]

 [[ 0.7414  0.2676  0.4796  0.1424]
  [ 0.1203  0.9183  0.1341  0.074 ]
  [ 0.2375  0.3475  0.2298  0.9879]
  [ 0.7814  0.0262  0.4498  0.9864]]]
[2 3]
[1 1]
[ 0.7392  0.0262]

正如预期的那样,第一个矩阵的索引[2] [1]和第二个矩阵的[3] [1]的值是[0.7392 0.0262],一切都很好!...然而......

# too many items!
print test(1000000)

机器失速,因为M [:,i,j]对于所有扔掉的值来说太大了(我所关心的只是对角线)。

我用np.einsum稍微探讨一下它是否有帮助。但这对我来说太新了,所以现在我正在寻求一些帮助! :)

1 个答案:

答案 0 :(得分:1)

我不认为einsum会为您做任何事情 - 您只是将其用作diagonal的替代方案。但试试:

M[np.arange(COUNT),i,j]

这应该返回所需的元素,而不会收集额外的东西。

这是有效的,因为它相当于索引:

M[[0 1], [2 3], [1 1]]

即元素

M[0,2,1] and M[1,3,1]

另一个生成(COUNT,COUNT)矩阵,并从中提取对角(COUNT,)数组。