由于numba中没有np.einsum
和np.newaxis
,我希望有一种便捷的方法可以对多个轴执行np.expand_dims
,同时可以numba.njit
进行该功能。我无法获得njit
的两种解决方案:
def expand_dims(arr, axes):
for ax in axes:
arr = np.expand_dims(arr, ax)
return arr
def expand_dims2(arr, axes):
shape_list = list(arr.shape)
for ax in axes:
shape_list.insert(ax, 1)
return arr.reshape(tuple(shape_list))
,其中axes
是应创建的轴索引的可迭代项。
在能够njit
的情况下,有没有一种很好的方法?否则,代码会很快变得混乱。