我试图沿一个轴边缘化一个阵列并检查1-D峰值出现在与原始2-D峰值相同的相关指数处。在什么情况下(x
的形式)下面的断言应该失败?
def check(x,axis=None):
import numpy
m = numpy.sum(x, axis=axis)
v,w = numpy.unravel_index(numpy.argmax(x), x.shape)
assert(v==numpy.argmax(m))
return
对于x=numpy.array(range(15)).reshape(5,3)
,check(x,axis=0)
会引发错误,但check(x,axis=1)
却没有。我不明白为什么AssertionError
被提出 - 我是不是很蠢?
答案 0 :(得分:1)
您正在检查拆开的索引的错误坐标。而不是
v,w=numpy.unravel_index(numpy.argmax(x),x.shape)
assert(v==numpy.argmax(m))
你可能想要
vw = numpy.unravel_index(numpy.argmax(x),x.shape)
assert vw[1 - axis] == numpy.argmax(m)
或者
v,w=numpy.unravel_index(numpy.argmax(x),x.shape)
assert (v if axis == 1 else w) == numpy.argmax(m)
答案 1 :(得分:1)
axis
参数的值非常重要。
使用x
(N, M)
数组,m=np.sum(x, axis=axis)
将为您提供
axis=None
; M
; ,则axis=0
数组
如果N
,则axis=1
数组。 因此,如果np.argmax(m)
axis=None
或0
与M
之间的整数(分别为N
),则axis=0
始终为0 axis=1
(分别为(v, w) = np.unravel_index(...)
)。
但是,您的v
始终会将N
作为0到axis=0
之间的整数。
正如您所见,使用m
时,v
的潜在值范围与axis=1
的潜在值不同,而m
的潜在值范围则相同。
因此,如果v
,则将axis=1
与w
进行比较,如果axis=0
,则将{{1}}与{{1}}进行比较(