如何将mpi4py应用于以下循环以加快速度?

时间:2018-12-16 21:48:13

标签: python mpi4py

过去,我不需要使用并行计算,因为我的python脚本通常不需要大量计算。最近,我编写了以下for循环,发现运行大约需要10到15分钟:

cpx_var = np.linspace(0.20,0.80,301)
horn_var = np.linspace(0.20,0.80,301)
plag_var = np.linspace(0.05,0.1,26)
mag_var = np.linspace(0.02,0.06,21)
ap_var = np.linspace(0.002,0.006,3)

poss_comb = []
count=0

for i in cpx_var:
    for j in horn_var:
        for k in plag_var:
            for l in mag_var:
                for m in ap_var:
                    count = count+1
                    if abs((i+j+k+l+m)-1.0)<0.002:
                        poss_comb.append([i,j,k,l,m])
print(count)

我一直在研究mpi4py作为加快for-loop速度的一种方法,但是我不确定如何将其应用于示例。有人有建议吗?

谢谢!

Zack Eriksen

2 个答案:

答案 0 :(得分:0)

尽管多处理是一种加快代码速度的方法,但首先应该尝试使代码效率更高。您正在使用5个嵌套循环,在Python中这很疯狂!在Python中,您应该尽可能地使用数组来实现代码,而不是for循环。

这里是一个解决方案,其中我只在最后使用一个for循环(实际上是列表理解,这是for循环的稍快版本)。我使用了np.mesh函数,然后进行了一些切片以摆脱您拥有的所有循环,而仅使用列表推导将所有结果放入所需的列表格式。

因此请试用该版本。在笔记本电脑上运行只需要20秒左右

import numpy as np

cpx_var = np.linspace(0.20,0.80,301)
horn_var = np.linspace(0.20,0.80,301)
plag_var = np.linspace(0.05,0.1,26)
mag_var = np.linspace(0.02,0.06,21)
ap_var = np.linspace(0.002,0.006,3)

cpx_vars, horn_vars, plag_vars, mag_vars, ap_vars = np.meshgrid(cpx_var, horn_var, plag_var, mag_var, ap_var)

selection = abs((cpx_vars + horn_vars + plag_vars + mag_vars + ap_vars)-1.0)<0.002

poss_comb = [list(comb) for comb in zip(cpx_vars[selection], horn_vars[selection], plag_vars[selection], mag_vars[selection], ap_vars[selection])]

答案 1 :(得分:0)

我找到了另一种方法来将您的代码再提高23倍。自从回答了这个问题以来,我一直在学习有关称为numba的Python软件包的更多信息。该软件包确实非常聪明,它实际上将您的Python函数转换为机器代码,并且可以极大地提高您的速度(具体取决于您的工作)。

我在玩这个软件包,在这里想到了您的问题,我想看看numba与使用numpy数组相比能做到多快。在您的代码,我在上面的答案中粘贴的numpy版本以及代码的numba优化版本之间进行比较。我发现numpy(无循环)版本比您的循环版本快23倍,但numba版本却快455倍!

无论如何,我认为您可能希望看到这个。 numba版本比上面我的答案快23倍。以下是我正在使用的代码的3个版本。我将代码放入名为original_loop_function的函数中,上面在函数numpy_function中发布了numpy版本,而代码的numba更快版本是函数numba_loop_function。希望这对您仍然有用

import numpy as np
from numba import jit

cpx_var = np.linspace(0.20,0.80,301)
horn_var = np.linspace(0.20,0.80,301)
plag_var = np.linspace(0.05,0.1,26)
mag_var = np.linspace(0.02,0.06,21)
ap_var = np.linspace(0.002,0.006,3)


def original_loop_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
    '''Your original version of this code'''
    poss_comb = []
    count=0

    for i in cpx_var:
        for j in horn_var:
            for k in plag_var:
                for l in mag_var:
                    for m in ap_var:
                        count = count+1
                        if abs((i+j+k+l+m)-1.0)<0.002:
                            poss_comb.append([i,j,k,l,m])
    return poss_comb

def numpy_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
    '''Numpy version of your code with no loops'''
    cpx_vars, horn_vars, plag_vars, mag_vars, ap_vars = np.meshgrid(cpx_var, horn_var, plag_var, mag_var, ap_var)

    selection = abs((cpx_vars + horn_vars + plag_vars + mag_vars + ap_vars)-1.0)<0.002

    poss_comb = [list(comb) for comb in zip(cpx_vars[selection], horn_vars[selection], plag_vars[selection], mag_vars[selection], ap_vars[selection])]

    return poss_comb


@jit(nopython=True)
def numba_loop_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
    '''Numba optimised version of your code'''
    poss_comb = []
    count=0

    for i in cpx_var:
        for j in horn_var:
            for k in plag_var:
                for l in mag_var:
                    for m in ap_var:
                        count = count+1
                        if abs((i+j+k+l+m)-1.0)<0.002:
                            poss_comb.append([i,j,k,l,m])
    return poss_comb