我正在使用numpy.where来查找某些值的索引。但是,numpy.where会产生错误的索引,如下所示。有人可以解释我为什么会得到这样的错误指数吗?
感谢。
In [1]: d = np.random.rand(3,4)
In [2]: d
Out[2]:
array([[ 0.11694612, 0.95137658, 0.70099781, 0.06730629],
[ 0.59989836, 0.52586768, 0.45387929, 0.76093495],
[ 0.036541 , 0.91714289, 0.2246452 , 0.40785078]])
In [3]: np.where(d>0.9)
Out[3]: (array([0, 2]), array([1, 1]))
然而,
In [4]: d[0,2]
Out[4]: 0.70099781000000005
In[5]: d[1,1]
Out[5]: 0.52586767999999995
答案 0 :(得分:2)
问题是np.where
返回一个数组元组,其中索引位于条件所在的给定轴中。所以,也许这使得它更清晰:
>>> import numpy as np
>>> d = np.array([[ 0.11694612, 0.95137658, 0.70099781, 0.06730629],
... [ 0.59989836, 0.52586768, 0.45387929, 0.76093495],
... [ 0.036541 , 0.91714289, 0.2246452 , 0.40785078]])
>>> x, y = np.where(d > 0.9)
>>> d[x[0],y[0]]
0.95137658000000003
>>> d[x[1],y[1]]
0.91714289000000004
注意,这适用于numpy
索引的工作方式:
>>> d[x,y]
array([ 0.95137658, 0.91714289])
注意,这适用于任何维度:
>>> d.reshape(3,2,2)
array([[[ 0.11694612, 0.95137658],
[ 0.70099781, 0.06730629]],
[[ 0.59989836, 0.52586768],
[ 0.45387929, 0.76093495]],
[[ 0.036541 , 0.91714289],
[ 0.2246452 , 0.40785078]]])
>>> d = d.reshape(3,2,2)
>>> x, y, z = np.where(d > 0.9)
>>> x
array([0, 2])
>>> y
array([0, 0])
>>> z
array([1, 1])
>>> d[x,y,z]
array([ 0.95137658, 0.91714289])