如何检查numpy数组列表是否包含给定的测试数组?

时间:2018-08-01 10:39:14

标签: python numpy

我有一个2018-06-07T12:22:00+0200 2018-06-07T12:53:00+0200 2018-06-07T13:22:00+0200 数组的列表,例如,

numpy

我有一个测试数组,说

a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]

我想检查b = np.random.rand(3, 3) 是否包含a。但是

b

引发以下错误:

  

ValueError:具有多个元素的数组的真值不明确。使用a.any()或a.all()

我想要什么的正确方法是什么?

6 个答案:

答案 0 :(得分:2)

您可以在(3, 3, 3)中制作一个形状为a的阵列:

a = np.asarray(a)

然后将其与b(在这里比较浮点数,因此我们应该使用isclose()

np.all(np.isclose(a, b), axis=(1, 2))

例如:

a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...]       # set b to some value we know will yield True

np.all(np.isclose(a, b), axis=(1, 2))
# array([False,  True, False])

答案 1 :(得分:0)

好吧,in无效,因为它确实有效

def in_(obj, iterable):
    for elem in iterable:
        if obj == elem:
            return True
    return False

现在,问题在于对于两个ndarray aba == b是一个数组(请尝试),而不是布尔值,因此if a == b失败。解决方案是定义一个新功能

def array_in(arr, list_of_arr):
     for elem in list_of_arr:
        if (arr == elem).all():
            return True
     return False

a = [np.arange(5)] * 3
b = np.ones(5)

array_in(b, a) # --> False

答案 2 :(得分:0)

此错误是因为如果abnumpy arrays,则a == b不会返回TrueFalse,而是{{在arrayboolean元素进行比较之后,得到a个值中的1}}。

您可以尝试以下操作:

b
  • np.any([np.all(a_s == b) for a_s in a]) 在这里,您将创建[np.all(a_s == b) for a_s in a]个值的列表,迭代boolean的元素,并检查a中的所有元素以及{{ 1}}一样。

  • 使用b,您可以检查数组中的任何元素是否为a

答案 3 :(得分:0)

this answer中指出,documentation指出:

  

对于诸如list,tuple,set,frozenset,dict或collections.deque之类的容器类型,y中的表达式x等于任意(x为e或x == e,y中的e)。

不过,

a[0]==b是一个数组,包含a[0]b的逐元素比较。该数组的整体真值显然是不明确的。如果所有元素都匹配,或者如果至少一个匹配,则大多数匹配,它们是否相同?因此,numpy迫使您明确表达自己的意思。您想知道的是测试所有元素是否相同。您可以使用numpy的{​​{3}}方法来做到这一点:

any((b is e) or (b == e).all() for e in a)

或放入一个函数

def numpy_in(arrayToTest, listOfArrays):
    return any((arrayToTest is e) or (arrayToTest == e).all()
               for e in listOfArrays)

答案 4 :(得分:0)

使用numpy中的array_equal

    import numpy as np
    a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
    b = np.random.rand(3,3)

    for i in a:
        if np.array_equal(b,i):
            print("yes")

答案 5 :(得分:0)

如@jotasi所强调的,由于数组中元素之间的比较,真值是不明确的。 该问题here以前有一个答案。总体而言,您的任务可以通过多种方式完成:

  1. 列表到列表:

您可以通过将列表转换为(3,3,3)形状的数组来使用“ in”运算符,如下所示:

    >>> a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
    >>> a= np.asarray(a)
    >>> b= a[1].copy()
    >>> b in a
    True
  1. np.all:

    >>> any(np.all((b==a),axis=(1,2)))
    True
    
  2. 列表理解: 这是通过遍历每个数组来完成的:

    >>> any([(b == a_s).all() for a_s in a])
    True
    

下面是上面三种方法的速度比较:

Speed Comparison

import numpy as np
import perfplot

perfplot.show(
    setup=lambda n: np.asarray([np.random.rand(3*3).reshape(3,3) for i in range(n)]),
    kernels=[
        lambda a: a[-1] in a,
        lambda a: any(np.all((a[-1]==a),axis=(1,2))),
        lambda a: any([(a[-1] == a_s).all() for a_s in a])
        ],
    labels=[
        'in', 'np.all', 'list_comperhension'
        ],
    n_range=[2**k for k in range(1,20)],
    xlabel='Array size',
    logx=True,
    logy=True,
    )