检查numpy数组是否是另一个数组的子集

时间:2013-05-14 16:43:28

标签: python numpy set

类似的问题已经在SO上提出,但它们有更具体的限制,它们的答案不适用于我的问题。

一般来说,确定任意numpy数组是否是另一个数组的子集的最pythonic方法是什么?更具体地说,我有一个大约20000x3的数组,我需要知道完全包含在一个集合中的1x3元素的索引。更一般地说,有更多的pythonic方式编写以下内容:

master=[12,155,179,234,670,981,1054,1209,1526,1667,1853] #some indices of interest
triangles=np.random.randint(2000,size=(20000,3)) #some data
for i,x in enumerate(triangles):
  if x[0] in master and x[1] in master and x[2] in master:
    print i

对于我的用例,我可以放心地假设len(master)<< 20000.(因此,假设主人被分类是因为这很便宜也是安全的。)

5 个答案:

答案 0 :(得分:3)

您可以通过在列表推导中迭代数组来轻松完成此操作。玩具示例如下:

import numpy as np
x = np.arange(30).reshape(10,3)
searchKey = [4,5,8]
x[[0,3,7],:] = searchKey
x

给出

 array([[ 4,  5,  8],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 4,  5,  8],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [ 4,  5,  8],
        [24, 25, 26],
        [27, 28, 29]])

现在迭代元素:

ismember = [row==searchKey for row in x.tolist()]

结果是

[True, False, False, True, False, False, False, True, False, False]

您可以将其修改为您问题中的子集:

searchKey = [2,4,10,5,8,9]  # Add more elements for testing
setSearchKey = set(searchKey)
ismember = [setSearchKey.issuperset(row) for row in x.tolist()]

如果您需要索引,请使用

np.where(ismember)[0]

它给出了

array([0, 3, 7])

答案 1 :(得分:2)

您可以尝试以下两种方法:

1,使用套装。集合的实现与python字典非常相似,并且具有恒定的时间查找。这看起来很像你已经拥有的代码,只需从master创建一个集合:

master = [12,155,179,234,670,981,1054,1209,1526,1667,1853]
master_set = set(master)
triangles = np.random.randint(2000,size=(20000,3)) #some data
for i, x in enumerate(triangles):
  if master_set.issuperset(x):
    print i

2,使用搜索排序。这很好,因为它不要求你使用hashable类型并使用numpy内置函数。 searchsorted是master(大小)的log(N)和triangels大小的O(N)所以它也应该非常快,可能更快,具体取决于数组的大小等等。

master = [12,155,179,234,670,981,1054,1209,1526,1667,1853]
master = np.asarray(master)
triangles = np.random.randint(2000,size=(20000,3)) #some data
idx = master.searchsorted(triangles)
idx.clip(max=len(master) - 1, out=idx)
print np.where(np.all(triangles == master[idx], axis=1))

第二种情况假设master按照searchsorted暗示的顺序排序。

答案 2 :(得分:0)

对于numpy中的集合操作更自然(可能更快)的解决方案是使用numpy.lib.arraysetops中的函数。这些通常使您避免在Python的set类型之间来回转换。要检查一个数组是否是另一个数组的子集,请使用numpy.setdiff1d()并测试返回的数组的长度是否为0:

import numpy as np
a = np.arange(10)
b = np.array([1, 5, 9])
c = np.array([-5, 5, 9])
# is `a` a subset of `b`?
len(np.setdiff1d(a, b)) == 0 # gives False
# is `b` a subset of `a`?
len(np.setdiff1d(b, a)) == 0 # gives True
# is `c` a subset of `a`?
len(np.setdiff1d(c, a)) == 0 # gives False

您也可以选择设置assume_unique=True来提高速度。

实际上,numpy并没有像内置issubset()这样的功能(类似于set.issubset()),对此我感到有些惊讶。

另一种选择是使用numpy.in1d()(请参见https://stackoverflow.com/a/37262010/2020363

编辑:我刚刚意识到,在遥远的过去的某个时候,这让我很烦恼,以至于我编写了自己的简单函数:

def issubset(a, b):
    """Return whether sequence `a` is a subset of sequence `b`"""
    return len(np.setdiff1d(a, b)) == 0

答案 3 :(得分:0)

开始于:

master=[12,155,179,234,670,981,1054,1209,1526,1667,1853] #some indices of interest

triangles=np.random.randint(2000,size=(20000,3)) #some data

查找master中包含的三元组索引的最Python方式是什么?尝试将np.in1d与列表理解一起使用:

inds = [j for j in range(len(triangles)) if all(np.in1d(triangles[j], master))]

%timeit说〜0.5 s =半秒

->以更快速的方式(系数为1000!)避免python的慢循环吗?尝试将np.isinnp.sum结合使用,以获取np.arange的布尔掩码:

inds = np.where(
 np.sum(np.isin(triangles, master), axis=-1) == triangles.shape[-1])

%timeit说〜0.0005 s =半毫秒!

建议:避免尽可能多地遍历列表,因为与包含一个算术运算的python循环的单个迭代价格相同,您可以调用一个执行数千个相同算术运算的numpy函数

结论

似乎np.isin(arr1=triangles, arr2=master)是您要查找的函数,它提供了一个与arr1形状相同的布尔掩码,告诉arr1的每个元素是否也是arr2的元素;从这里开始,要求掩码行的总和为3(即,一行在三角形中的全长),即可为所需的三角形行(或使用np.arange的索引)提供一维掩码。

答案 4 :(得分:0)

还可以使用 np.isin,这可能比 @petrichor's answer 中的列表理解更有效。使用相同的设置:

import numpy as np

x = np.arange(30).reshape(10, 3)
searchKey = [4, 5, 8]
x[[0, 3, 7], :] = searchKey
array([[ 4,  5,  8],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 4,  5,  8],
       [12, 13, 14],
       [15, 16, 17],
       [18, 19, 20],
       [ 4,  5,  8],
       [24, 25, 26],
       [27, 28, 29]])

现在可以使用np.isin;默认情况下,它将按元素工作:

np.isin(x, searchKey)
array([[ True,  True,  True],
       [False,  True,  True],
       [False, False,  True],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False]])

我们现在必须过滤所有条目评估为 True 的行,我们可以使用 all

np.isin(x, searchKey).all(1)
array([ True, False, False,  True, False, False, False,  True, False,
       False])

如果现在想要相应的索引,可以使用np.where

np.where(np.isin(x, searchKey).all(1))
(array([0, 3, 7]),)