python xarray索引/切片很慢

时间:2016-08-31 15:55:08

标签: python numpy netcdf python-xarray

我目前正在处理一些海洋模型输出。在每个时间步,它有42 * 1800 * 3600网格点。

我发现程序中的bottelneck是切片,并在方法中调用xarray_built来提取值。更有趣的是,相同的语法有时需要大量不同的时间。

ds = xarray.open_dataset(filename, decode_times=False)
vvel0=ds.VVEL.sel(lat=slice(-60,-20),lon=slice(0,40))/100    #in CCSM output, unit is cm/s convert to m/s
uvel0=ds.UVEL.sel(lat=slice(-60,-20),lon=slice(0,40))/100   ## why the speed is that different? now it's regional!!
temp0=ds.TEMP.sel(lat=slice(-60,-20),lon=slice(0,40)) #de

以此为例,读取VVEL和UVEL需要大约4秒,而在TEMP中读取只需要大约6ms。没有切片,VVEL和UVEL需要约1秒,而TEMP需要120纳秒。

我一直认为,当我只输入完整数组的一部分时,我需要更少的内存,因此更少的时间。事实证明,XARRAY加载了完整的数组,任何额外的切片都需要更多的时间。但是,有人可以解释为什么从同一个netcdf文件中读取不同的变量会花费不同的时间吗?

该程序旨在提取逐步截面,并计算横截面热传输,因此我需要选择UVEL或VVEL,这是TEMP沿截面的时间。因此,似乎在TEMP中加载快速好,不是吗?

不幸的是,事实并非如此。当我沿着规定的部分循环约250个网格点时......

# Calculate VT flux orthogonal to the chosen grid cells, which is the heat transport across GOODHOPE line
vtflux=[]
utflux=[]
vap = vtflux.append
uap = utflux.append
#for i in range(idx_north,idx_south+1):
for i in range(10):
    yidx=gh_yidx[i]
    xidx=gh_xidx[i]
    lon_next=ds_lon[i+1].values
    lon_current=ds_lon[i].values
    lat_next=ds_lat[i+1].values
    lat_current=ds_lat[i].values
    tt=np.squeeze(temp[:,yidx,xidx].values)  #<< calling values is slow
    if (lon_next<lon_current) and (lat_next==lat_current):   # The condition is incorrect
        dxlon=Re*np.cos(lat_current*np.pi/180.)*0.1*np.pi/180.
        vv=np.squeeze(vvel[:,yidx,xidx].values)  
        vt=vv*tt
        vtdxdz=np.dot(vt[~np.isnan(vt)],layerdp[0:len(vt[~np.isnan(vt)])])*dxlon
        vap(vtdxdz)
        #del  vtdxdz
    elif (lon_next==lon_current) and (lat_next<lat_current):
        #ut=np.array(uvel[:,gh_yidx[i],gh_xidx[i]].squeeze().values*temp[:,gh_yidx[i],gh_xidx[i]].squeeze().values) # slow
        uu=np.squeeze(uvel[:,yidx,xidx]).values  # slow
        ut=uu*tt
        utdxdz=np.dot(ut[~np.isnan(ut)],layerdp[0:len(ut[~np.isnan(ut)])])*dxlat
        uap(utdxdz) #m/s*degC*m*m ## looks fine, something wrong with the sign
        #del utdxdz
total_trans=(np.nansum(vtflux)-np.nansum(utflux))*3996*1026/1e15

特别是这一行:

tt=np.squeeze(temp[:,yidx,xidx].values)

需要〜3.65秒,但现在必须重复约250次。如果我删除.values,则此时间减少到约4毫秒。但我需要将tt计时到vt,所以我必须提取值。奇怪的是,类似的表达式vv=np.squeeze(vvel[:,yidx,xidx].values)需要更少的时间,只需约1.3毫秒。

总结我的问题:

  1. 为什么从同一个netcdf文件加载不同的变量需要不同的时间?
  2. 在多维数组中选择单个列是否有更有效的方法? (不一定是xarray结构,也是numpy.ndarray)
  3. 为什么从Xarray结构中提取值需要不同的时间,对于完全相同的语法?
  4. 谢谢!

1 个答案:

答案 0 :(得分:3)

当您索引从netCDF文件加载的变量时,xarray不会立即将其加载到内存中。相反,我们创建一个惰性数组,支持任何数量的进一步不同的索引操作。即使您未使用dask.array也是如此(通过在chunks=中设置open_dataset或使用open_mfdataset来触发)。

这解释了您观察到的惊人表现。计算temp0很快,因为它不会从磁盘加载任何数据。 vvel0很慢,因为除以100需要将数据作为numpy数组加载到内存中。

稍后,索引temp0的速度较慢,因为每个操作都从磁盘加载数据,而不是索引已经在内存中的numpy数组。

解决方法是首先将您需要的数据集部分显式加载到内存中,例如,通过编写temp0.load()。 xarray文档的netCDF section也给出了这个提示。