基本的NumPy数据比较

时间:2012-05-02 16:56:23

标签: python numpy

我有一个排列在2D数组中的N维值数组。类似的东西:

import numpy as np
data = np.array([[[1,2],[3,4]],[[5,6],[1,2]]])

我还有一个值x我希望与每个数据点进行比较,我想得到一个布尔值的二维数组,显示我的数据是否等于x

x = np.array([1,2])

如果我这样做:

data == x

我得到了

# array([[[ True,  True],
#        [False, False]],
#
#       [[False, False],
#        [ True,  True]]], dtype=bool)

我可以轻松地将这些结合起来以获得我想要的结果。但是,我不想迭代这些切片中的每一个,特别是当data.shape[2]更大时。我正在寻找的是直接的获取方式:

array([[ True,  False],
        [False, True]])

对于这项看似简单的任务的任何想法?

1 个答案:

答案 0 :(得分:2)

好吧,(data == x).all(axis=-1)可以为您提供所需内容。它仍在构建一个三维结果数组并对其进行迭代,但至少该迭代不是在Python级别,因此它应该相当快。