是集合中数组的元素吗?

时间:2016-06-12 13:48:45

标签: python numpy

import numpy
data = numpy.random.randint(0, 10, (6,8))
test = set(numpy.random.randint(0, 10, 5))

我想要一个表达式,其值是一个布尔数组,具有相同的data形状(或者,至少可以重新整形为相同的形状),它告诉我{{1}中是否有相应的术语在} data

例如,如果我想知道set的哪些元素严格小于data,我可以使用单个矢量化表达式,

6

计算a = data < 6 布尔ndarray。相反,当我尝试一个明显等效的布尔表达式

6x8

我得到的是一个例外:

b = data in test

附录 - 使用不同的解决方案

编辑:由于hpaulj,下面的可能性#4给出了错误的结果 和Divakar让我走上正轨。

在这里,我比较了四种不同的可能性,

  1. Divakar提出的建议,TypeError: unhashable type: 'numpy.ndarray'
  2. hpaulj的一项提案,np.in1d(data, np.hstack(test))
  3. hpaulj的另一个提议,`np.in1d(data,np.fromiter(test,int))。
  4. 在其作者删除的答案中提出的建议,其名称我不记得,np.in1d(data, np.array(list(test)))
  5. 这是Ipython会话,略微编辑以避免空行

    np.in1d(data, test)

    <击> In [1]: import numpy as np In [2]: nr, nc = 100, 100 In [3]: top = 3000 In [4]: data = np.random.randint(0, top, (nr, nc)) In [5]: test = set(np.random.randint(0, top, top//3)) In [6]: %timeit np.in1d(data, np.hstack(test)) 100 loops, best of 3: 5.65 ms per loop In [7]: %timeit np.in1d(data, np.array(list(test))) 1000 loops, best of 3: 1.4 ms per loop In [8]: %timeit np.in1d(data, np.fromiter(test, int)) 1000 loops, best of 3: 1.33 ms per loop
    In [9]: %timeit np.in1d(data, test)

    1000 loops, best of 3: 687 µs per loop

    <击> In [10]: nr, nc = 1000, 1000 In [11]: top = 300000 In [12]: data = np.random.randint(0, top, (nr, nc)) In [13]: test = set(np.random.randint(0, top, top//3)) In [14]: %timeit np.in1d(data, np.hstack(test)) 1 loop, best of 3: 706 ms per loop In [15]: %timeit np.in1d(data, np.array(list(test))) 1 loop, best of 3: 269 ms per loop In [16]: %timeit np.in1d(data, np.fromiter(test, int)) 1 loop, best of 3: 274 ms per loop
    In [17]: %timeit np.in1d(data, test)

    10 loops, best of 3: 67.9 ms per loop

    (现在)匿名海报的回答给出了更好的时间。

    事实证明,匿名海报有充分理由删除他们的答案,结果是错误的!

    正如hpaulj所评论的那样,在In [18]: 的文档中,有一个警告反对使用in1d作为第二个参数,但如果计算结果可以更好,我会更明确地失败是错的。

    也就是说,使用set的解决方案具有最佳数字......

2 个答案:

答案 0 :(得分:5)

我假设您正在寻找一个布尔数组来检测set数组中是否存在data元素。为此,您可以使用np.hstackset中提取元素,然后使用np.in1d检测来自set任何元素的存在。 strong data中的每个位置,给我们一个与data大小相同的布尔数组。由于np.in1d在处理之前使输入变平,因此作为最后一步,我们需要将np.in1d的输出重新整形为原始的2D形状。因此,最终的实施将是 -

np.in1d(data,np.hstack(test)).reshape(data.shape)

示例运行 -

In [125]: data
Out[125]: 
array([[7, 0, 1, 8, 9, 5, 9, 1],
       [9, 7, 1, 4, 4, 2, 4, 4],
       [0, 4, 9, 6, 6, 3, 5, 9],
       [2, 2, 7, 7, 6, 7, 7, 2],
       [3, 4, 8, 4, 2, 1, 9, 8],
       [9, 0, 8, 1, 6, 1, 3, 5]])

In [126]: test
Out[126]: {3, 4, 6, 7, 9}

In [127]: np.in1d(data,np.hstack(test)).reshape(data.shape)
Out[127]: 
array([[ True, False, False, False,  True, False,  True, False],
       [ True,  True, False,  True,  True, False,  True,  True],
       [False,  True,  True,  True,  True,  True, False,  True],
       [False, False,  True,  True,  True,  True,  True, False],
       [ True,  True, False,  True, False, False,  True, False],
       [ True, False, False, False,  True, False,  True, False]], dtype=bool)

答案 1 :(得分:3)

表达式a = data < 6返回一个新数组,因为<是一个值比较运算符。

  

Arithmetic, matrix multiplication, and comparison operations

     

ndarrays上的算术和比较操作定义为   元素操作,通常产生ndarray对象   结果

     

每个算术运算(+, - ,*,/,//,%,divmod(),**或   pow(),&lt;&lt;,&gt;&gt;,&amp;,^,|,〜)和比较(==,&lt;,&gt;,&lt; =,&gt; =,!=)   相当于相应的通用函数(或ufunc for   在Numpy。

请注意,in运算符不在此列表中。可能是因为它与大多数操作员的工作方向相反。

虽然a + ba.__add__(b)相同,但a in b从右到左b.__contains__(a)。在这种情况下,python尝试调用set.__contains__(),它只接受hashable / immutable类型。数组是可变的,因此它们不能成为集合的成员。

解决方法是直接使用 numpy.vectorize 而不是in,并在数组中的每个元素上调用任何python函数。

对于numpy数组来说,它是一种map()

  

numpy.vectorize

     

定义一个矢量化函数,该函数采用嵌套的对象序列   或numpy数组作为输入并返回一个numpy数组作为输出。该   向量化函数在连续的元组上计算pyfunc   输入数组,如python map函数,除了它使用   广播规则numpy。

>>> import numpy
>>> data = numpy.random.randint(0, 10, (3, 3))
>>> test = set(numpy.random.randint(0, 10, 5))
>>> numpy.vectorize(test.__contains__)(data)

array([[False, False,  True],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)

基准

当n很大时,这种方法很快,因为set.__contains__()是一个恒定时间操作。 (&#34;大&#34;表示top&gt; 13000左右)

>>> import numpy as np
>>> nr, nc = 100, 100
>>> top = 300000
>>> data = np.random.randint(0, top, (nr, nc))
>>> test = set(np.random.randint(0, top, top//3))
>>> %timeit -n10 np.in1d(data, list(test)).reshape(data.shape)
10 loops, best of 3: 26.2 ms per loop

>>> %timeit -n10 np.in1d(data, np.hstack(test)).reshape(data.shape)
10 loops, best of 3: 374 ms per loop

>>> %timeit -n10 np.vectorize(test.__contains__)(data)
10 loops, best of 3: 3.16 ms per loop

然而,当n很小时,其他解决方案明显更快。