用python中的2d数组汇总ndarray

时间:2018-06-25 13:27:55

标签: python numpy multidimensional-array scipy

我想使用2d数组dat中包含的索引来总结3d数组idx

请考虑以下示例。对于dat[:, :, i]上的每个边距,我想根据某个索引idx计算中位数。所需的输出(out)是一个2d数组,其行记录索引,列记录边距。以下代码有效,但效率不高。有什么建议吗?

import numpy as np
dat = np.arange(12).reshape(2, 2, 3)
idx = np.array([[0, 0], [1, 2]])

out = np.empty((3, 3))
for i in np.unique(idx):
    out[i,] = np.median(dat[idx==i], axis = 0)
print(out)

输出:

[[ 1.5  2.5  3.5]
 [ 6.   7.   8. ]
 [ 9.  10.  11. ]]

1 个答案:

答案 0 :(得分:1)

为了更好地可视化问题,我将数组的2x2维称为行和列,将3维维称为深度。我将沿第3维的矢量称为“像素”(像素的长度为3),将沿前两个维的平面称为“通道”。

您的循环正在累积由掩码idx == i选择的一组像素,并获取该组中每个通道的中值。结果是一个Nx3数组,其中N是您拥有的不同倾斜数。

有一天,generalized ufuncs在numpy中将无处不在,而np.median就是这样的功能。那天,您将可以使用reduceat magic 1 做类似的事情

unq, ind = np.unique(idx, return_inverse=True)
np.median.reduceat(dat.reshape(-1, dat.shape[-1]), np.r_[0, np.where(np.diff(unq[ind]))[0]+1])

1 有关特定魔术类型的更多信息,请参见Applying operation to unevenly split portions of numpy array

由于当前无法实现,因此可以改用scipy.ndimage.median。该版本允许您计算数组中一组标记区域的中值,而这正是idx所具有的。此方法假定您的索引数组包含N个密集打包的值,所有这些值都在range(N)中。否则,整形操作将无法正常工作。

如果不是这种情况,请先转换idx

_, ind = np.unique(idx, return_inverse=True)
idx = ind.reshape(idx.shape)

OR

idx = np.unique(idx, return_inverse=True)[1].reshape(idx.shape)

由于实际上是为每个区域和通道计算单独的中位数,因此您将需要为每个通道设置一组标签。充实idx,以为每个频道设置一组不同的索引:

chan = dat.shape[-1]
offset = idx.max() + 1
index = np.stack([idx + i * offset for i in range(chan)], axis=-1)

现在index在每个通道中定义了一组相同的区域,您可以在scipy.ndimage.median中使用它们:

out = scipy.ndimage.median(dat, index, index=range(offset * chan)).reshape(chan, offset).T

输入标签必须从零到offset * chan密集地包装,以使index=range(offset * chan)正常工作,并且reshape操作要具有正确数量的元素。最后的转置只是标签排列方式的人工产物。

以下是完整产品以及结果的IDEOne demo

import numpy as np
from scipy.ndimage import median

dat = np.arange(12).reshape(2, 2, 3)
idx = np.array([[0, 0], [1, 2]])

def summarize(dat, idx):
    idx = np.unique(idx, return_inverse=True)[1].reshape(idx.shape)
    chan = dat.shape[-1]
    offset = idx.max() + 1
    index = np.stack([idx + i * offset for i in range(chan)], axis=-1)
    return median(dat, index, index=range(offset * chan)).reshape(chan, offset).T

print(summarize(dat, idx))