有关如何加快此python函数速度的建议?

时间:2020-06-18 16:39:39

标签: python python-2.7 performance cython numba

关于如何加快此功能的任何建议?

def smooth_surface(z,c):
    hph_arr_list = []
    for x in xrange(c,len(z)-(c+1)):
        new_arr = np.hstack(z[x-c:x+c])
        hph_arr_list.append(np.percentile(new_arr[((new_arr >= np.percentile(new_arr,15)) & (new_arr <= np.percentile(new_arr,85)))],99))
    return np.array(map(float,hph_arr_list))

变量z的长度约为1500万,c是窗口大小+-的值。该函数基本上是一个滑动窗口,可计算每次迭代的百分位值。任何帮助,将不胜感激! z是一个数组数组(因此np.hstack)。也许numba会对此有所帮助。如果可以,该如何实施?

1 个答案:

答案 0 :(得分:1)

计算的最慢部分似乎是行np.percentile(new_arr[((new_arr >= np.percentile(new_arr,15)) & (new_arr <= np.percentile(new_arr,85)))],99)。这是由于小型阵列上的np.percentile异常慢,以及创建了多个中间阵列。

由于new_arr实际上很小,因此对它进行排序并自己进行插值要快得多。而且,numba还可以帮助加快计算速度。

@njit #Use @njit instead of @jit to increase speed
def filter(arr):
    arr = arr.copy() # This line can be removed to modify arr in-place
    arr.sort()
    lo = int(math.ceil(len(arr)*0.15))
    hi = int(len(arr)*0.85)
    interp = 0.99 * (hi - 1 - lo)
    interp = interp - int(interp)
    assert lo <= hi-2
    return arr[hi-2]* (1.0 - interp) + arr[hi-1] * interp

在我的计算机上,使用大小为20的数组,此代码快160倍,并且应产生相同的结果。

最后,您也可以通过在numba中使用自动并行化来加快smooth_surface的速度(有关更多信息,请参见here)。这是未经测试的原型:

@jit(parallel=True)
def smooth_surface(z,c):
    hph_arr = np.zeros(len(z)-(c+1)-c)
    for x in prange(c,len(z)-(c+1)):
        hph_arr[x-c] = filter(np.hstack(z[x-c:x+c]))
    return hph_arr