我正在尝试使用python`'mp``加快引导程序估计的计算。但是我没有看到任何加快。 我尝试将引导程序应用到的数据被网格化(时间,纬度,经度),并使用xarray从netcdf读取。
这就是我所做的。
import multiprocessing as mp
import xarray as xr
def boot_mean(idata):
return(idata.mean(dim='boot_ax'))
def process_boot(nsample,bfun,idata):
ind_boot = np.random.choice(len(idata['boot_ax']),nsample)
return(bfun(idata.isel(boot_ax=ind_boot)))
#geo bootstrap
def geo_bootstrap(idata_raw,bfun,nsample=0,nboot=1000,scoord='time',\
np=1):
idata = idata_raw.rename({scoord:'boot_ax'})
if nsample==0:
nsample=len(idata['boot_ax'])
if np==1:
C = xr.concat([process_boot(nsample,bfun,idata) \
for x in range(nboot)],dim='boot')
else:
pool = mp.Pool(processes=np)
results = [pool.apply_async(process_boot,args=(nsample,bfun,idata)) for x in range(nboot)]
C = xr.concat([p.get() for p in results],dim='boot')
pool.close()
return(C)
使用1个cpu,代码大约需要7s
%%time
O = geo_bootstrap(INPUT,boot_mean,nboot=20)
CPU times: user 6.79 s, sys: 268 ms, total: 7.06 s
Wall time: 7.1 s
使用4 cpus时,代码花费的时间更长,我无法完全理解
%%time
O = geo_bootstrap(INPUT,boot_mean,nboot=20,np=4)
CPU times: user 2.14 s, sys: 4.34 s, total: 6.49 s
Wall time: 8.44 s
我正在运行的计算机上有足够的内存。这是我第一次尝试使用mp
,但不确定是否存在瓶颈。
INPUT是一个xarray数据集
<xarray.Dataset>
Dimensions: (bnds: 2, time: 15, xh: 720, yh: 576)
Coordinates:
* time (time) object 1990-07-02 12:00:00 ... 1994-07-02 12:00:00
* xh (xh) float64 -299.8 -299.2 -298.8 ... 58.75 59.25 59.75
* yh (yh) float64 -77.91 -77.72 -77.54 ... 89.47 89.68 89.89
x (yh, xh) float64 -299.8 -299.2 -298.8 ... 59.99 59.99 60.0
y (yh, xh) float64 -77.91 -77.91 -77.91 ... 65.18 64.97
Dimensions without coordinates: bnds
Data variables:
time_bnds (time, bnds) object 1990-01-01 00:00:00 ... 1995-01-01 00:00:00
dep_n (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
tot_fsn (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
epc100 (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
nh4_stf (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
wc_vert_int_nfix (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
no3os (time, yh, xh) float32 nan nan nan nan ... nan nan nan nan
其他信息
我认为问题可能在于该数组已传递给每个子进程。 如果我重新定义功能,如下所示:
def process_boot(nsample):
ind_boot = np.random.choice(len(INPUT['time']),nsample)
return(INPUT.isel(time=ind_boot).mean(dim='time'))
#geo bootstrap
def geo_bootstrap(idata_raw,bfun,nsample=0,nboot=1000,scoord='time',\
np=1):
'''bootstrap estimates of time,lat,lon dataset
'''
idata = idata_raw.rename({scoord:'boot_ax'})
if nsample==0:
nsample=len(idata['boot_ax'])
if np==1:
C = xr.concat([process_boot(nsample) for x in range(nboot)],dim='boot')
else:
pool = mp.Pool(processes=np)
results = [pool.apply_async(process_boot,args=(nsample,)) for x in range(nboot)]
C = xr.concat([p.get() for p in results],dim='time')
pool.close()
return(C)
我的速度很好
%%time
O = geo_bootstrap(historical_diff_ts,boot_mean,nboot=20,np=5)
CPU times: user 350 ms, sys: 585 ms, total: 935 ms
Wall time: 2.41 s
但是随后的代码并没有我想要的那么模块化。