我有2个矩阵(data
和result
),result
是2D dask.array
,而data
是3D xarray.DataArray
,我必须做这样的计算:
var_idx = 0 # const value
result[i,j] = external_function(data[var_idx,:,i],data[var_idx,:,j])
我是python和dask的新手,但是经过数周的研究,下面的代码代表了我可以做的这项计算...
def func_block(block,block_info=None):
@jit(nogil=True)
def func_external(a,b):
# just an example
return np.max(np.multiply(a, b))
# result element location of data array
# to compute just lower triangular elements
block_info = block_info[0]
[(s_row,end_row),(s_col,end_col)]= block_info['array-location']
if s_col > s_row:
return block
it = np.nditer(block, flags=['multi_index'],op_flags=['readwrite'])
while not it.finished:
(r_idx,c_idx) = it.multi_index
row = r_idx+s_row
col = c_idx+s_col
if row > col:
it[0] = float(func_external(da_stacked_notnull[0,:,row],da_stacked_notnull[0,:,col]))
it.iternext()
return block
darr_zeros = da.zeros((grid_size,grid_size), chunks=(3000,3000))
darr_zeros = darr_zeros -1
dask_result = darr_zeros.map_blocks(func_block,chunks=(3000,3000),dtype=np.float16)
dask_result = xr.DataArray(dask_result.compute())
但是,我有一些问题:
1)如果我使用了所有需要的数据,则会收到Python内存错误。我想这是由于dask_result.compute()
导致的,如果我正确理解.compute()
返回一个numpy数组,但是我没有足够的内存来将所有结果存储在numpy数组中。如何使用dask数组来完成此操作?
2)所有线程最多使用每个内核的50%...,我认为这是由于GIL造成的,但是dask并没有改善它?可以重构它以获得更好的性能吗?
这是没有我真正需要的所有网格点的数据矩阵:
da_stacked_notnull
Out[1]:
<xarray.DataArray (variable: 1, time: 365, gridcell: 7230)>
array([[[-0.376704, -0.036332, ..., 27.715254, 26.863554],
[-0.465122, -0.152866, ..., 27.227764, 26.556808],
...,
[-0.724707, -0.520708, ..., 29.315022, 29.10007 ],
[-0.835325, -0.704899, ..., 29.425072, 29.086765]]], dtype=float32)
Coordinates:
* time (time) datetime64[ns] 2015-01-01 2015-01-02 ... 2015-12-31
* gridcell (gridcell) MultiIndex
- lon (gridcell) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
- lat (gridcell) float64 -59.5 -58.5 -57.5 -56.5 ... -32.5 -31.5 -30.5
* variable (variable) <U3 'sst'
谢谢!