NumPy:从每一行中找到最大值,将其设置为1并保持为0

时间:2017-02-22 16:36:02

标签: python numpy

我有一个2D numpy数组,

array([[ 0.49596769,  1.15846407, -1.38944733],
       [-0.47042814, -0.07512128 , 1.90417981]], dtype=float32)

我想找到每一行的最大值并将其更改为1并将其保留为0.就像这样。

array([[ 0.,  1.,  0.],
       [ 0.,  0.,  1.]], dtype=float32)

使用numpy完成任务的最有效方法是什么?

2 个答案:

答案 0 :(得分:6)

一种方法是 -

(a == a.max(axis=1, keepdims=1)).astype(float)

示例运行 -

In [43]: a
Out[43]: 
array([[ 0.49596769,  1.15846407, -1.38944733],
       [-0.47042814, -0.07512128,  1.90417981]])

In [44]: (a == a.max(axis=1, keepdims=1)).astype(float)
Out[44]: 
array([[ 0.,  1.,  0.],
       [ 0.,  0.,  1.]])

如果一行中有多个具有相同最大值且您只想将第一个设置为1 -

idx = a.argmax(axis=1)
out = (idx[:,None] == np.arange(a.shape[1])).astype(float)

示例运行 -

In [49]: a
Out[49]: 
array([[2, 4, 4],
       [3, 4, 5]])

In [50]: (a == a.max(axis=1, keepdims=1)).astype(float)
Out[50]: 
array([[ 0.,  1.,  1.],
       [ 0.,  0.,  1.]])

In [51]: idx = a.argmax(axis=1)

In [52]: (idx[:,None] == np.arange(a.shape[1])).astype(float)
Out[52]: 
array([[ 0.,  1.,  0.],
       [ 0.,  0.,  1.]])

为了提高性能,我们可以采用基于初始化的方法 -

def initialization_based(a): 
    idx = a.argmax(axis=1)
    out = np.zeros_like(a,dtype=float)
    out[np.arange(a.shape[0]), idx] = 1
    return out

答案 1 :(得分:0)

或使用argsort:

 (a.argsort()==a.shape[1]-1).astype(int)

它将管理多个最大值。