Numpy ndarray的会员资格检查

时间:2016-08-26 10:55:21

标签: python numpy indexing

我编写了一个脚本,用于评估arr中的某些条目是否在check_elements中。我的方法比较单个条目,但是arr内的整个向量。因此,脚本会检查[8, 3][4, 5],...是否在check_elements

以下是一个例子:

import numpy as np

# arr.shape -> (2, 3, 2)
arr = np.array([[[8,  3],
                 [4,  5],
                 [6,  2]],

                [[9,  0],
                 [1, 10],
                 [7, 11]]])

# check_elements.shape -> (3, 2)
# generally: (n, 2)
check_elements = np.array([[4, 5], [9, 0], [7, 11]])

# rslt.shape -> (2, 3)
rslt = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)

for i, j in np.ndindex((arr.shape[0], arr.shape[1])):
    if arr[i, j] in check_elements:   # <-- condition is checked against
                                      #     the whole last dimension
        rslt[i, j] = True
    else:
        rslt[i, j] = False

现在:

print(rslt)

...会打印:

[[False  True False]
 [ True False  True]]

获取我使用的索引:

print(np.transpose(np.nonzero(rslt)))

...打印以下内容:

[[0 1]    # arr[0, 1] -> [4, 5] -> is in check_elements
 [1 0]    # arr[1, 0] -> [9, 0] -> is in check_elements
 [1 2]]   # arr[1, 2] -> [7, 11] -> is in check_elements

如果我要检查单个值的条件,例如arr > 3np.where(...),那么此任务将非常简单且高效,但我对单个值感兴趣。我想检查整个最后一个维度(或它的切片)的条件。

我的问题是:是否有更快的方法来实现相同的结果?我是否正确,矢量化尝试以及np.where之类的内容可以用于我的问题,因为它们始终在单个值上运行,而不是在整个维度或该维度的切片上运行?

3 个答案:

答案 0 :(得分:2)

numpy_indexed包(免责声明:我是其作者)包含执行这类查询的功能;特别是,nd(子)数组的包含关系:

import numpy_indexed as npi
flatidx = npi.indices(arr.reshape(-1, 2), check_elements)
idx = np.unravel_index(flatidx, arr.shape[:-1])

请注意,实现是完全矢量化的。

另外,请注意,使用此方法,idx中索引的顺序与check_elements的顺序匹配; idx中的第一项是check_elements中第一项的行和列。当您使用上面发布的方法或使用其中一个替代建议答案时,此信息将丢失,这将使您的idx按照其在arr中的出现顺序排序,这通常是不合需要的。

答案 1 :(得分:2)

这是使用broadcasting的Numpythonic方法:

>>> (check_elements == arr[:,:,None]).reshape(2, 3, 6).any(axis=2)
array([[False,  True, False],
       [ True, False,  True]], dtype=bool)

答案 2 :(得分:1)

你可以使用np.in1d,即使它是为一维数组提供一个数组的一维视图,每个最后一个轴包含一个元素:

arr_view = arr.view((np.void, arr.dtype.itemsize*arr.shape[-1])).ravel()
check_view = check_elements.view((np.void,
        check_elements.dtype.itemsize*check_elements.shape[-1])).ravel()

将为您提供两个1D数组,其中包含沿最后一个轴的2个元素数组的void类型版本。现在,您可以通过执行以下操作来检查arr中的哪些元素也位于check_view中:

flatResult = np.in1d(arr_view, check_view)

这将给出一个展平的数组,然后您可以将其重新整形为arr的形状,并删除最后一个轴:

print(flatResult.reshape(arr.shape[:-1]))

将为您提供所需的结果:

array([[False,  True, False],
       [ True, False,  True]], dtype=bool)