假设我有以下功能:
def f(x,y):
return x*y
如何使用多处理模块将功能应用于NxM 2D numpy数组中的每个元素?使用串行迭代,代码可能如下所示:
import numpy as np
N = 10
M = 12
results = np.zeros(shape=(N,M))
for x in range(N):
for y in range(M):
results[x,y] = f(x,y)
答案 0 :(得分:4)
以下是如何使用multiprocesssing
并行化示例函数的方法。我还包括一个几乎相同的纯Python函数,它使用非并行for
循环,以及一个实现相同结果的numpy单行程:
import numpy as np
from multiprocessing import Pool
def f(x,y):
return x * y
# this helper function is needed because map() can only be used for functions
# that take a single argument (see http://stackoverflow.com/q/5442910/1461210)
def splat_f(args):
return f(*args)
# a pool of 8 worker processes
pool = Pool(8)
def parallel(M, N):
results = pool.map(splat_f, ((i, j) for i in range(M) for j in range(N)))
return np.array(results).reshape(M, N)
def nonparallel(M, N):
out = np.zeros((M, N), np.int)
for i in range(M):
for j in range(N):
out[i, j] = f(i, j)
return out
def broadcast(M, N):
return np.prod(np.ogrid[:M, :N])
现在让我们来看看表现:
%timeit parallel(1000, 1000)
# 1 loops, best of 3: 1.67 s per loop
%timeit nonparallel(1000, 1000)
# 1 loops, best of 3: 395 ms per loop
%timeit broadcast(1000, 1000)
# 100 loops, best of 3: 2 ms per loop
非并行纯Python版本比并行化版本大约4倍,使用numpy数组广播的版本绝对粉碎其他两个。
问题是启动和停止Python子进程会带来相当多的开销,而且你的测试函数是如此微不足道,以至于每个工作线程只花费其生命周期的一小部分来完成有用的工作。如果每个线程在被杀死之前都有大量的工作要做,那么多处理才有意义。例如,您可以为每个工作人员提供更大的输出数组以进行计算(尝试将chunksize=
参数弄乱到pool.map()
),但是有了这样一个简单的例子我怀疑你'我会看到一个很大的进步。
我不知道你的实际代码是什么样的 - 也许你的功能很大且价格昂贵,足以保证使用多处理。但是,我敢打赌,很多有更好的方法来改善其性能。
答案 1 :(得分:0)
在您的情况下,不确定是否需要多处理。在上面的简单示例中,您可以执行
X, Y = numpy.meshgrid(numpy.arange(10), numpy.arange(12))
result = X*Y