numpy:根据布尔数组选择元素

时间:2019-04-11 14:52:54

标签: python arrays numpy mask

我有一个数组和一个布尔数组(作为一种热编码)

a = np.arange(12).reshape(4,3)
b = np.array([
    [1,0,0],
    [0,1,0],
    [0,0,1],
    [0,0,1],
], dtype=bool)

print(a)
print(b)
# [[ 0  1  2]
#  [ 3  4  5]
#  [ 6  7  8]
#  [ 9 10 11]]
# [[ True False False]
#  [False  True False]
#  [False False  True]
#  [False False  True]]

我想使用布尔数组选择元素

print(a[:, [True, False, False]])
# array([[0],
#        [3],
#        [6],
#        [9]])

print(a[:, [False, True, False]])
# array([[ 1],
#        [ 4],
#        [ 7],
#        [10]])

但是,此选择基于所有行的相同模板布尔值。我想按行执行此操作:

print(a[:, b])
# IndexError: too many indices for array

我应该在...中输入什么,以便得到:

print(a[:, ...])
# array([[0],
#        [4],
#        [8],
#        [11]])

编辑:这类似于臭名昭著的CS231课程中使用的内容:

dscores = a
num_examples = 4 
# They had 300
y = b
dscores[range(num_examples),y]
# equivalent to
# a{:,b]

编辑2:在CS231示例中,y是一维的,不是一个热编码的!

他们在做dscores[[rowIdx],[columnIdx]]

2 个答案:

答案 0 :(得分:3)

通过b过滤后广播

a[b][:,None]
Out[168]: 
array([[ 0],
       [ 4],
       [ 8],
       [11]])

a[b,None]
Out[174]: 
array([[ 0],
       [ 4],
       [ 8],
       [11]])

答案 1 :(得分:0)

这是执行此操作的另一种方法。请注意,与高级索引相比,这效率很低。它仅用于教学目的,并说明可以使用多种方法解决问题。

In [275]: np.add.reduce(a*b, axis=1, keepdims=True)
Out[275]: 
array([[ 0],
       [ 4],
       [ 8],
       [11]])