缩短CCH图的执行时间

时间:2018-10-24 19:43:18

标签: python-3.x performance numpy matplotlib numba

我正在尝试减少绘制以下函数的时间:

def cch(tau):
    return np.sum(abs(-1*np.diff(cartprod)-tau)<0.001)

“ cartprod”的缩写:

cartprod = np.asarray(list(itertools.product(times1,times2)))

times1和times2是列表,其元素从大约0.0123到大约99.9948最多散布0.25 ish。每个列表也有大约5000个元素。如果您来自神经科学背景,那么这就是高峰时间。注意:此信息对于问题是多余的,但仅对好奇的人有用。

我用以下绘图材料对其进行绘制:

t = np.linspace(-0.25,0.25,1250) 
vfunc = np.vectorize(cch)
y = vfunc(t)
plt.plot(t,y,'g')

绘图大约需要4分钟。我不太担心绘图时间(只要合理即可:在5-10分钟之内说出来)。我所关心的是,我将必须平均绘制这些功能中的10,000个以上的事实,而我需要能够快速地做到这一点。是否可以通过numba或任何算法增强的方法来加快对函数的每次调用?

谢谢

1 个答案:

答案 0 :(得分:0)

我无法使用np.vectorize来复制您的代码,因为t(1250,)的形状和cartprod(25000000, 2)的形状(我从长度上假定了您的列表times1times2)不匹配。

不幸的是,numpy中尚未实现加速此功能所需的numba函数。但是仍然在numpy中重写代码可以显着提高速度。在我的计算机上,cartprod的计算很容易加快大约25倍。

def cartprod(arr_times1, arr_times2):
    res = np.empty((arr_times1.size * arr_times2.size, 2))
    res[:, 0] = np.repeat(arr_times1, arr_times2.size)
    res[:, 1] = np.tile(arr_times2, arr_times1.size)
    return res
def cartprod_iter(times1, times2):
    return np.asarray(list(itertools.product(times1, times2)))

arr_times1 = np.random.rand(5000)
arr_times2 = np.random.rand(5000)
times1 = list(arr_times1)
times2 = list(arr_times2)

%timeit cartprod_iter(times1, times2)
%timeit cartprod(arr_times1, arr_times1)
# 12.9 s ± 954 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 521 ms ± 53.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
print(np.all(cartprod(arr_times1, arr_times1) == cartprod_iter(times1, times2)))
# Out: True

现在使用cch函数:
np.tilenp.repeat无法麻木。如果您想jit cch,则需要手工重写它们。或者,您可以jit函数的某些部分cch

import numba as nb
@nb.njit
def cch_core(cp, tau):
    return np.sum(np.abs(-1 * np.diff(cp) - tau) < 0.001)

def cch_nb(arr_times1, arr_times2, tau):
    cp = cartprod(arr_times1, arr_times2)
    return cch_core(cp, tau)

def cch(arr_times1, arr_times2, tau):
    return np.sum(np.abs(-1 * np.diff(cartprod(arr_times1, arr_times2)) - tau) < 0.001)

tau2 = np.linspace(-0.25, 0.25, 50)

%timeit cch_nb(arr_times1, arr_times2, tau2)
#2.81 s ± 144 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit cch(arr_times1, arr_times2, tau2)
#15.2 s ± 494 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
print(cch_nb(arr_times1, arr_times2, tau2) == cch(arr_times1, arr_times2, tau2))
# Out: True

这是另一个加快5.4倍的速度。我使用了tau的简化值来使计时成为可能。此外,形状为tau的{​​{1}}会在没有numba的情况下产生内存错误,但将与numba一起使用!

如果要提高速度,则需要自己在numba中实现(1250,)。 如果此代码无法正常工作:请考虑您在问题中张贴的代码无法复制。发布一个可以正常工作的最小示例可能会有所帮助。