我正在尝试提取形状为(3,32)的一个numpy数组的子集并对其进行数学运算,我尝试提取的数据是形状(3,9)的子集,并且此数据的范围起源于另一个数组大小(3)中包含的索引。例如,我有一个时域中三个通道的值的数据集,并将每个通道的最大值的索引提取到数组中
a = np.random.randint(20,size = (3,32))
a
array([[18, 3, 10, 6, 12, 1, 10, 8, 4, 11, 13, 14, 9, 9, 10, 2,
9, 0, 0, 16, 14, 19, 1, 19, 14, 19, 19, 2, 14, 0, 4, 18],
[ 9, 19, 2, 12, 0, 14, 18, 7, 3, 0, 7, 3, 12, 19, 4, 2,
5, 9, 2, 11, 15, 19, 16, 17, 3, 4, 17, 5, 6, 1, 2, 17],
[ 0, 11, 18, 8, 9, 2, 9, 15, 9, 6, 0, 8, 9, 16, 9, 6,
1, 19, 1, 9, 12, 8, 0, 0, 7, 15, 3, 14, 15, 8, 10, 19]])
b = np.argmax(a,1)
b
array([21, 1, 17], dtype=int64)
我的目标是从每个指定的索引中派生一个由三个值组成的新数组。例如,我想提取出来:
[21,22,23] from a[0]
[1,2,3] from a[1]
[17,18,19] from a[2]
全部放入大小为[3,3]的新数组
我已经能够使用循环来实现这一点,但是我怀疑有一种更有效的方式来产生没有循环的结果(速度对于此应用程序来说是个问题)。通过手动填充较小的矩阵,我已经能够实现类似的结果...
c = np.asarray([[1,2,3],[2,3,4],[3,4,5]])
a[np.arange(3)[:,None],c]
array([[ 3, 10, 6],
[ 2, 12, 0],
[ 8, 9, 2]])
但是,考虑到此应用程序的动态性质,我想编写它以便可以动态缩放(范围超出根索引的9个值,等等)。我只是不知道是否有这种方法。为了分割数组,我使用了类似于以下的语法...
a[np.arange(3)[:,None],b[:]:(b[:] + 2)]
产生错误消息,其本质是...
builtins.TypeError: only integer scalar arrays can be converted to a scalar index
答案 0 :(得分:1)
由于您说过不会溢出,因此棘手的事情就少了很多。通常,由于您有了起始索引,因此可以使用基本广播使用索引创建一个(n, 3)
形状数组,并使用take_along_axis
从原始数组中提取这些元素。
np.take_along_axis(a, b[:, None] + np.arange(3), axis=1)
array([[19, 1, 19],
[19, 2, 12],
[19, 1, 9]])