沿3D NumPy数组中的轴将最大值设置为1,其余设置为零

时间:2018-08-07 17:44:14

标签: python numpy

我有一个3D阵列:

volts = np.random.random((3,3,3)).round(decimals=5)
>>> volts
array([[[0.94785, 0.43955, 0.74527],
    [0.82098, 0.52509, 0.67954],
    [0.72355, 0.16252, 0.03184]],

   [[0.25782, 0.04191, 0.6689 ],
    [0.18215, 0.63108, 0.52052],
    [0.81992, 0.36301, 0.66629]],

   [[0.90585, 0.27223, 0.78807],
    [0.32251, 0.65861, 0.70398],
    [0.21687, 0.20798, 0.33868]]])

5个小数就足够我的应用了。

>>> volts[0,0,:]
array([0.94785, 0.43955, 0.74527])  
>>> volts[0,1,:]
array([0.82098, 0.52509, 0.67954])

在以上两行中,我想将0.947850.82098设置为1,并将所有元素都设置为零。不仅volts[0,0,:] and volts[0,1,:],而且还有所有其他volts[x,y,:]。所以我这样做了:

>>> volts = np.random.random((3,3,3)).round(decimals=5)
>>> volts1 = deepcopy(volts)
>>> vmaxs=volts1.max(axis=2).flatten().tolist()
>>> for items in vmaxs:
       volts1[np.where(volts1==items)]=1


>>> volts1[np.where(volts1!=1)]=0
>>> volts1
array([[[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]],

       [[0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]]])
>>> volts[0,0,:]
array([0.90763, 0.38579, 0.25768])
>>> volts1[0,0,:]
array([1., 0., 0.])
.
.
.
>>> volts[2,2,:]
array([0.33343, 0.73859, 0.43735])
>>> volts1[2,2,:]
array([0., 1., 0.])

您会看到沿轴2的最大值设置为1,其余元素设置为零。在这里,我仅迭代9个元素,但是如果我必须迭代200-300个元素呢?怎么能不那么冗长又有效呢?

2 个答案:

答案 0 :(得分:3)

使用 np.eye 。如果有多个最大值,则将选择第一个。

np.eye(volts.shape[1])[volts.argmax(2)]

array([[[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]],

       [[0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.]]])

答案 1 :(得分:2)

方法1 :这是broadcasting-

mask = volts.argmax(axis=-1)[...,None] == np.arange(volts.shape[-1])
out = mask.astype(volts.dtype)

方法2 为提高性能,请使用array-assignment-

out = np.zeros(volts.shape)
idx = volts.argmax(axis=-1)
out[np.arange(volts.shape[0])[:,None],np.arange(volts.shape[1]),idx] = 1

时间-

In [90]: np.random.seed(0)

In [91]: volts = np.random.random((300,300,300))

# @user3483203's soln
In [92]: %timeit np.eye(volts.shape[1])[volts.argmax(2)]
10 loops, best of 3: 56.3 ms per loop

# Appproach #1 from this post
In [93]: %%timeit
    ...: mask = volts.argmax(axis=-1)[...,None] == np.arange(volts.shape[-1])
    ...: out = mask.astype(volts.dtype)
10 loops, best of 3: 90.9 ms per loop

# Appproach #2 from this post
In [94]: %%timeit
    ...: out = np.zeros(volts.shape)
    ...: idx = volts.argmax(axis=-1)
    ...: out[np.arange(volts.shape[0])[:,None],np.arange(volts.shape[1]),idx] = 1
10 loops, best of 3: 41.8 ms per loop