我正在使用索引数组(或列表)乱序遍历数组,但是对于我的一个计算,我需要数组(或列表)中前一个点的值。出于性能原因,我需要使用noPython模式执行此操作。
例如:
ordered_pts = [0, 1, 6, 2, 3, 7, 4, 5, 8, 9, 10, 11]
start = [True, False, False, False, False, False, False, True, False, False, False]
@jit(nogil=True, nopython=True)
def loop_in_order(ordered_pts, start):
for pt in ordered_pts:
if start[pt]:
prev_pt = None
else:
prev_pt = ordered_points[ordered_points.index(pt) - 1]
print pt, prev_pt
虽然这个函数在非jitted时有效,但是我得到一个无法将native?int64转换为Python对象的错误。
在numba中仍然使用noPython模式时,是否有一种有效的方法来获取列表中的上一项?或者我应该只为我的numba函数提供以前的索引列表以及我当前的索引吗?
此外,如果可能的话,我更愿意将数据类型保留为数组。 (我知道.index函数用于列表而不是数组)
答案 0 :(得分:1)
index
触发二次复杂度,其中存在准线性Numpy解:
def numpy_loop_in_order(ordered_pts,start):
prev_pt=np.roll(ordered_pts,1) # roll indices
prev_pt[np.argsort(ordered_pts)[np.where(start)]] =-1 # see below
return prev_pt
#[-1, 0, 1, 6, 2, -1, 7, 4, 5, 8, 9, 10]
我添加了一个元素以None
开始替换-1
,因为Numpy数组必须是同构的。 np.argsort
找到O( n ln(n) )
中所有值的索引,np.where
确定将设置哪些值。
此外,Numba只会加速np.arrays上的代码,而不是列表。并且通常仅在Numpy工具无法解决问题时才有用。但是你可以稍微改进你的代码,因为ordered_points与range(len(ordered_points))的元素相同:
@jit(nogil=True, nopython=True)
def loop_in_order(ordered_pts, start):
reverse_index = np.empty_like(ordered_pts)
prev_pt = np.empty_like(ordered_pts)
prev = -1
for i,pt in enumerate(ordered_pts):
reverse_index[pt] = i
prev_pt[i] = prev
prev = pt
for pt,i in enumerate(reverse_index):
if start[pt]:
prev_pt[i] = -1
return prev_pt
测试1000分:
op=np.arange(1000)
np.random.shuffle(op)
ordered_pts=list(op)
start=np.random.randint(0,2,1000,dtype=bool)
In [614]: %timeit loop_in_order(ordered_pts,start)
9.52 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [615]: %timeit numpy_loop_in_order(op,start)
43.6 µs ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [616]: %timeit numba_loop_in_order(op,start)
4.73 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
主要改进是由于放弃了index
。