鉴于我有一个方形矩阵的索引,例如:
idxs = np.array([[1, 1],
[0, 1]])
和一个大小相同的方阵矩阵(不一定与idxs
大小相同):
mats = array([[[ 0. , 0. ],
[ 0. , 0.5]],
[[ 1. , 0.3],
[ 1. , 1. ]]])
我想将idxs
中的每个索引替换为mats
中的相应矩阵,以获取:
array([[ 1. , 0.3, 1. , 0.3],
[ 1. , 1. , 1. , 1. ],
[ 0. , 0. , 1. , 0.3],
[ 0. , 0.5, 1. , 1. ]])
mats[idxs]
给了我一个嵌套版本:
array([[[[ 1. , 0.3],
[ 1. , 1. ]],
[[ 1. , 0.3],
[ 1. , 1. ]]],
[[[ 0. , 0. ],
[ 0. , 0.5]],
[[ 1. , 0.3],
[ 1. , 1. ]]]])
所以我尝试使用reshape
,但是'twas徒劳无功! mats[idxs].reshape(4,4)
返回:
array([[ 1. , 0.3, 1. , 1. ],
[ 1. , 0.3, 1. , 1. ],
[ 0. , 0. , 0. , 0.5],
[ 1. , 0.3, 1. , 1. ]])
如果有帮助,我发现skimage.util.view_as_blocks
与我需要的完全相反(它可以将我想要的结果转换为嵌套的mats[idxs]
形式)。
是否有(希望非常)快速的方法来做到这一点?对于应用程序,我的mats
仍然只有几个小矩阵,但我的idxs
将是一个最大为2 ^ 15的方阵,在这种情况下,我将替换超过一百万用于创建订单2 ^ 16的新矩阵的索引。
非常感谢你的帮助!
答案 0 :(得分:3)
我们使用这些索引索引到输入数组的第一个轴。要获得2D
输出,我们只需要置换轴并重新整形。因此,一种方法是使用np.transpose
/ np.swapaxes
和np.reshape
,就像这样 -
mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
示例运行 -
In [83]: mats
Out[83]:
array([[[1, 1],
[7, 1]],
[[6, 6],
[5, 8]],
[[7, 1],
[6, 0]],
[[2, 7],
[0, 4]]])
In [84]: idxs
Out[84]:
array([[2, 3],
[0, 3],
[1, 2]])
In [85]: mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
Out[85]:
array([[7, 1, 2, 7],
[6, 0, 0, 4],
[1, 1, 2, 7],
[7, 1, 0, 4],
[6, 6, 7, 1],
[5, 8, 6, 0]])
np.take
对重复索引的性能提升
对于重复索引,为了提高性能,我们最好通过np.take
索引axis=0
。让我们列出这些方法,并在idxs
有许多重复索引时将其计时。
功能定义 -
def simply_indexing_based(mats, idxs):
ncols = mats.shape[-1]*idxs.shape[-1]
return mats[idxs].swapaxes(1,2).reshape(-1,ncols)
def take_based(mats, idxs):np.take(mats,idxs,axis=0)
ncols = mats.shape[-1]*idxs.shape[-1]
return np.take(mats,idxs,axis=0).swapaxes(1,2).reshape(-1,ncols)
运行时测试 -
In [156]: mats = np.random.randint(0,9,(10,2,2))
In [157]: idxs = np.random.randint(0,10,(1000,1000))
# This ensures many repeated indices
In [158]: out1 = simply_indexing_based(mats, idxs)
In [159]: out2 = take_based(mats, idxs)
In [160]: np.allclose(out1, out2)
Out[160]: True
In [161]: %timeit simply_indexing_based(mats, idxs)
10 loops, best of 3: 41.2 ms per loop
In [162]: %timeit take_based(mats, idxs)
10 loops, best of 3: 27.3 ms per loop
因此,我们看到 1.5x+
的整体改善。
为了了解np.take
的改进情况,让我们单独编制索引部分 -
In [168]: %timeit mats[idxs]
10 loops, best of 3: 22.8 ms per loop
In [169]: %timeit np.take(mats,idxs,axis=0)
100 loops, best of 3: 8.88 ms per loop
对于这些数据,其 2.5x+
。还不错!