一个numpy.where错误?

时间:2017-03-29 20:25:30

标签: python numpy

我正在使用numpy.where来查找某些值的索引。但是,numpy.where会产生错误的索引,如下所示。有人可以解释我为什么会得到这样的错误指数吗?

感谢。

In [1]: d = np.random.rand(3,4)

In [2]: d
Out[2]: 
array([[ 0.11694612,  0.95137658,  0.70099781,  0.06730629],
       [ 0.59989836,  0.52586768,  0.45387929,  0.76093495],
       [ 0.036541  ,  0.91714289,  0.2246452 ,  0.40785078]])

In [3]: np.where(d>0.9)
Out[3]: (array([0, 2]), array([1, 1]))

然而,

In [4]: d[0,2]
Out[4]: 0.70099781000000005
In[5]: d[1,1]
Out[5]: 0.52586767999999995

1 个答案:

答案 0 :(得分:2)

问题是np.where返回一个数组元组,其中索引位于条件所在的给定轴中。所以,也许这使得它更清晰:

>>> import numpy as np
>>> d = np.array([[ 0.11694612,  0.95137658,  0.70099781,  0.06730629],
...        [ 0.59989836,  0.52586768,  0.45387929,  0.76093495],
...        [ 0.036541  ,  0.91714289,  0.2246452 ,  0.40785078]])
>>> x, y = np.where(d > 0.9)
>>> d[x[0],y[0]]
0.95137658000000003
>>> d[x[1],y[1]]
0.91714289000000004

注意,这适用于numpy索引的工作方式:

>>> d[x,y]
array([ 0.95137658,  0.91714289])

注意,这适用于任何维度:

>>> d.reshape(3,2,2)
array([[[ 0.11694612,  0.95137658],
        [ 0.70099781,  0.06730629]],

       [[ 0.59989836,  0.52586768],
        [ 0.45387929,  0.76093495]],

       [[ 0.036541  ,  0.91714289],
        [ 0.2246452 ,  0.40785078]]])
>>> d = d.reshape(3,2,2)
>>> x, y, z = np.where(d > 0.9)
>>> x
array([0, 2])
>>> y
array([0, 0])
>>> z
array([1, 1])
>>> d[x,y,z]
array([ 0.95137658,  0.91714289])