嗨,我有一个numpy数组,例如。
arr = np.random.rand(4,5)
array([[0.70733982, 0.1770464 , 0.55588376, 0.8810145 , 0.43711158],
[0.22056565, 0.0193138 , 0.89995761, 0.75157581, 0.21073093],
[0.22333035, 0.92795789, 0.3903581 , 0.41225472, 0.74992639],
[0.92328687, 0.20438876, 0.63975818, 0.6179422 , 0.40596821]])
我需要找到数组中的前三个最大元素。我尝试过
arr[[-arr.argsort(axis=-1)[:, :3]]]
我还在StackOverflow上引用了此question,它仅提供索引而不提供值
我能够获取前三个最大值的索引,但是如何获取其对应的值呢?
我也尝试通过转换为给定here的列表来对数组进行排序
但是没有给我所需的结果。有想法吗?
答案 0 :(得分:1)
您可以直接使用np.sort()
:
# np.sort sorts in ascending order
# --> we apply np.sort -arr
arr_sorted = -np.sort(-arr,axis=1)
top_three = arr_sorted[:,:3]
答案 1 :(得分:1)
该问题已经有一个有效的公认答案,但是我只想指出,在数组较大的情况下,使用np.partition
而不是np.sort
会更快。我们仍然使用np.sort
,但仅在组成行前三位的数组的一小部分上使用。
arr = np.random.random((10000, 10000))
top_three_fast = np.sort(np.partition(arr, -3)[:, -3:])[:, ::-1]
时间:
In [22]: %timeit top_three_fast = np.sort(np.partition(arr, -3)[:, -3:])[:, ::-1]
1.04 s ± 8.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [23]: %timeit top_three_slow = -np.sort(-arr, axis=1)[:, :3]
6.22 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [24]: (top_three_slow == top_three_fast).all()
Out[24]: True