检查NumPy数组是否包含另一个数组

时间:2015-10-19 14:56:33

标签: python numpy

在Python 2.7中使用NumPy,我想创建一个n乘2的数组y。然后,我想检查此数组是否在其任何行中包含特定的1 x 2数组z

这是我到目前为止所尝试的内容,在这种情况下,n = 1:

x = np.array([1, 2]) # Create a 1-by-2 array
y = [x] # Create an n-by-2 array (n = 1), and assign the first row to x
z = np.array([1, 2]) # Create another 1-by-2 array
if z in y: # Check if y contains the row z
    print 'yes it is'

然而,这给了我以下错误:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我做错了什么?

2 个答案:

答案 0 :(得分:6)

您只需使用any((z == x).all() for x in y)即可。不过,我不知道它是否是最快的。

答案 1 :(得分:5)

你可以(y == z).all(1).any()

为了更详细一点,numpy将使用称为“广播”的东西自动地在更高维度上逐个元素进行比较。因此,如果y是您的n-by-2数组,而z是1-by-2数组,则y == z会将y的每一行与{{1}进行比较元素逐个元素。然后,您可以使用z获取所有元素匹配的行,并使用all(axis=1)查看是否有任何匹配。

所以这就是实践:

any()

这比基于循环或基于生成器的方法要快得多,因为它会对操作进行矢量化。