如何在numpy.argmin中使用out参数来检索最小索引?

时间:2016-10-26 09:26:42

标签: python numpy scipy

根据numpy.argmin文档,第三个参数 out 应该提供索引(据我所知):

  

out 数组,可选   如果提供,将插入结果   这个数组。它应该是合适的形状和类型。

但我无法使用它来收集大型numpy.ndarray( dist 在以下代码段中)的最小二十个索引。

# here using smaller size arrays, but hold RGBs of images    
a = np.random.normal(size=(10,3))
b = np.random.normal(size=(1,3))
# compute distances between two arrays
dist = scipy.spatial.distance.cdist(a,b)
index_of_minimum = np.argmin(dist)
# create an array to hold the indices: same shape and same datatype
indices = np.array(dist.shape, dtype=np.float64)
np.argmin(dist, axis=0, out=indices)

但我得到以下错误:

  

ValueError:输出数组与np.argmin的结果不匹配。

有一个类似的问题before,但我的要求有所不同。首先,是否可以使用 out 参数收集最小指数?如果是这样我怎么能这样做?

2 个答案:

答案 0 :(得分:0)

两个问题:首先在你的第6行中有一个类型。应该说

dist

而不是

list

但这并不影响最后一部分。

第二个问题是指数的形状。它对我有用

indices = np.zeros(dist[0].shape, dtype=np.int64)

对于索引数组,您需要整数类型。由于轴= 0,预期形状的尺寸比原始dist阵列少一个。你可以通过在索引的启动中明确地取出0轴来实现这一点

答案 1 :(得分:0)

使用您在评论中提供的数据:

In [1856]: data
Out[1856]: 
array([[ 1.80575419],
       [ 1.8899017 ],
       [ 2.01251239],
       [ 1.23996376],
       [ 1.71429128],
       [ 1.43630964],
       [ 1.75586222],
       [ 1.3196145 ],
       [ 0.92899127],
       [ 1.57403994]])
In [1857]: i=np.argmin(data,axis=0)
In [1858]: i
Out[1858]: array([8], dtype=int32)
In [1859]: data[i]
Out[1859]: array([[ 0.92899127]])

argmin返回一个元素的索引 - 唯一的一个元素。

对于3个最小的值:

In [1878]: i=np.argsort(data,axis=0)[:3]
In [1879]: data[i.ravel()]
Out[1879]: 
array([[ 0.92899127],
       [ 1.23996376],
       [ 1.3196145 ]])

您已经找到了找到all最小值索引的答案。但这个答案只有在整数的情况下才有意义。由于四舍五入,只有特制的花车彼此相等。

How to return all the minimum indices in numpy

我可以找到接近最小值的所有值。这看起来像.5会很有趣:

In [1885]: i=np.argmin(data,axis=0)
In [1886]: i
Out[1886]: array([8], dtype=int32)
In [1887]: m=data[i[0]]
In [1888]: m
Out[1888]: array([ 0.92899127])
In [1889]: data==m    # only 1 value equals the minimum
Out[1889]: 
array([[False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [ True],
       [False]], dtype=bool)

但接近它的人可以找到:

In [1890]: np.abs(data-m)<.5
Out[1890]: 
array([[False],
       [False],
       [False],
       [ True],
       [False],
       [False],
       [False],
       [ True],
       [ True],
       [False]], dtype=bool)
 In [1892]: data[(data-m)<.5]
 Out[1892]: array([ 1.23996376,  1.3196145 ,  0.92899127])

我不需要abs,因为m是最低要求。 np.isclose(data,m,atol=.5)也有效。