Numpy:用同一行中最多的其他元素替换行中的每个元素

时间:2018-01-23 02:49:25

标签: python numpy

假设我们有一个像这样的二维数组:

>>> a
array([[1, 1, 2],
   [0, 2, 2],
   [2, 2, 0],
   [0, 2, 0]])

对于每一行,我想用同一行中另外两个元素的最大值替换每个元素。

我已经找到了如何使用numpy.amax和一个标识数组分别对每个列执行此操作,如下所示:

>>> np.amax(a*(1-np.eye(3)[0]), axis=1)
array([ 2.,  2.,  2.,  2.])
>>> np.amax(a*(1-np.eye(3)[1]), axis=1)
array([ 2.,  2.,  2.,  0.])
>>> np.amax(a*(1-np.eye(3)[2]), axis=1)
array([ 1.,  2.,  2.,  2.])

但是我想知道是否有办法避免for循环并直接获得结果,在这种情况下应该如下所示:

>>> numpy_magic(a)
array([[2, 2, 1],
   [2, 2, 2],
   [2, 2, 2],
   [2, 0, 2]])

编辑:在控制台玩了几个小时之后,我终于找到了我想要的解决方案。准备好让一些人想起一行代码:

np.amax(a[[range(a.shape[0])]*a.shape[1],:][(np.eye(a.shape[1]) == 0)[:,[range(a.shape[1])*a.shape[0]]].reshape(a.shape[1],a.shape[0],a.shape[1])].reshape((a.shape[1],a.shape[0],a.shape[1]-1)),axis=2).transpose()
array([[2, 2, 1],
   [2, 2, 2],
   [2, 2, 2],
   [2, 0, 2]])

Edit2:Paul提出了一个更具可读性和更快速的替代方案:

np.max(a[:, np.where(~np.identity(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

在对这3个替代方案进行计时之后,两个Paul的解决方案在每种情况下的速度都快了4倍(我已经基准测试了2,3列和4列200行)。祝贺这些惊人的代码!

最后编辑(对不起):用更快的np.eye替换np.identity后,我们现在拥有最快,最简洁的解决方案:

np.max(a[:, np.where(~np.eye(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

3 个答案:

答案 0 :(得分:4)

以下是两个解决方案,一个专门针对max设计,另一个针对其他操作也适用。

使用除了每行中可能只有一个最大值之外的所有行都是整行的最大值的事实,我们可以使用argpartition来便宜地找到最大的两个元素的索引。然后在最大的位置我们把第二大的值和其他地方的值放在最大值。也适用于超过3列。

>>> a
array([[6, 0, 8, 8, 0, 4, 4, 5],
       [3, 1, 5, 0, 9, 0, 3, 6],
       [1, 6, 8, 3, 4, 7, 3, 7],
       [2, 1, 6, 2, 9, 1, 8, 9],
       [7, 3, 9, 5, 3, 7, 4, 3],
       [3, 4, 3, 5, 8, 2, 2, 4],
       [4, 1, 7, 9, 2, 5, 9, 6],
       [5, 6, 8, 5, 5, 3, 3, 3]])
>>> 
>>> M, N = a.shape
>>> result = np.empty_like(a)
>>> largest_two = np.argpartition(a, N-2, axis=-1)
>>> rng = np.arange(M)
>>> result[...] = a[rng, largest_two[:, -1], None]
>>> result[rng, largest_two[:, -1]] = a[rng, largest_two[:, -2]]>>> 
>>> result
array([[8, 8, 8, 8, 8, 8, 8, 8],
       [9, 9, 9, 9, 6, 9, 9, 9],
       [8, 8, 7, 8, 8, 8, 8, 8],
       [9, 9, 9, 9, 9, 9, 9, 9],
       [9, 9, 7, 9, 9, 9, 9, 9],
       [8, 8, 8, 8, 5, 8, 8, 8],
       [9, 9, 9, 9, 9, 9, 9, 9],
       [8, 8, 6, 8, 8, 8, 8, 8]])

此解决方案取决于最大的特定属性

更通用的解决方案,例如也适用于sum而不是max。将a的两个副本粘贴在一起(并排,不在彼此之上)。所以行类似于a0 a1 a2 a3 a0 a1 a2 a3。对于索引x,我们可以通过切片ax获得除[x+1:x+4]之外的所有内容。要执行此向量化,我们使用stride_tricks

>>> a
array([[2, 6, 0],
       [5, 0, 0],
       [5, 0, 9],
       [6, 4, 4],
       [5, 0, 8],
       [1, 7, 5],
       [9, 7, 7],
       [4, 4, 3]])
>>> M, N = a.shape
>>> aa = np.c_[a, a]
>>> ast = np.lib.stride_tricks.as_strided(aa, (M, N+1, N-1), aa.strides + aa.strides[1:])
>>> result = np.max(ast[:, 1:, :], axis=-1)
>>> result
array([[6, 2, 6],
       [0, 5, 5],
       [9, 9, 5],
       [4, 6, 6],
       [8, 8, 5],
       [7, 5, 7],
       [7, 9, 9],
       [4, 4, 4]])

# use sum instead of max
>>> result = np.sum(ast[:, 1:, :], axis=-1)
>>> result
array([[ 6,  2,  8],
       [ 0,  5,  5],
       [ 9, 14,  5],
       [ 8, 10, 10],
       [ 8, 13,  5],
       [12,  6,  8],
       [14, 16, 16],
       [ 7,  7,  8]])

答案 1 :(得分:1)

列表理解解决方案。

np.array([np.amax(a * (1 - np.eye(3)[j]), axis=1) for j in range(a.shape[1])]).T

答案 2 :(得分:1)

与@ Ethan的答案类似,但np.delete()np.max()np.dstack()

np.dstack([np.max(np.delete(a, i, 1), axis=1) for i in range(a.shape[1])])

array([[2, 2, 1],
       [2, 2, 2],
       [2, 2, 2],
       [2, 0, 2]])
  • delete()"过滤器"依次列出每一栏;
  • max()找到剩余两列的行方式最大值
  • dstack()堆叠生成的1d数组

如果您有超过3列,请注意这将找到"所有其他"的最大值。列而不是" 2-great"每行的列数。例如:

a2 = np.arange(25).reshape(5,5)
np.dstack([np.max(np.delete(a2, i, 1), axis=1) for i in range(a2.shape[1])])

array([[[ 4,  4,  4,  4,  3],
        [ 9,  9,  9,  9,  8],
        [14, 14, 14, 14, 13],
        [19, 19, 19, 19, 18],
        [24, 24, 24, 24, 23]]])