给定阈值alpha
和numpy数组a
,找到第一个索引i
有多种可能性arr[i] > alpha
;见Numpy first occurrence of value greater than existing value:
numpy.searchsorted(a, alpha)+1
numpy.argmax(a > alpha)
就我而言,alpha
可以是标量,也可以是任意形状的数组。我希望函数get_lowest
适用于两种情况:
alpha = 1.12
arr = numpy.array([0.0, 1.1, 1.2, 3.0])
get_lowest(arr, alpha) # 2
alpha = numpy.array(1.12, -0.5, 2.7])
arr = numpy.array([0.0, 1.1, 1.2, 3.0])
get_lowest(arr, alpha) # [2, 0, 3]
任何提示?
答案 0 :(得分:2)
您可以使用广播:
In [9]: arr = array([ 0. , 1.1, 1.2, 3. ])
In [10]: alpha = array([ 1.12, -0.5 , 2.7 ])
In [11]: np.argmax(arr > np.atleast_2d(alpha).T, axis=1)
Out[11]: array([2, 0, 3])
要折叠多维数组,可以使用np.squeeze
,但如果在第一种情况下需要Python浮点数,则可能需要做一些特殊的事情:
def get_lowest(arr, alpha):
b = np.argmax(arr > np.atleast_2d(alpha).T, axis=1)
b = np.squeeze(b)
if np.size(b) == 1:
return float(b)
return b
答案 1 :(得分:1)
searchsorted
实际上可以解决问题:
numpy.searchsorted(a, alpha)
axis
的{{1}}参数有助于解决问题;此
argmax
诀窍。确实
numpy.argmax(numpy.add.outer(alpha, -a) < 0, axis=-1)