numpy.where是如何工作的?

时间:2017-02-01 06:25:42

标签: python numpy where

我可以理解遵循numpy行为。

>>> a
array([[ 0. ,  0. ,  0. ],
       [ 0. ,  0.7,  0. ],
       [ 0. ,  0.3,  0.5],
       [ 0.6,  0. ,  0.8],
       [ 0.7,  0. ,  0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. ,  0.7,  0.5,  0.8,  0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7,  0.7,  0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))

我理解0.7,0.7和0.8是[1,1],[3,2]和[4,0]所以我得到了元组(array[1,3,4] and array[1,2,0]),每个数组由0和1组成这三个要素的指数。然后我尝试了其他例子,看看我的理解是否正确。

>>> np.where(a == [0.3])
(array([2]), array([1]))

0.3在[2,1]中,所以结果看起来像我预期的那样。然后我试了

>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)

??我期望看到(array([2,2]),array([2,3]))。为什么我看到上面的输出?

>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))

我也无法理解第二个结果。有人可以向我解释一下吗?感谢。

1 个答案:

答案 0 :(得分:1)

首先要意识到的是np.where(a == [whatever])只是向您显示a == [whatever]为True的索引。因此,您可以通过查看a == [whatever]的值来获得提示。在你的情况下,"工作":

>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False,  True],
       [ True, False, False]], dtype=bool)

你没有得到你的想法。您认为这是分别要求每个元素的索引,而是它获取值在行中相同位置的位置。基本上这个比较做的是说"对于每一行,告诉我第一个元素是否为0.7,第二个元素是否为0.7,以及第三个元素是否为0.8"。然后它返回那些匹配位置的索引。换句话说,比较是在整行之间进行的,而不仅仅是单个值。最后一个例子:

>>> a == [0.8,0.7,0.7]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

您现在获得了不同的结果。它没有要求" a的值为0.8"的指数,它只要求在开头有0.8 的指数行 - 在后两个位置中的任何一个都是0.7。

只有当您比较的值与a的单行形状匹配时,才能执行此类行式比较。因此,当您使用两个元素列表进行尝试时,它会返回一个空集,因为它会尝试将列表作为标量值与数组中的各个值进行比较。

结果是,您无法在值列表中使用==,并希望它只是告诉您任何值发生的位置。等式将匹配值和位置(如果您比较的值与数组的行的形状相同),或者它将尝试将整个列表作为标量进行比较(如果形状不匹配)。如果你想独立搜索这些值,你需要做一些类似于Khris在评论中建议的内容:

np.where((a==0.3)|(a==0.5))

也就是说,您需要对单独的值进行两次(或更多次)单独比较,而不是对值列表进行单次比较。