numpy 3D数组中最高索引的返回值

时间:2018-12-12 18:52:08

标签: python arrays numpy

我在numpy中有一个3D数组,其中包含nans。我需要返回沿0轴具有最大索引位置的值。答案将减少为2D阵列。

关于找到沿轴(How to get the index of a maximum element in a numpy array along one axis)的最大值的索引位置有很多问题,但这与我需要的有所不同。

示例3D数组:

>>> import numpy as np
>>> foo = np.asarray([[[7,4,6],[4,2,11], [7,8,9], [4,8,2]],[[1,2,3],[np.nan,5,8], [np.nan,np.nan,10], [np.nan,np.nan,7]]])
>>> foo
array([[[  7.,   4.,   6.],
        [  4.,   2.,  11.],
        [  7.,   8.,   9.],
        [  4.,   8.,   2.]],

       [[  1.,   2.,   3.],
        [ nan,   5.,   8.],
        [ nan,  nan,  10.],
        [ nan,  nan,   7.]]])

我以为我已经接近使用np.where了,但是它返回了所有不是nan的元素。并不是我所需要的,因为我想要一个(4,3)数组。

>>> zoo = foo[np.where(~np.isnan(foo))]
>>> zoo
array([  7.,   4.,   6.,   4.,   2.,  11.,   7.,   8.,   9.,   4.,   8.,
     2.,   1.,   2.,   3.,   5.,   8.,  10.,   7.])

我需要的答案是:

>>> ans = np.asarray([[1,2,3], [4,5,8], [7,8,10], [4,8,7]])
>>> ans
array([[ 1,  2,  3],
       [ 4,  5,  8],
       [ 7,  8, 10],
       [ 4,  8,  7]])

编辑:我编辑了foo示例数组以使问题更清楚。

2 个答案:

答案 0 :(得分:1)

您可以使用np.nanmax

>>> np.nanmax(foo, axis=0)
array([[ 7.,  4.,  6.],
       [ 4.,  5., 11.],
       [ 7.,  8., 10.],
       [ 4.,  8.,  7.]])

np.nanmax函数返回数组的最大值或沿轴的最大值,而忽略所有NaN。

编辑

正如您在注释中正确指出的那样,您需要最大索引处的值,而上面的代码不会返回该值。

您可以使用apply-along-axis

>>> def highest_index(a):
...     return a[~np.isnan(a)][-1] # return non-nan value at highest index

>>> np.apply_along_axis(highest_index, 0, foo)
array([[ 1.  2.  3.]
       [ 4.  5.  8.]
       [ 7.  8. 10.]
       [ 4.  8.  7.]])

答案 1 :(得分:0)

向量解决方案,仅具有索引:

def last_non_nan(foo):
    i = np.isnan(foo)[::-1].argmin(0)
    j,k = np.indices(foo[0].shape)
    return foo[-i-1,j,k]

i包含反向“行”中第一个非nan数字的索引。 因此-i-1是它在直线上的索引。

>>> last_non_nan(foo):
  [[  1.,   2.,   3.],
   [  4.,   5.,   8.],
   [  7.,   8.,  10.],
   [  4.,   8.,   7.]]

highest_index快:

In [5]%timeit last_non_nan(foo)
133 µs ± 29.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [6]: %timeit np.apply_along_axis(highest_index,0,foo)
667 µs ± 90 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
对于具有90%nans的(10,400,400)阵列,

最多可快150倍(40 ms vs 6 s)。

主要是因为last_non_nan在计算索引并提取所有非nan值时,仅获取每行中的最后一个非nan值。