如果我有两个数组:
a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b = np.array([0, 1, 3, 4, 7, 8, 9])
n = 3
在不使用循环的情况下,如何找到a
的索引,该索引的值等于b
的值,且偏移了某个数字n
?
我认为类似的方法会起作用,但是我收到elementwise == comparison failed
警告,结果是一个空数组:
np.where(a == b + n)
这是一个for循环,可以完成我想做的事情:
for val in b:
print(np.where(a == val + n))
它输出:
(array([3]),)
(array([4]),)
(array([6]),)
(array([7]),)
(array([], dtype=int64),)
(array([], dtype=int64),)
(array([], dtype=int64),)
答案 0 :(得分:1)
这应该可以解决问题:
>>> a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> b = np.array([0, 1, 3, 4, 7, 8, 9])
>>> n = 3
>>> np.where((a == b[:, np.newaxis] + n).any(axis=0))
(array([3, 4, 6, 7]),)
这个想法是要在a
上广播b
,因为您想检查b
的每个可能值。归根结底,这是一个O(len(a) * len(b))
操作,因此您需要创建一个大小为len(a) * len(b)
的2D数组以进行矢量化。
通过在b
中插入另一个轴来实现广播,使其成为列向量:
>>> b
array([0, 1, 3, 4, 7, 8, 9])
>>> b[:, np.newaxis]
array([[0],
[1],
[3],
[4],
[7],
[8],
[9]])
因此,现在比较将返回2D数组:
>>> a == b[:, np.newaxis] + n
array([[False, False, False, True, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, False, True, False, False],
[False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False]])
存在i, j
值的索引True
表示b[i] == a[j]
。由于我们只关心与a
中的任何值匹配的b
的索引,因此我们只沿列查找 any 真值:
>>> (a == b[:, np.newaxis] + n).any(axis=0)
array([False, False, False, True, True, False, True, True, False, False])
最后,所需要做的就是在此处获取与True
值相对应的索引:
>>> np.where((a == b[:, np.newaxis] + n).any(axis=0))
(array([3, 4, 6, 7]),)
您的示例输入已对数组进行排序,这意味着您实际上可以在O(len(a) + len(b))
中进行线性运算,但通常无法对其进行向量化。基本上,您可以在每个数组中都有一个索引,在b
中一次递增一个索引,并在a
中不断递增,直到达到下一个值。当然,对于numpy数组,这通常要比我上面针对“典型”(即非大规模)数组的解决方案要低。