沿numpy数组的给定轴匹配并标记向量元素的位置

时间:2012-06-05 10:06:35

标签: arrays numpy

我想在数组的给定轴上标记所有最大值(哪个形状可能是n维),这在第一个上可以正常工作,但对于其余的我无法弄明白。我不想在轴上进行迭代,因为它们可能存在许多问题。

>>> A = range(5)*3
>>> A = array(a).reshape([3,5], order='F')
>>> A
array([[0, 3, 1, 4, 2],
       [1, 4, 2, 0, 3],
       [2, 0, 3, 1, 4]])
>>> B = amax(A, axis= 0)
>>> C = amax(A, axis= 1)
>>> B == A
array([[False, False, False,  True, False],
       [False,  True, False, False, False],
       [ True, False,  True, False,  True]], dtype=bool)

这就是我想要它做的事情:

>>> C == A
False

但(当然)它没有。

如何使这个工作?

3 个答案:

答案 0 :(得分:1)

在回答您的直接示例时,请执行以下操作:

>>>A == C
False

它“无效”,因为numpy不了解如何广播操作以提供您想要的输出。 使用转置两次,您可以获得比您提议的更简单的解决方案:

>>>C = amax(A, axis=1)
>>>transpose(C == transpose(A))
array([[False, False, False,  True, False],
       [False,  True, False, False, False],
       [False, False, False, False,  True]], dtype=bool)

答案 1 :(得分:1)

我终于想到了:

def tiletuple(t,axis):
  m = [1]*t.ndim
  m[axis] = t.shape[axis]
  return m

ax = 1
tile(expand_dims(amax(A, axis=ax), axis=ax), tiletuple(A, ax)) == A

答案 2 :(得分:1)

晚会,但这个怎么样:

rollaxis(amax(A, ax)[...,newaxis], -1, ax) == A

它基本上是再次插入轴,由于amax而从阵列中掉出来。然后,广播再次起作用。 或者,等效地:

a = list(A.shape)
a[ax] = 1
amax(A, ax).reshape(a) == A