获取numpy数组的索引,其中值与自身的偏移量为

时间:2019-09-20 23:40:09

标签: numpy

如果我有两个数组:

a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b = np.array([0, 1, 3, 4, 7, 8, 9])
n = 3

在不使用循环的情况下,如何找到a的索引,该索引的值等于b的值,且偏移了某个数字n

我认为类似的方法会起作用,但是我收到elementwise == comparison failed警告,结果是一个空数组:

np.where(a == b + n)

这是一个for循环,可以完成我想做的事情:

for val in b:
    print(np.where(a == val + n))

它输出:

(array([3]),)
(array([4]),)
(array([6]),)
(array([7]),)
(array([], dtype=int64),)
(array([], dtype=int64),)
(array([], dtype=int64),)

1 个答案:

答案 0 :(得分:1)

这应该可以解决问题:

>>> a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> b = np.array([0, 1, 3, 4, 7, 8, 9])
>>> n = 3
>>> np.where((a == b[:, np.newaxis] + n).any(axis=0))
(array([3, 4, 6, 7]),)

这个想法是要在a上广播b,因为您想检查b的每个可能值。归根结底,这是一个O(len(a) * len(b))操作,因此您需要创建一个大小为len(a) * len(b)的2D数组以进行矢量化。

通过在b中插入另一个轴来实现广播,使其成为列向量:

>>> b
array([0, 1, 3, 4, 7, 8, 9])
>>> b[:, np.newaxis]
array([[0],
       [1],
       [3],
       [4],
       [7],
       [8],
       [9]])

因此,现在比较将返回2D数组:

>>> a == b[:, np.newaxis] + n
array([[False, False, False,  True, False, False, False, False, False, False],
       [False, False, False, False,  True, False, False, False, False, False],
       [False, False, False, False, False, False,  True, False, False, False],
       [False, False, False, False, False, False, False,  True, False, False],
       [False, False, False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False, False, False]])

存在i, j值的索引True表示b[i] == a[j]。由于我们只关心与a中的任何值匹配的b的索引,因此我们只沿列查找 any 真值:

>>> (a == b[:, np.newaxis] + n).any(axis=0)
array([False, False, False,  True,  True, False,  True,  True, False, False])

最后,所需要做的就是在此处获取与True值相对应的索引:

>>> np.where((a == b[:, np.newaxis] + n).any(axis=0))
(array([3, 4, 6, 7]),)

您的示例输入已对数组进行排序,这意味着您实际上可以在O(len(a) + len(b))中进行线性运算,但通常无法对其进行向量化。基本上,您可以在每个数组中都有一个索引,在b中一次递增一个索引,并在a中不断递增,直到达到下一个值。当然,对于numpy数组,这通常要比我上面针对“典型”(即非大规模)数组的解决方案要低。