我正在尝试实现一个numpy函数,它将2D数组的每一行中的max替换为1,将所有其他数字替换为零:
>>> a = np.array([[0, 1],
... [2, 3],
... [4, 5],
... [6, 7],
... [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
[1. 0.]]
到目前为止我尝试了什么
def some_function(x):
a = np.zeros(x.shape)
a[:,np.argmax(x, axis=1)] = 1
return a
>>> b = some_function(a)
>>> b
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
答案 0 :(得分:16)
方法#1,调整你的:
>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
[0, 1],
[0, 1],
[0, 1],
[1, 0]])
[实际上,range
会正常工作;我出于习惯写了arange
。]
方法#2,使用max
代替argmax
来处理多个元素达到最大值的情况:
>>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
[1, 1],
[1, 0]])
答案 1 :(得分:2)
我更喜欢使用numpy.where:
a[np.where(a==np.max(a))] = 1
答案 2 :(得分:1)
a==np.max(a)
将来会出现错误,因此,这是经过调整的版本,可以继续正确播放。
我知道这个问题很古老,但是我认为我有一个不错的解决方案,与其他解决方案有点不同。
# get max by row and convert from (n, ) -> (n, 1) which will broadcast
row_maxes = a.max(axis=1).reshape(-1, 1)
np.where(a == row_maxes, 1, 0)
np.where(a == row_maxes).astype(int)
如果需要进行更新,则可以
a[:] = np.where(a == row_maxes, 1, 0)
答案 3 :(得分:0)
b = (a == np.max(a))
这对我有用