我有一个numpy数组a
:
a = np.array([[-21.78878256, 97.37484004, -11.54228119],
[ -5.72592375, 99.04189958, 3.22814204],
[-19.80795922, 95.99377136, -10.64537733]])
我还有另一个数组b
:
b = np.array([[ 54.64642121, 64.5172014, 44.39991983],
[ 9.62420892, 95.14361441, 0.67014312],
[ 49.55036427, 66.25136632, 40.38778238]])
我想从数组b
中提取最小值索引。
ixs = [[2],
[2],
[2]]
然后,要使用索引a
从数组ixs
中提取元素:
预期的答案是:
result = [[-11.54228119]
[3.22814204]
[-10.64537733]]
我尝试过:
ixs = np.argmin(b, axis=1)
print ixs
[2,2,2]
result = np.take(a, ixs)
print result
不!
欢迎任何想法
答案 0 :(得分:1)
您可以使用
result = a[np.arange(a.shape[0]), ixs]
np.arange
将为每一行生成索引,ixs
将为每一列生成索引。因此,有效的结果将具有所需的结果。
答案 1 :(得分:0)
您可以尝试使用以下代码
np.take(a, ixs, axis = 1)[:,0]
初始部分将创建一个3 x 3的数组,并对第一列进行切片
>>> np.take(a, ixs, axis = 1)
array([[-11.54228119, -11.54228119, -11.54228119],
[ 3.22814204, 3.22814204, 3.22814204],
[-10.64537733, -10.64537733, -10.64537733]])