我正在尝试减少绘制以下函数的时间:
def cch(tau):
return np.sum(abs(-1*np.diff(cartprod)-tau)<0.001)
“ cartprod”的缩写:
cartprod = np.asarray(list(itertools.product(times1,times2)))
times1和times2是列表,其元素从大约0.0123到大约99.9948最多散布0.25 ish。每个列表也有大约5000个元素。如果您来自神经科学背景,那么这就是高峰时间。注意:此信息对于问题是多余的,但仅对好奇的人有用。
我用以下绘图材料对其进行绘制:
t = np.linspace(-0.25,0.25,1250)
vfunc = np.vectorize(cch)
y = vfunc(t)
plt.plot(t,y,'g')
绘图大约需要4分钟。我不太担心绘图时间(只要合理即可:在5-10分钟之内说出来)。我所关心的是,我将必须平均绘制这些功能中的10,000个以上的事实,而我需要能够快速地做到这一点。是否可以通过numba或任何算法增强的方法来加快对函数的每次调用?
谢谢
答案 0 :(得分:0)
我无法使用np.vectorize
来复制您的代码,因为t
与(1250,)
的形状和cartprod
与(25000000, 2)
的形状(我从长度上假定了您的列表times1
和times2
)不匹配。
不幸的是,numpy
中尚未实现加速此功能所需的numba
函数。但是仍然在numpy
中重写代码可以显着提高速度。在我的计算机上,cartprod
的计算很容易加快大约25倍。
def cartprod(arr_times1, arr_times2):
res = np.empty((arr_times1.size * arr_times2.size, 2))
res[:, 0] = np.repeat(arr_times1, arr_times2.size)
res[:, 1] = np.tile(arr_times2, arr_times1.size)
return res
def cartprod_iter(times1, times2):
return np.asarray(list(itertools.product(times1, times2)))
arr_times1 = np.random.rand(5000)
arr_times2 = np.random.rand(5000)
times1 = list(arr_times1)
times2 = list(arr_times2)
%timeit cartprod_iter(times1, times2)
%timeit cartprod(arr_times1, arr_times1)
# 12.9 s ± 954 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 521 ms ± 53.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
print(np.all(cartprod(arr_times1, arr_times1) == cartprod_iter(times1, times2)))
# Out: True
现在使用cch
函数:
np.tile
和np.repeat
无法麻木。如果您想jit
cch
,则需要手工重写它们。或者,您可以jit
函数的某些部分cch
:
import numba as nb
@nb.njit
def cch_core(cp, tau):
return np.sum(np.abs(-1 * np.diff(cp) - tau) < 0.001)
def cch_nb(arr_times1, arr_times2, tau):
cp = cartprod(arr_times1, arr_times2)
return cch_core(cp, tau)
def cch(arr_times1, arr_times2, tau):
return np.sum(np.abs(-1 * np.diff(cartprod(arr_times1, arr_times2)) - tau) < 0.001)
tau2 = np.linspace(-0.25, 0.25, 50)
%timeit cch_nb(arr_times1, arr_times2, tau2)
#2.81 s ± 144 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit cch(arr_times1, arr_times2, tau2)
#15.2 s ± 494 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
print(cch_nb(arr_times1, arr_times2, tau2) == cch(arr_times1, arr_times2, tau2))
# Out: True
这是另一个加快5.4倍的速度。我使用了tau
的简化值来使计时成为可能。此外,形状为tau
的{{1}}会在没有numba的情况下产生内存错误,但将与numba一起使用!
如果要提高速度,则需要自己在numba中实现(1250,)
。
如果此代码无法正常工作:请考虑您在问题中张贴的代码无法复制。发布一个可以正常工作的最小示例可能会有所帮助。