我有一个Python代码,该代码重新绑定了numpy数组以进行下采样。但是,当输入数组很大时,必须通过将数据窗口化为较小的块来循环执行。
我现在正在寻找的是一种使用多重处理来加速rebin
函数的方法。当前,该功能使用单个CPU。我对多处理模块有一定的了解,但我很乐意提供有关将以下代码转换为使用多处理的建议。
注意:我很欣赏建议(可能更多) pythonic)可以以不同的方式加速以下功能。但是,我 我也很想知道是否有人可以帮助升级以下代码 适合多处理。
MWE:
import numpy as np
from tqdm.auto import tqdm
def rebin(arr, new_shape):
shape = (arr.shape[0], new_shape[0], arr.shape[1] // new_shape[0],
new_shape[1], arr.shape[2] // new_shape[1])
return arr.reshape(shape).mean(-1).mean(2)
def rebinner(arr, window):
# w is the window size used to perform rebin on a fixed size window as the entire
# data cannot be fit in memory for large data size.
y = []
for j in tqdm(range(0, arr.shape[0], window)):
y.append(rebin(arr[j:j + window, :, None]*arr[j:j + window, None, :], [32, 32]))
return np.concatenate(y, axis=0)
arr = np.random.random((1000,2560,))
以下代码在jupyter笔记本单元中运行,以检查执行时间。也可以在使用timeit的脚本中使用。
%%time
print(rebinner(arr, window=10).shape)
预期输出:
(1000, 32, 32)
CPU times: user 8.52 s, sys: 4.58 s, total: 13.1 s
Wall time: 13.1 s
使用@ john-zwinck建议的numba库输出
基于@johnzwinck的评论,我对代码进行了小幅更新,并包括了numa装饰器。但是,新脚本会引发断言错误,而我不确定是什么原因引起的。下面是更新的代码和相应的错误消息。
import numba
import numpy as np
@numba.njit(nopython=True)
def rebinner2(arr, new_shape):
shape = (arr.shape[0], new_shape[0], arr.shape[1] // new_shape[0],
new_shape[1], arr.shape[2] // new_shape[1])
return arr.reshape(shape).mean(-1).mean(2)
@numba.njit(nopython=True)
def rebinner1(arr, window):
return [rebinner2(np.random.random((window, 1, 2560))*np.random.random((window, 2560, 1)),
[32, 32]) for j in range(0, arr.shape[0], window)]
# return [rebinner2(arr[j:j + window, :, None]*arr[j:j + window, None, :],
# [32, 32]) for j in range(0, arr.shape[0], window)]
def rebinner(arr, window):
return np.concatenate(rebinner1(arr,window), axis=0)
if __name__ == "__main__":
arr = np.random.random((1000, 2560))
print(rebinner(arr, window=10).shape)
# print(rebinner1(arr, window=10).shape)
输出:
venv/lib/python3.6/site-packages/numba/core/decorators.py:252: RuntimeWarning: nopython is set for njit and is ignored
warnings.warn('nopython is set for njit and is ignored', RuntimeWarning)
Traceback (most recent call last):
File "numba_tester.py", line 28, in <module>
print(rebinner(arr, window=10).shape)
File "numba_tester.py", line 23, in rebinner
return np.concatenate(rebinner1(arr,window), axis=0)
File "venv/lib/python3.6/site-packages/numba/core/dispatcher.py", line 415, in _compile_for_args
error_rewrite(e, 'typing')
File "venv/lib/python3.6/site-packages/numba/core/dispatcher.py", line 358, in error_rewrite
reraise(type(e), e, None)
File "venv/lib/python3.6/site-packages/numba/core/utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
AssertionError()
- Resolution failure for non-literal arguments:
AssertionError()
During: resolving callee type: BoundFunction(array.mean for array(float64, 5d, C))
During: typing of call at numba_tester.py (9)
File "numba_tester.py", line 9:
def rebinner2(arr, new_shape):
<source elided>
new_shape[1], arr.shape[2] // new_shape[1])
return arr.reshape(shape).mean(-1).mean(2)
^
During: resolving callee type: type(CPUDispatcher(<function rebinner2 at 0x7f6c5743e730>))
During: typing of call at numba_tester.py (17)
During: resolving callee type: type(CPUDispatcher(<function rebinner2 at 0x7f6c5743e730>))
During: typing of call at numba_tester.py (17)
File "numba_tester.py", line 17:
def rebinner1(arr, window):
<source elided>
return [rebinner2(np.random.random((window, 1, 2560))*np.random.random((window, 2560, 1)),
[32, 32]) for j in range(0, arr.shape[0], window)]
注意:上面的代码在没有numba装饰器的情况下可以正常工作。我通过使用随机数组而不是实际的切片转置乘法运算来简化rebinner2
(以检查是否引起了numba问题,但不是)。
经过修改的numba代码可以正常工作,但性能没有提高。
按照@JohnZwinck的建议,我将njit
更改为jit
。这将禁用nopython
模式。但是,在检查执行所需的时间时,numba方法现在似乎表现较差(可能是因为禁用nopython会失去优化)。
使用Numba:
(5000, 32, 32)
real: 1m11.538s
user: 0m49.137s
sys : 0m22.872s
没有Numba:
(5000, 32, 32)
real: 1m2.439s
user: 0m41.721s
sys : 0m21.152s
答案 0 :(得分:1)
只需删除不必要的tqdm
(进度条),然后添加Numba即可加快速度(未经测试,但应接近所需的速度):
import numba
@numba.njit
def rebinner2(arr, new_shape):
shape = (arr.shape[0], new_shape[0], arr.shape[1] // new_shape[0],
new_shape[1], arr.shape[2] // new_shape[1])
return arr.reshape(shape).mean(-1).mean(2)
@numba.njit
def rebinner1(arr, window):
return [rebinner2(arr[j:j + w, :, None]*arr[j:j + w, None, :], [32, 32]) for j in range(0, arr.shape[0], window)]
def rebinner(arr, window):
return np.concatenate(rebinner1, axis=0)
我将np.concatenate()
包装在非Numba函数中,因为我不确定Numba是否支持axis
参数。
为此使用多个内核可能是没有意义的,除非您的数据确实是巨大的,在这种情况下,您只需在每个进程中加载一部分数据。但是Numba将使循环的速度远远快于多核。