我想创建一个函数,它将数组中沿指定轴的最近值返回给定值。
要获取最接近值的索引,我使用以下代码,其中arr
是一个多维数组,value
是要查找的值:
def nearest_index( arr, value, axis=None ):
return ( np.abs( arr - value ) ).argmin( axis=axis )
但是我很难使用这个函数的结果从数组中获取值。
使用1D阵列很容易:
In [14]: arr_1 = np.random.randint( 10, 100, size=( 10, ) )
In [15]: arr_1
Out[15]: array([67, 49, 90, 29, 60, 80, 31, 55, 29, 10])
In [16]: nearest_index( arr_1, 50 )
Out[16]: 1
In [17]: arr_1[nearest_index( arr_1, 50 )]
Out[17]: 49
或使用扁平数组:
In [25]: arr_3 = np.random.randint( 10, 100, size=( 2, 3, 4, ) )
In [26]: arr_3
Out[26]:
array([[[85, 51, 74, 79],
[63, 42, 27, 75],
[89, 68, 80, 63]],
[[85, 72, 74, 16],
[85, 22, 47, 78],
[44, 70, 98, 34]]])
In [27]: idx_flat = nearest_index( arr_3, 50, axis=None )
In [28]: idx_flat
Out[28]: 1
In [29]: idx = np.unravel_index( idx_flat, arr_3.shape )
In [30]: idx
Out[30]: (0, 0, 1)
In [31]: arr_3[idx]
Out[31]: 51
如何创建一个函数,它返回沿定义轴的值?
我尝试了Question的解决方案,但我只在axis=-1
工作了。
请注意,如果数组中的多个元素同样接近预期值,则只发现结果的第一次出现,这对我来说不是问题。
答案 0 :(得分:2)
对于多维数组,我们需要使用advanced-indexing
。因此,对于通用的n-dim
数组和指定的轴,我们可以这样做 -
def argmin_values_along_axis(arr, value, axis):
argmin_idx = np.abs(arr - value).argmin(axis=axis)
shp = arr.shape
indx = list(np.ix_(*[np.arange(i) for i in shp]))
indx[axis] = np.expand_dims(argmin_idx, axis=axis)
return np.squeeze(arr[indx])
样品运行 -
In [203]: arr_3 = np.random.randint( 10, 100, size=( 2, 3, 4, ) )
In [204]: arr_3
Out[204]:
array([[[94, 55, 26, 51],
[82, 66, 80, 66],
[96, 54, 93, 57]],
[[59, 28, 95, 56],
[47, 48, 17, 77],
[15, 57, 57, 25]]])
In [205]: argmin_values_along_axis(arr_3, value=50, axis=0)
Out[205]:
array([[59, 55, 26, 51],
[47, 48, 80, 66],
[15, 54, 57, 57]])
In [206]: argmin_values_along_axis(arr_3, value=50, axis=1)
Out[206]:
array([[82, 54, 26, 51],
[47, 48, 57, 56]])
In [207]: argmin_values_along_axis(arr_3, value=50, axis=2)
Out[207]:
array([[51, 66, 54],
[56, 48, 57]])
答案 1 :(得分:0)
这对我有用。
def nearest_index(arr, value, axis=None):
return np.argmin(np.abs( arr - value ), axis=axis)
>>> X
array([[76, 94, 56, 93, 28, 0, 44, 50, 89, 93],
[80, 99, 29, 98, 39, 27, 55, 70, 19, 76],
[87, 7, 28, 78, 47, 95, 34, 97, 66, 27],
[75, 78, 82, 30, 15, 0, 2, 25, 58, 69],
[31, 2, 34, 1, 56, 7, 87, 78, 32, 77],
[89, 80, 76, 97, 49, 18, 62, 35, 94, 41],
[ 2, 44, 83, 3, 64, 4, 49, 93, 46, 8],
[51, 63, 45, 57, 77, 90, 93, 4, 26, 81],
[43, 92, 22, 98, 93, 36, 46, 25, 35, 36],
[30, 14, 42, 91, 86, 14, 78, 9, 37, 19]])
>>> X[nearest_index(X, 2, axis=0), np.arange(10)]
array([ 2, 2, 22, 1, 15, 0, 2, 4, 19, 8])
>>> X[np.arange(10), nearest_index(X, 2, axis=1)]
array([ 0, 19, 7, 2, 2, 18, 2, 4, 22, 9])