dask.array.apply_gufunc具有多个不同形状的输出

时间:2019-04-17 12:50:28

标签: dask python-xarray

我正在尝试将ufunc应用于可广播的分块dask数组,该数组会产生几种不同形状的输出:

import dask.array as da # dask.__version__ is 1.2.0
import numpy as np

def func(A3, A2):
    return A3+A2, A2**2

A3 = da.from_array(np.random.randn(3,5,5), chunks=(3,2,2))
A2 = da.from_array(np.random.randn(  5,5), chunks=(  2,2))
ret = da.apply_gufunc(func, '(),()->(),()', A3, A2, output_dtypes=[float,float])

for r in ret:
    print(r)
    r.compute()

问题在于,假设ret中的两个输出都是形状(3,5,5),然后在.compute()上失败,而第二个输出的ValueError: axes don't match array失败了。应该是二维的。

在这种情况下如何使用apply_gufunc

注意::在这种情况下,我可能宁愿使用xarray.apply_ufunc,但不幸的是,它尚不具有多个输出(请参见here)。

1 个答案:

答案 0 :(得分:1)

以下是一种解决方法,可以帮助您:

def func(A3, A2):
    return A3+A2, (A2**2)[np.newaxis,:]