从numpy数组中的最大值进行掩码,特定轴

时间:2017-12-06 15:47:03

标签: python numpy

输入示例:

我有一个numpy数组,例如

a=np.array([[0,1], [2, 1], [4, 8]])

期望的输出:

我想生成一个掩码数组,其沿给定轴具有最大值,在我的情况下为轴1,为True,其他所有为False。例如在这种情况下

mask = np.array([[False, True], [True, False], [False, True]])

尝试:

我尝试过使用np.amax的方法,但这会在展平列表中返回最大值:

>>> np.amax(a, axis=1)
array([1, 2, 8])

np.argmax类似地返回沿该轴的最大值的索引。

>>> np.argmax(a, axis=1)
array([1, 0, 1])

我可以通过某种方式迭代这个但是一旦这些数组变得更大,我希望解决方案保持在numpy本地的东西。

5 个答案:

答案 0 :(得分:7)

方法#1

使用broadcasting,我们可以使用与最大值的比较,同时保持dims以方便broadcasting -

a.max(axis=1,keepdims=1) == a

示例运行 -

In [83]: a
Out[83]: 
array([[0, 1],
       [2, 1],
       [4, 8]])

In [84]: a.max(axis=1,keepdims=1) == a
Out[84]: 
array([[False,  True],
       [ True, False],
       [False,  True]], dtype=bool)

方法#2

另外,对于列中的索引范围,还有argmaxbroadcasted-comparison个案例的索引 -

In [92]: a.argmax(axis=1)[:,None] == range(a.shape[1])
Out[92]: 
array([[False,  True],
       [ True, False],
       [False,  True]], dtype=bool)

方法#3

要完成设置,如果我们正在寻找性能,请使用初始化然后advanced-indexing -

out = np.zeros(a.shape, dtype=bool)
out[np.arange(len(a)), a.argmax(axis=1)] = 1

答案 1 :(得分:3)

创建一个单位矩阵,并使用数组中的argmax从行中选择:

include_tasks

请注意,这会忽略关系,它会与markResolved() { let reportingFailure = this.get('reportingFailure'); reportingFailure.resolvedAt = '2017-01-01'; reportingFailure.save(); } 返回的值一致。

答案 2 :(得分:2)

你已经回答了一半。沿轴计算 max 之后,您可以将其与输入数组进行比较,并且您将拥有所需的二进制掩码!

In [7]: maxx = np.amax(a, axis=1)

In [8]: maxx
Out[8]: array([1, 2, 8])

In [12]: a >= maxx[:, None]
Out[12]: 
array([[False,  True],
       [ True, False],
       [False,  True]], dtype=bool)

注意:在amaxx之间进行比较时使用NumPy broadcasting

答案 3 :(得分:0)

在线:np.equal(a.max(1)[:,None],a)np.equal(a.max(1),a.T).T

但这可能导致连续几个。

答案 4 :(得分:0)

在多维情况下,您也可以使用np.indices。假设您有一个数组:

a = np.array([[
    [0, 1, 2],
    [3, 8, 5],
    [6, 7, -1],
    [9, 5, 8]],[
    [5, 2, 8],
    [7, 6, -3],
    [-1, 2, 1],
    [3, 5, 6]]
])

您可以像这样访问为轴0计算的argmax值:

k = np.zeros((2, 4, 3), np.bool)
k[a.argmax(0), ind[0], ind[1]] = 1

输出为:

array([[[False, False, False],
        [False,  True,  True],
        [ True,  True, False],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [ True, False, False],
        [False, False,  True],
        [False, False, False]]])