numba中的多维numpy.expand_dims

时间:2019-11-05 10:42:04

标签: python numpy jit numba

由于numba中没有np.einsumnp.newaxis,我希望有一种便捷的方法可以对多个轴执行np.expand_dims,同时可以numba.njit进行该功能。我无法获得njit的两种解决方案:

def expand_dims(arr, axes):
    for ax in axes:
        arr = np.expand_dims(arr, ax)
    return arr

def expand_dims2(arr, axes):
    shape_list = list(arr.shape)
    for ax in axes:
        shape_list.insert(ax, 1)
    return arr.reshape(tuple(shape_list))

,其中axes是应创建的轴索引的可迭代项。

在能够njit的情况下,有没有一种很好的方法?否则,代码会很快变得混乱。

0 个答案:

没有答案