如何使用map_blocks触发映射到dask数组的函数并将结果存储到xarray或dask数组而不使用numpy array?

时间:2019-03-29 19:47:45

标签: python multithreading dask

我有2个矩阵(dataresult),result是2D dask.array,而data是3D xarray.DataArray,我必须做这样的计算:

var_idx = 0 # const value
result[i,j] = external_function(data[var_idx,:,i],data[var_idx,:,j])

我是python和dask的新手,但是经过数周的研究,下面的代码代表了我可以做的这项计算...

def func_block(block,block_info=None):

    @jit(nogil=True)     
    def func_external(a,b):
        # just an example
        return np.max(np.multiply(a, b))


    # result element location of data array
    # to compute just lower triangular elements     
    block_info = block_info[0]
    [(s_row,end_row),(s_col,end_col)]= block_info['array-location']
    if s_col > s_row:
        return block

    it = np.nditer(block, flags=['multi_index'],op_flags=['readwrite'])
    while not it.finished:
        (r_idx,c_idx) = it.multi_index
        row = r_idx+s_row
        col = c_idx+s_col
        if  row > col:
            it[0] = float(func_external(da_stacked_notnull[0,:,row],da_stacked_notnull[0,:,col]))
        it.iternext()

    return block

darr_zeros = da.zeros((grid_size,grid_size), chunks=(3000,3000))
darr_zeros = darr_zeros -1
dask_result = darr_zeros.map_blocks(func_block,chunks=(3000,3000),dtype=np.float16)
dask_result = xr.DataArray(dask_result.compute())

但是,我有一些问题:

1)如果我使用了所有需要的数据,则会收到Python内存错误。我想这是由于dask_result.compute()导致的,如果我正确理解.compute()返回一个numpy数组,但是我没有足够的内存来将所有结果存储在numpy数组中。如何使用dask数组来完成此操作?

2)所有线程最多使用每个内核的50%...,我认为这是由于GIL造成的,但是dask并没有改善它?可以重构它以获得更好的性能吗?

这是没有我真正需要的所有网格点的数据矩阵:

da_stacked_notnull
Out[1]: 
<xarray.DataArray (variable: 1, time: 365, gridcell: 7230)>
array([[[-0.376704, -0.036332, ..., 27.715254, 26.863554],
        [-0.465122, -0.152866, ..., 27.227764, 26.556808],
        ...,
        [-0.724707, -0.520708, ..., 29.315022, 29.10007 ],
        [-0.835325, -0.704899, ..., 29.425072, 29.086765]]], dtype=float32)
Coordinates:
  * time      (time) datetime64[ns] 2015-01-01 2015-01-02 ... 2015-12-31
  * gridcell  (gridcell) MultiIndex
  - lon       (gridcell) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
  - lat       (gridcell) float64 -59.5 -58.5 -57.5 -56.5 ... -32.5 -31.5 -30.5
  * variable  (variable) <U3 'sst'

谢谢!

0 个答案:

没有答案