我花了最后一小时试图解决这个问题
假设我们有
import numpy as np
a = np.random.rand(5, 20) - 0.5
amin_index = np.argmin(np.abs(a), axis=1)
print(amin_index)
> [ 0 12 5 18 1] # or something similar
这不起作用:
a[amin_index]
因此,实质上,我需要为数组np.abs(a)找到沿某个轴的最小值,然后在这些位置从数组a中提取值。如何仅将索引应用于一个轴?
可能非常简单,但我无法理解。此外,我不能使用任何循环,因为我必须为具有数百万条目的数组执行此操作。 谢谢
答案 0 :(得分:1)
是因为argmin
返回每行(包含axis=1
)的列索引,因此您需要访问这些特定列的每一行:
a[range(a.shape[0]), amin_index]
答案 1 :(得分:1)
一种方法是传入行索引数组(例如[0,1,2,3,4]
)和每个对应行(列表amin_index
)中最小值的列索引列表。
这会返回一个数组,其中包含每行[i, amin_index[i]]
的{{1}}值:
i
这是基本的索引(而不是高级索引),因此返回的数组实际上是>>> a[np.arange(a.shape[0]), amin_index]
array([-0.0069325 , 0.04268358, -0.00128002, -0.01185333, -0.00389487])
的视图,而不是内存中的新数组。
答案 2 :(得分:0)
为什么不简单地执行np.amin(np.abs(a), axis=1)
,如果通过argmin不需要中间amin_index
数组,则更简单?
Numpy的reference page是一个很好的资源,请参阅“索引”。
编辑:时间总是有用的:
In [3]: a=np.random.rand(4000, 4000)-.5
In [4]: %timeit np.amin(np.abs(a), axis=1)
10 loops, best of 3: 128 ms per loop
In [5]: %timeit a[np.arange(a.shape[0]), np.argmin(np.abs(a), axis=1)]
10 loops, best of 3: 135 ms per loop