我在以下代码的上下文中理解unravel_index的输出时遇到了一些麻烦。
使用meshgrid我创建了两个代表某些坐标的数组:
import numpy as np
x_in=np.arange(-800, 0, 70)
y_in=np.arange(-3500, -2000, 70)
y, x =np.meshgrid(y_in,x_in,indexing='ij')
然后我通过其中一个网格来识别某些限制内的值:
limit=100
x_gd=x[np.logical_and(x>=-600-limit,x<=-600+limit)]
这将返回一个包含我感兴趣的值的数组 - 为了得到这些值的索引,我使用了以下函数(我在阅读this后开发):
def get_index(array, select_array):
'''
Find the index positions of values from select_array in array
'''
rows,cols=array.shape
flt = array.flatten()
sorted = np.argsort(flt)
pos = np.searchsorted(flt[sorted], select_array)
indices = sorted[pos]
y_indx, x_indx = np.unravel_index(indices, [rows, cols])
return y_indx, x_indx
xx_y_indx, xx_x_indx = get_index(x, x_gd)
xx_x_indx返回我期望的内容 - 来自x:
的值的col参考array([2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3,
4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2,
3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4], dtype=int64)
xx_y_indx然而返回:
array([15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2,
19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15,
2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19,
15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19, 15, 2, 19], dtype=int64)
当我希望它显示所有行时,因为数组x表示的坐标每行都相同 - 而不仅仅是第15,2和19行。
对于我感兴趣的内容,我可以使用xx_x_indx的结果 - 列索引。但是,我无法解释为什么y(行)索引报告的原因。
答案 0 :(得分:1)
对searchsorted
的此次调用未找到selected_array
中 flt[sorted]
的每个出现位置;它找到第一个出现的索引。
pos = np.searchsorted(flt[sorted], select_array)
In [273]: pos
Out[273]:
array([44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66,
88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44,
66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88,
44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88, 44, 66, 88])
注意pos
中的所有重复值。
过去这一点的所有内容可能都不符合您的意图,因为您并未真正使用select_array
或flt[sorted]
中array
值的所有位置。
您可以使用以下方法解决问题:
def get_index(array, select_array):
'''
Find the index positions of values from select_array in array
'''
mask = np.logical_or.reduce([array==val for val in np.unique(select_array)])
y_indx, x_indx = np.where(mask)
return y_indx, x_indx
或
def get_index2(array, select_array):
idx = np.in1d(array.ravel(), select_array.ravel())
y_indx, x_indx = np.where(idx.reshape(array.shape))
return y_indx, x_indx
哪个更快取决于np.unique(select_array)
中的元素数量。当它很大时,使用for-loop
会更慢,因此get_index2
更快。但如果select_array
中有很多重复,而np.unique(select_array)
很小,那么get_index
可能是更快的选择。
要演示np.unravel_index
的使用,您甚至可以使用
def get_index3(array, select_array):
idx = np.in1d(array.ravel(), select_array.ravel())
y_indx, x_indx = np.unravel_index(np.where(idx), array.shape)
return y_indx, x_indx
但我认为这比get_index2
慢reshape
,因为np.where
非常快,因此使用reshape
np.where
比使用np.unravel_index
要快{{1}}。