我有一个NxMx3
numpy数组dtype=object
。我还有一个函数f(a,b,c)
,它接受这个数组的最后一个轴中的三个元素并返回np.int32
。我的问题是如何将f
应用于NxMx3
数组以生成NxM
数组dtype=np.int32
?
我目前的解决方案是使用
newarr = np.fromfunction(lambda i,j: f(arr[i,j,0], arr[i,j,1], arr[i,j,2]),
arr.shape[:2], dtype=np.int)
虽然这比我希望的更加冗长。
答案 0 :(得分:3)
您可以使用vectorize
:
np.vectorize(f, otypes=[np.int32])(arr[:, :, 0], arr[:, :, 1], arr[:, :, 2])
这可以通过轴滚动和迭代来简化:
np.vectorize(f, otypes=[np.int32])(*np.rollaxis(arr, 2, 0))
或者,您可以使用dsplit
显式拆分数组:
np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3))[..., 0]
或
np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).reshape(arr.shape[:-1])
或
np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).squeeze()
但是,apply_along_axis
可能更简单:
np.apply_along_axis(lambda x: f(*x), 2, arr)