用numpy减少轴

时间:2012-10-12 22:17:35

标签: python numpy

我有一个NxMx3 numpy数组dtype=object。我还有一个函数f(a,b,c),它接受​​这个数组的最后一个轴中的三个元素并返回np.int32。我的问题是如何将f应用于NxMx3数组以生成NxM数组dtype=np.int32

我目前的解决方案是使用

newarr = np.fromfunction(lambda i,j: f(arr[i,j,0], arr[i,j,1], arr[i,j,2]),
                          arr.shape[:2], dtype=np.int)

虽然这比我希望的更加冗长。

1 个答案:

答案 0 :(得分:3)

您可以使用vectorize

np.vectorize(f, otypes=[np.int32])(arr[:, :, 0], arr[:, :, 1], arr[:, :, 2])

这可以通过轴滚动和迭代来简化:

np.vectorize(f, otypes=[np.int32])(*np.rollaxis(arr, 2, 0))

或者,您可以使用dsplit显式拆分数组:

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3))[..., 0]

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).reshape(arr.shape[:-1])

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).squeeze()

但是,apply_along_axis可能更简单:

np.apply_along_axis(lambda x: f(*x), 2, arr)