我编写了一个脚本,用于评估arr
中的某些条目是否在check_elements
中。我的方法不比较单个条目,但是arr
内的整个向量。因此,脚本会检查[8, 3]
,[4, 5]
,...是否在check_elements
。
以下是一个例子:
import numpy as np
# arr.shape -> (2, 3, 2)
arr = np.array([[[8, 3],
[4, 5],
[6, 2]],
[[9, 0],
[1, 10],
[7, 11]]])
# check_elements.shape -> (3, 2)
# generally: (n, 2)
check_elements = np.array([[4, 5], [9, 0], [7, 11]])
# rslt.shape -> (2, 3)
rslt = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)
for i, j in np.ndindex((arr.shape[0], arr.shape[1])):
if arr[i, j] in check_elements: # <-- condition is checked against
# the whole last dimension
rslt[i, j] = True
else:
rslt[i, j] = False
现在:
print(rslt)
...会打印:
[[False True False]
[ True False True]]
获取我使用的索引:
print(np.transpose(np.nonzero(rslt)))
...打印以下内容:
[[0 1] # arr[0, 1] -> [4, 5] -> is in check_elements
[1 0] # arr[1, 0] -> [9, 0] -> is in check_elements
[1 2]] # arr[1, 2] -> [7, 11] -> is in check_elements
如果我要检查单个值的条件,例如arr > 3
或np.where(...)
,那么此任务将非常简单且高效,但我不对单个值感兴趣。我想检查整个最后一个维度(或它的切片)的条件。
我的问题是:是否有更快的方法来实现相同的结果?我是否正确,矢量化尝试以及np.where
之类的内容可以不用于我的问题,因为它们始终在单个值上运行,而不是在整个维度或该维度的切片上运行?
答案 0 :(得分:2)
numpy_indexed包(免责声明:我是其作者)包含执行这类查询的功能;特别是,nd(子)数组的包含关系:
import numpy_indexed as npi
flatidx = npi.indices(arr.reshape(-1, 2), check_elements)
idx = np.unravel_index(flatidx, arr.shape[:-1])
请注意,实现是完全矢量化的。
另外,请注意,使用此方法,idx中索引的顺序与check_elements的顺序匹配; idx中的第一项是check_elements中第一项的行和列。当您使用上面发布的方法或使用其中一个替代建议答案时,此信息将丢失,这将使您的idx按照其在arr中的出现顺序排序,这通常是不合需要的。
答案 1 :(得分:2)
这是使用broadcasting的Numpythonic方法:
>>> (check_elements == arr[:,:,None]).reshape(2, 3, 6).any(axis=2)
array([[False, True, False],
[ True, False, True]], dtype=bool)
答案 2 :(得分:1)
你可以使用np.in1d
,即使它是为一维数组提供一个数组的一维视图,每个最后一个轴包含一个元素:
arr_view = arr.view((np.void, arr.dtype.itemsize*arr.shape[-1])).ravel()
check_view = check_elements.view((np.void,
check_elements.dtype.itemsize*check_elements.shape[-1])).ravel()
将为您提供两个1D数组,其中包含沿最后一个轴的2个元素数组的void
类型版本。现在,您可以通过执行以下操作来检查arr
中的哪些元素也位于check_view
中:
flatResult = np.in1d(arr_view, check_view)
这将给出一个展平的数组,然后您可以将其重新整形为arr
的形状,并删除最后一个轴:
print(flatResult.reshape(arr.shape[:-1]))
将为您提供所需的结果:
array([[False, True, False],
[ True, False, True]], dtype=bool)