Numba循环中的先前值

时间:2017-11-15 09:19:47

标签: python arrays loops numba

我正在使用索引数组(或列表)乱序遍历数组,但是对于我的一个计算,我需要数组(或列表)中前一个点的值。出于性能原因,我需要使用noPython模式执行此操作。

例如:

ordered_pts = [0, 1, 6, 2, 3, 7, 4, 5, 8, 9, 10, 11]
start = [True, False, False, False, False, False, False, True, False, False, False]


@jit(nogil=True, nopython=True)
def loop_in_order(ordered_pts, start):
    for pt in ordered_pts:
        if start[pt]:
            prev_pt = None
        else:
            prev_pt = ordered_points[ordered_points.index(pt) - 1]
        print pt, prev_pt

虽然这个函数在非jitted时有效,但是我得到一个无法将native?int64转换为Python对象的错误。

在numba中仍然使用noPython模式时,是否有一种有效的方法来获取列表中的上一项?或者我应该只为我的numba函数提供以前的索引列表以及我当前的索引吗?

此外,如果可能的话,我更愿意将数据类型保留为数组。 (我知道.index函数用于列表而不是数组)

1 个答案:

答案 0 :(得分:1)

index触发二次复杂度,其中存在准线性Numpy解:

def numpy_loop_in_order(ordered_pts,start):        
    prev_pt=np.roll(ordered_pts,1)               # roll indices
    prev_pt[np.argsort(ordered_pts)[np.where(start)]] =-1  # see below
    return prev_pt
#[-1,  0,  1,  6,  2, -1,  7,  4,  5,  8,  9, 10]

我添加了一个元素以None开始替换-1,因为Numpy数组必须是同构的。 np.argsort找到O( n ln(n) )中所有值的索引,np.where确定将设置哪些值。

此外,Numba只会加速np.arrays上的代码,而不是列表。并且通常仅在Numpy工具无法解决问题时才有用。但是你可以稍微改进你的代码,因为ordered_points与range(len(ordered_points))的元素相同:

@jit(nogil=True, nopython=True)
def loop_in_order(ordered_pts, start):
    reverse_index = np.empty_like(ordered_pts)
    prev_pt = np.empty_like(ordered_pts)
    prev = -1
    for i,pt in enumerate(ordered_pts):
        reverse_index[pt] = i
        prev_pt[i] = prev
        prev = pt
    for pt,i in enumerate(reverse_index):
        if start[pt]:
            prev_pt[i] = -1
    return prev_pt

测试1000分:

op=np.arange(1000)
np.random.shuffle(op)
ordered_pts=list(op)
start=np.random.randint(0,2,1000,dtype=bool)


In [614]: %timeit loop_in_order(ordered_pts,start)
9.52 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [615]: %timeit numpy_loop_in_order(op,start)
43.6 µs ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [616]: %timeit numba_loop_in_order(op,start)
4.73 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

主要改进是由于放弃了index