这是我的previous question的后续行动:
我正在尝试使用Numba和Dask来加快慢速计算的速度,类似于计算大量点集的kernel density estimate。我的计划是在
jit
ed函数中编写计算量大的逻辑,然后使用dask
在CPU内核之间分配工作。我想使用nogil
函数的numba.jit
功能,以便可以使用dask
线程后端,以避免不必要的输入数据存储副本(非常大)。 / p>
It turns out,问题的一部分与在jit
ed函数中解压缩参数有关。但是,即使进行了此修复,我仍然看不到以下代码的加速:
import os
import time
import numpy as np
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from numba import njit, jit
CPU_COUNT = os.cpu_count()
print("CPU_COUNT", CPU_COUNT)
def PE(pool, func, args):
with pool(max_workers=CPU_COUNT) as exc:
fut = {exc.submit(func, *arg): i for i, arg in enumerate(args)}
for f in as_completed(fut):
f.result()
def render(params, mag):
"""Render gaussian peaks in small windows"""
radius = 3
for i in range(len(params)):
y0 = params[i, 0] * mag
x0 = params[i, 1] * mag
sy = params[i, 2] * mag
sx = params[i, 3] * mag
# calculate the render window size
wy = int(sy * radius * 2.0)
wx = int(sx * radius * 2.0)
# calculate the area in the image
ystart = int(np.rint(y0)) - wy // 2
yend = ystart + wy
xstart = int(np.rint(x0)) - wx // 2
xend = xstart + wx
# adjust coordinates to window coordinates
y1 = y0 - ystart
x1 = x0 - xstart
y = np.arange(wy)
x = np.arange(wx)
amp = 1 / (2 * np.pi * sy * sx)
gy = np.exp(-((y - y0) / sy) ** 2 / 2)
gx = np.exp(-((x - x0) / sx) ** 2 / 2)
g = amp * np.outer(gy, gx)
jit_render = jit(render, nopython=True, nogil=True)
args = [(np.random.rand(1000000, 4) * (1, 1, 0.02, 0.02), 100) for i in range(CPU_COUNT)]
print("Single time:")
# %timeit render(*args[0])
%timeit jit_render(*args[0])
print()
print("Linear time:")
%time [jit_render(*a) for a in args]
print()
print("Threads time:")
%time PE(ThreadPoolExecutor, jit_render, args)
在我的8核MacBook上,速度提高了大约2倍
Single time:
1.6 s ± 153 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Linear time:
CPU times: user 11.8 s, sys: 43.1 ms, total: 11.9 s
Wall time: 11.9 s
Threads time:
CPU times: user 45.4 s, sys: 125 ms, total: 45.5 s
Wall time: 6.29 s
在我的24核心Windows机器上,我的速度提高了大约1倍:
Single time:
1.91 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Linear time:
Wall time: 1min 30s
Threads time:
Wall time: 1min 4s