numpy array - argmax的unravel_index v。和的argmax

时间:2012-08-29 13:08:21

标签: python numpy scipy

我试图沿一个轴边缘化一个阵列并检查1-D峰值出现在与原始2-D峰值相同的相关指数处。在什么情况下(x的形式)下面的断言应该失败?

def check(x,axis=None):
    import numpy
    m = numpy.sum(x, axis=axis)
    v,w = numpy.unravel_index(numpy.argmax(x), x.shape)
    assert(v==numpy.argmax(m))
    return

对于x=numpy.array(range(15)).reshape(5,3)check(x,axis=0)会引发错误,但check(x,axis=1)却没有。我不明白为什么AssertionError被提出 - 我是不是很蠢?

2 个答案:

答案 0 :(得分:1)

您正在检查拆开的索引的错误坐标。而不是

v,w=numpy.unravel_index(numpy.argmax(x),x.shape)
assert(v==numpy.argmax(m))
你可能想要

vw = numpy.unravel_index(numpy.argmax(x),x.shape)
assert vw[1 - axis] == numpy.argmax(m)

或者

v,w=numpy.unravel_index(numpy.argmax(x),x.shape)
assert (v if axis == 1 else w) == numpy.argmax(m)

答案 1 :(得分:1)

axis参数的值非常重要。

使用x (N, M)数组,m=np.sum(x, axis=axis)将为您提供

  • 标量axis=None;
  • 如果M; ,则
  • axis=0数组 如果N,则
  • axis=1数组。

因此,如果np.argmax(m) axis=None0M之间的整数(分别为N),则axis=0始终为0 axis=1 (分别为(v, w) = np.unravel_index(...))。

但是,您的v始终会将N作为0到axis=0之间的整数。

正如您所见,使用m时,v的潜在值范围与axis=1的潜在值不同,而m的潜在值范围则相同。

因此,如果v,则将axis=1w进行比较,如果axis=0,则将{{1}}与{{1}}进行比较(