使用np.where时,数组的索引太多了

时间:2017-10-13 18:18:10

标签: python numpy indices

我有代码:

multipart/form-data

a=b=np.arange(9).reshape(3,3) c=np.zeros(3) for x in range(3): c[x]=np.average(b[np.where(a<x+3)]) 的输出是

c

而不是for循环,我想使用数组(向量化),然后我做了以下代码:

>>>array([ 1. ,  1.5,  2. ])

但是它显示了IndexError:数组的索引太多了

至于a=b=np.arange(9).reshape(3,3) c=np.zeros(3) i=np.arange(3) c[i]=np.average(b[np.where(a<i[:,None,None]+3)])

它正确显示

a<i[:,None,None]+3

但是当我使用array([[[ True, True, True], [False, False, False], [False, False, False]], [[ True, True, True], [ True, False, False], [False, False, False]], [[ True, True, True], [ True, True, False], [False, False, False]]], dtype=bool) 时,它再次显示了IndexError:数组的索引太多了。我无法获得b[np.where(a<i[:,None,None]+3)]的正确输出。

1 个答案:

答案 0 :(得分:1)

我感觉你在试图在这里进行矢量化,尽管没有明确提到。现在,我认为你不能以矢量化的方式进行索引。为了以矢量化的方式解决你的问题,我建议使用np.tensordotmatrix-multiplication的帮助,使用broadcasting建议一个更有效的方法来减少和...你的试验。

因此,一个解决方案是 -

from __future__ import division

i = np.arange(3)
mask = a<i[:,None,None]+3
c = np.tensordot(b,mask,axes=((0,1),(1,2)))/mask.sum((1,2))

Related post to understand tensordot

可能改善表现

  • 在馈送到np.dot之前将掩码转换为float dtype,因为基于BLAS的矩阵乘法会更快。

  • 使用np.count_nonzero代替np.sum来计算布尔值。因此,使用它来替换{​​{1}}部分。