查找作为集合

时间:2015-12-11 21:57:19

标签: python arrays performance numpy

我想找到一个numpy数组的行,它们是一个集合的成员。例如:

wanted=set([(1,2),(8,9)])

z=np.array([[1,2],[8,8],[2,3]])

结果应为[1,2]。

我可以使用列表理解:

[b for b in z if tuple(b) in wanted]

但是当z有许多行和列时,这很慢。有更快的方法吗?

谢谢!

1 个答案:

答案 0 :(得分:2)

一种矢量化方法是 -

  • wanted转换为包含map()np.vstack的Numpy数组。

  • 使用None/np.newaxis扩展Numpy数组版本wanted的尺寸以形成3D数组,并与引入broadcastingz进行比较。

    < / LI>
  • 检查所有True行和ANY True第一轴匹配,为我们提供一个掩码,可用于索引z进行最终选择。

实施 -

wanted_arr = np.vstack((map(np.array,wanted)))
out = z[((wanted_arr[:,None] == z).all(2)).any(0)]

示例运行 -

In [64]: z
Out[64]: 
array([[1, 2],
       [8, 8],
       [2, 3]])

In [65]: wanted
Out[65]: {(1, 2), (8, 9)}

In [66]: wanted_arr = np.vstack((map(np.array,wanted)))

In [67]: wanted_arr
Out[67]: 
array([[1, 2],
       [8, 9]])

In [68]: z[((wanted_arr[:,None] == z).all(2)).any(0)]
Out[68]: array([[1, 2]])