加速 Python 中的嵌套 for 循环

时间:2021-01-23 15:10:08

标签: python numpy parallel-processing scipy python-xarray

我有一个用 python 编写的嵌套循环系统,如下所示:

for yt in range(dims[1]):
  for xt in range(dims[2]):
    for yp in range(dims[1]):
       for xp in range(dims[2]):
           corr[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre[:,yp,xp],axis=0)[0] 
           corr2[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre2[:,yp,xp],axis=0)[0]
           corr3[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre3[:,yp,xp],axis=0)[0]

其中 dims 的形状为 (1710, 69, 21) 并且 corr、corr2 和 corr3 都是 xarray Dataarray,其中包含形状为 (69,21,69,21) 的空 NumPy 数组。

现在,问题是这个脚本需要永远完成(~ 6+ 小时)。我不确定嵌套循环设置是否导致了它,或者 sp.spearmanr 是否是罪魁祸首(或者两者都有)。我正在寻找使这个运行更快的方法,具体来说,我想知道是否可以利用并行处理。也欢迎其他提示。提前致谢!

编辑:我还应该补充一点,prec_tar、prec_pre、prec_pre2 和 prec_pre3 都具有与 dims 相同的形状(即 (1710, 69, 21))。

3 个答案:

答案 0 :(得分:2)

您可以使用以下代码段使您的代码并行。

import time
import itertools
import multiprocessing

yt = range(2)
xt = range(2)
yp = range(2)
xp = range(2)

param_list = list(itertools.product(yt, xt, yp, xp))

def task(args):
    print(args)
    # task
    time.sleep(1)
    return args

pool = multiprocessing.Pool()

response = pool.map(task, param_list)
print(response)

答案 1 :(得分:0)

您可以在矢量化而不是循环代码时加快速度。

尝试使用矢量化和并行化 spearmanr 函数的 xski​​llscore。 https://xskillscore.readthedocs.io/en/stable/api/xskillscore.spearman_r.html#xskillscore.spearman_r

答案 2 :(得分:0)

这是基于@aaron.spring 建议的针对此问题的有效解决方案。我希望有一天这对某人有所帮助。

# Problem at hand: Very slow.
t1 = time.time()
for i in range(dims[1]):   #dims = ((1000, 4, 5))
    for j in range(dims[2]):
        for x in range(dims[1]):
            for y in range(dims[2]):
                acorrb[i,j,x,y] = spearmanr(a[:,i,j], b[:,x,y], dim='time')
t2 = time.time()
print(t2-t1)  # 0.3600752353668213

# Faster solution based on xarray's vectorized indexing and using  xskillscore.spearman_r instead of spearmanr from scipy.stats. 

ind_i = xr.DataArray(range(dims[1]), dims=['i'])
ind_j = xr.DataArray(range(dims[2]), dims=['j'])
ind_x = xr.DataArray(range(dims[1]), dims=['x'])
ind_y = xr.DataArray(range(dims[2]), dims=['y'])

t3 = time.time()
acorrb2[ind_i, ind_j, ind_x, ind_y]=spearmanr(a[:,ind_i,ind_j], b[:,ind_x,ind_y],dim='time')
t4 = time.time()
print(t4-t3) #0.07205533981323242

快 5 倍以上。

print((acorrb.values==acorrb2.values).all()) #True