使用dask.array.core.map_blocks并行化3D numpy数组计算

时间:2016-06-22 21:14:14

标签: python numpy parallel-processing dask

我有一个3D numpy数组(尺寸:深度,纬度,经度),我正在尝试使用沿每个lat-lon点的深度轴的数据进行一些并行计算,到目前为止,我一直没有成功。我查看了dask.array.core.map_blocks的文档,但没有任何帮助。这就是我在做的事情:

N2_dask = da.from_array(N2_naned, chunks=(49, 32, (12, 12, 12, 12, 12, 12)))
zN2_dask = da.from_array(-zN2_agg[t], chunks=(49, 32, (12, 12, 12, 12, 12, 12)))
lat_dask = da.from_array(lat_agg, chunks=(32))
lon_dask = da.from_array(lon_agg, chunks=((12, 12, 12, 12, 12, 12)))

for j in range(len(lat_dask)):
    for i in range(len(lon_dask)):
        f = da.core.map_blocks(baroclinic.neutral_modes_from_N2_profile(
            zN2_dask[:, j, i], N2_dask[:, j, i], gsw.earth.f(lat_dask[j, i]), **kwargs))
        zphi, Rd, vd = f.compute()

其中baroclinic.neutral_modes_from_N2_profile是我的功能。我收到如下错误:

AssertionErrorTraceback (most recent call last)
<ipython-input-61-e58a7b54c470> in <module>()
      6         print zN2_dask[:, j, i]
      7         f = da.core.map_blocks(baroclinic.neutral_modes_from_N2_profile(
----> 8                 zN2_dask[:, j, i], N2_dask[:, j, i], gsw.earth.f(lat_dask[j, i]), **kwargs))
      9         zphi, Rd, vd = f.compute()

/home/takaya/.conda/envs/oceanmodes/lib/python2.7/site-packages/dask/array/core.pyc in __getitem__(self, index)
   1023             return self
   1024 
-> 1025         dsk, chunks = slice_array(out, self.name, self.chunks, index)
   1026 
   1027         return Array(merge(self.dask, dsk), out, chunks, dtype=self._dtype)

/home/takaya/.conda/envs/oceanmodes/lib/python2.7/site-packages/dask/array/slicing.pyc in slice_array(out_name, in_name, blockdims, index)
    134 
    135     # Pass down to next function
--> 136     dsk_out, bd_out = slice_with_newaxes(out_name, in_name, blockdims, index)
    137 
    138     bd_out = tuple(map(tuple, bd_out))

/home/takaya/.conda/envs/oceanmodes/lib/python2.7/site-packages/dask/array/slicing.pyc in slice_with_newaxes(out_name, in_name, blockdims, index)
    152 
    153     # Pass down and do work
--> 154     dsk, blockdims2 = slice_wrap_lists(out_name, in_name, blockdims, index2)
    155 
    156     # Insert ",0" into the key:  ('x', 2, 3) -> ('x', 0, 2, 0, 3)

/home/takaya/.conda/envs/oceanmodes/lib/python2.7/site-packages/dask/array/slicing.pyc in slice_wrap_lists(out_name, in_name, blockdims, index)
    183     shape = tuple(map(sum, blockdims))
    184     assert all(isinstance(i, (slice, list, int, long)) for i in index)
--> 185     assert len(blockdims) == len(index)
    186     for bd, i in zip(blockdims, index):
    187         check_index(i, sum(bd))

AssertionError: 

谁能告诉我为什么这会给我一个AssertionError?提前谢谢!

1 个答案:

答案 0 :(得分:0)

Chunks应该是一个整数元组,在整个数组中指定一个统一的chunkshape

chunks=(49, 32, 12)

或者它应该是一个元组元组,每个元组定义如何对每个维度进行分块。

chunks=((20, 20, 9), (8, 8, 8, 8), (12, 24, 24, 12))

在您的情况下,您似乎正在混合这些表格

chunks=(49, 32, (12, 12, 12, 12, 12, 12))

您可能打算使用以下方法之一:

chunks=(49, 32, 12)
chunks=((49,), (32,), (12, 12, 12, 12, 12, 12))

鉴于最后一个维度是统一分块(全部12个)我建议只使用(49, 32, 12)