测试Numpy数组是否包含给定行

时间:2013-02-08 05:20:35

标签: python numpy

是否有一种Pythonic且有效的方法来检查Numpy数组是否至少包含给定行的一个实例?通过"高效"我的意思是它终止于找到第一个匹配的行而不是迭代整个数组,即使已经找到了结果。

使用Python数组,这可以用if row in array:非常干净地完成,但是这不像我对Numpy数组所期望的那样工作,如下所示。

使用Python数组:

>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False

但是Numpy数组给出了不同的,而且看起来很奇怪的结果。 (__contains__的{​​{1}}方法似乎没有记录。)

ndarray

5 个答案:

答案 0 :(得分:36)

您可以使用.tolist()

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False

或使用视图:

>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False

或者生成numpy列表(可能非常慢):

any(([1,2] == x).all() for x in a)     # stops on first occurrence 

或使用numpy逻辑函数:

any(np.equal(a,[1,2]).all(1))

如果你计算时间:

import numpy as np
import time

n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()

t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]

tests=[ ('early hit',[t1, t1+1, t1+2]),
        ('middle hit',[t2,t2+1,t2+2]),
        ('late hit', [t3,t3+1,t3+2]),
        ('miss',[0,2,0])]

fmt='\t{:20}{:.5f} seconds and is {}'     

for test, tgt in tests:
    print('\n{}: {} in {:,} elements:'.format(test,tgt,n))

    name='view'
    t1=time.time()
    result=(a[...]==tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='python list'
    t1=time.time()
    result = True if tgt in b else False
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='gen over numpy'
    t1=time.time()
    result=any((tgt == x).all() for x in a)
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='logic equal'
    t1=time.time()
    np.equal(a,tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

你可以看到命中或未命中,numpy例程与搜索数组的速度相同。对于早期命中,Python in运算符可能的速度要快得多,如果必须一直通过数组,则生成器只是坏消息。

以下是300,000 x 3元素数组的结果:

early hit: [9000, 9001, 9002] in 300,000 elements:
    view                0.01002 seconds and is True
    python list         0.00305 seconds and is True
    gen over numpy      0.06470 seconds and is True
    logic equal         0.00909 seconds and is True

middle hit: [450000, 450001, 450002] in 300,000 elements:
    view                0.00915 seconds and is True
    python list         0.15458 seconds and is True
    gen over numpy      3.24386 seconds and is True
    logic equal         0.00937 seconds and is True

late hit: [899970, 899971, 899972] in 300,000 elements:
    view                0.00936 seconds and is True
    python list         0.30604 seconds and is True
    gen over numpy      6.47660 seconds and is True
    logic equal         0.00965 seconds and is True

miss: [0, 2, 0] in 300,000 elements:
    view                0.00936 seconds and is False
    python list         0.01287 seconds and is False
    gen over numpy      6.49190 seconds and is False
    logic equal         0.00965 seconds and is False

对于3,000,000 x 3阵列:

early hit: [90000, 90001, 90002] in 3,000,000 elements:
    view                0.10128 seconds and is True
    python list         0.02982 seconds and is True
    gen over numpy      0.66057 seconds and is True
    logic equal         0.09128 seconds and is True

middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
    view                0.09331 seconds and is True
    python list         1.48180 seconds and is True
    gen over numpy      32.69874 seconds and is True
    logic equal         0.09438 seconds and is True

late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
    view                0.09868 seconds and is True
    python list         3.01236 seconds and is True
    gen over numpy      65.15087 seconds and is True
    logic equal         0.09591 seconds and is True

miss: [0, 2, 0] in 3,000,000 elements:
    view                0.09588 seconds and is False
    python list         0.12904 seconds and is False
    gen over numpy      64.46789 seconds and is False
    logic equal         0.09671 seconds and is False

这似乎表明np.equal是执行此操作的最快的纯粹方式......

答案 1 :(得分:18)

Numpys __contains__ is, at the time of writing this, (a == b).any()如果b是一个标量(它有点毛茸茸,但我相信 - 只在1.7或更高版本中这样做 - 这可能是唯一正确的 - 这个将是正确的通用方法(a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any(),这对于ab维度的所有组合都有意义... ...

编辑:为了清楚起见,当涉及广播时,这 不一定是预期的结果。也有人可能会争辩说它应该a分别处理np.in1d中的项目。我不确定它应该有一个明确的方法。

现在你想要numpy在找到第一个匹配项时停止。此AFAIK目前不存在。这很难,因为numpy主要基于ufuncs,它在整个数组中做同样的事情。 Numpy确实优化了这种减少,但实际上只有在减少的数组已经是一个布尔数组(即np.ones(10, dtype=bool).any())时才有效。

否则它将需要__contains__的特殊功能,该功能不存在。这可能看起来很奇怪,但你必须记住numpy支持许多数据类型,并且有更大的机制来选择正确的数据并选择正确的函数来处理它。换句话说,ufunc机制无法做到这一点,并且由于数据类型的原因,实现__contains__或其他特殊实际上并不是那么微不足道。

你当然可以在python中编写它,或者因为你可能知道你的数据类型,所以在Cython / C中自己编写它非常简单。


那就是说。对于这些事情,使用基于排序的方法通常会好得多。这有点单调乏味,searchsorted没有lexsort这样的事情,但它有效(如果你愿意,你也可以滥用scipy.spatial.cKDTree)。这假设您只想沿最后一个轴进行比较:

# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()

# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.

sorted.sort()

b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)

result = sorted[ind] == b_comp

这也适用于数组b,如果你保持排序的数组,如果你同时在b中为单个值(行)执行它也会好得多,当时a保持不变(否则我只会np.in1d将其视为重新排列)。 重要提示:为安全起见,您必须执行np.ascontiguousarray。它通常什么都不做,但如果确实如此,那将是一个很大的潜在错误。

答案 2 :(得分:8)

我认为

equal([1,2], a).all(axis=1)   # also,  ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)

将列出匹配的行。正如Jamie所指出的,要知道是否存在至少一个这样的行,请使用any

equal([1,2], a).all(axis=1).any()
# True

除此之外:
我怀疑in(和__contains__)与上面一样,但使用的是any而不是all

答案 3 :(得分:1)

如果你真的想在第一次出现时停下来,你可以写一个循环,比如:

import numpy as np

needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
    if np.all(row == needle):
        found = True
        break
print("Found: ", found)

但是,我强烈怀疑,它会比使用numpy例程为整个数组执行它的其他建议慢得多。

答案 4 :(得分:1)

我将建议的解决方案与 perfplot 进行了比较,发现如果您要在长长的未排序列表中寻找 2 元组,

np.any(np.all(a == b, axis=1))

是最快的解决方案。如果在前几行中找到匹配项,则显式短路循环总是更快。

enter image description here

重现情节的代码:

import numpy as np
import perfplot


target = [6, 23]


def setup(n):
    return np.random.randint(0, 100, (n, 2))


def any_all(data):
    return np.any(np.all(target == data, axis=1))


def tolist(data):
    return target in data.tolist()

def loop(data):
    for row in data:
        if np.all(row == target):
            return True
    return False


def searchsorted(a):
    s = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
    s.sort()
    t = np.ascontiguousarray(target).view(s.dtype)
    ind = s.searchsorted(t)
    return (s[ind] == t)[0]


perfplot.save(
    "out02.png",
    setup=setup,
    kernels=[any_all, tolist, loop, searchsorted],
    n_range=[2 ** k for k in range(2, 20)],
    xlabel="len(array)",
)
相关问题