我正在处理像(110,80,817)这样的大型3D数组,并希望在某些单元测试中比较两个数组。但是,numpy.assert_almost_equal
的默认输出无法帮助我轻松追踪错误。例如:
> raise AssertionError(msg)
E AssertionError:
E Arrays are not almost equal to 7 decimals
E
E (mismatch 0.0314621119395%)
E x: array([[[ 0., 0., 0., ..., 0., 0., 0.],
E [ 0., 0., 0., ..., 0., 0., 0.],
E [ 0., 0., 0., ..., 0., 0., 0.],...
E y: array([[[ 0., 0., 0., ..., 0., 0., 0.],
E [ 0., 0., 0., ..., 0., 0., 0.],
E [ 0., 0., 0., ..., 0., 0., 0.],...
有没有办法轻松查看哪些3D索引未通过此断言?
答案 0 :(得分:2)
您可以将np.isclose
与np.where
结合使用
idx = zip(*np.where(~np.isclose(a, b, atol=0, rtol=1e-7)))
现在idx
将是断言失败的所有索引(x,y,z)
的列表。