我有一个多维数组,我需要从最后一个维度的每一行中获取前k个元素。
>>> x = np.random.random_integers(0, 100, size=(2,1,1,5))
>>> x
array([[[[99, 39, 10, 18, 68]]],
[[[22, 3, 13, 56, 2]]]])
我正在尝试获取:
array([[[[ 99., 68.]]],
[[[ 18., 99.]]]])
我可以使用以下方法获取索引,但是我不确定如何切出值。
>>> k = 2
>>> parts = np.flip(-1 - np.arange(k), 0)
>>> indices = np.flip(
... np.argpartition(x, parts, axis=-1)[..., -k:],
... axis=-1)
>>> indices
array([[[[0, 4]]],
[[[3, 0]]]])
答案 0 :(得分:0)
这可以解决您的问题。
np.sort(x, axis=len(x.shape)-1)[...,-2:]
答案 1 :(得分:0)
np.partition(x, 2)[..., -2:]
每行返回2个最大的元素。例如,
x = np.random.random_integers(0, 100, size=(2,1,1,5))
print(x)
print(np.partition(x, 2)[..., -2:])
打印类似
[[[[79 34 90 80 56]]]
[[[78 11 24 20 42]]]]
[[[[80 90]]]
[[[78 42]]]]