我想找到一个numpy数组的行,它们是一个集合的成员。例如:
wanted=set([(1,2),(8,9)])
z=np.array([[1,2],[8,8],[2,3]])
结果应为[1,2]。
我可以使用列表理解:
[b for b in z if tuple(b) in wanted]
但是当z有许多行和列时,这很慢。有更快的方法吗?
谢谢!
答案 0 :(得分:2)
一种矢量化方法是 -
将wanted
转换为包含map()
和np.vstack
的Numpy数组。
使用None/np.newaxis
扩展Numpy数组版本wanted
的尺寸以形成3D数组,并与引入broadcasting
的z
进行比较。
检查所有True行和ANY True第一轴匹配,为我们提供一个掩码,可用于索引z
进行最终选择。
实施 -
wanted_arr = np.vstack((map(np.array,wanted)))
out = z[((wanted_arr[:,None] == z).all(2)).any(0)]
示例运行 -
In [64]: z
Out[64]:
array([[1, 2],
[8, 8],
[2, 3]])
In [65]: wanted
Out[65]: {(1, 2), (8, 9)}
In [66]: wanted_arr = np.vstack((map(np.array,wanted)))
In [67]: wanted_arr
Out[67]:
array([[1, 2],
[8, 9]])
In [68]: z[((wanted_arr[:,None] == z).all(2)).any(0)]
Out[68]: array([[1, 2]])