我正在尝试使用包含3个元素的索引集在(5, 10, 2)
轴上切割一个数组,例如形状0-th
。结果我得到了一个形状(1, 3, 10, 2)
的数组。在这种情况下添加虚拟维度的原因是什么?这对我来说似乎是一个糟糕的设计,因为使用大括号语法的普通索引并不能做到这一点。 np.compress
也有正确的行为。
答案 0 :(得分:2)
在NumPy中,返回数组的形状,无论是索引还是使用np.take
,都会受到传递的索引形状的影响。因此,例如,如果您使用2D索引数组索引1D数组,您将获得2D结果:
>>> x = np.array([9, 8, 7, 6, 5])
>>> i = np.array([[1, 3], [2, 4]])
>>> x[i]
array([[8, 6],
[7, 5]])
即使对于多维数组也是如此,除了将尾随尺寸添加到索引形状:
>>> x = np.random.rand(5, 4, 3)
>>> x[i].shape
(2, 2, 4, 3)
因此,如果您有一个形状为(1, 3)
的索引数组,那么该形状将被“压印”在结果上:
>>> x = np.random.rand(5, 10, 2)
>>> i = np.array([[1, 2, 3]])
>>> x[i].shape
(1, 3, 10, 2)
这相当于轴上的take
:
>>> x.take(i, axis=0).shape
(1, 3, 10, 2)
您的问题不包含任何代码,但我怀疑当您将索引复制并粘贴到take
时,您复制了额外的一对方括号:
>>> x[[1, 2, 3]].shape
(3, 10, 2)
>>> x.take([[1, 2, 3]], axis=0).shape
(1, 3, 10, 2)
索引数组的形状很重要,在使用numpy构造更复杂的表达式时非常有用。