np.take返回一个具有额外奇异维度的数组

时间:2017-11-12 13:19:29

标签: python numpy

我正在尝试使用包含3个元素的索引集在(5, 10, 2)轴上切割一个数组,例如形状0-th。结果我得到了一个形状(1, 3, 10, 2)的数组。在这种情况下添加虚拟维度的原因是什么?这对我来说似乎是一个糟糕的设计,因为使用大括号语法的普通索引并不能做到这一点。 np.compress也有正确的行为。

1 个答案:

答案 0 :(得分:2)

在NumPy中,返回数组的形状,无论是索引还是使用np.take,都会受到传递的索引形状的影响。因此,例如,如果您使用2D索引数组索引1D数组,您将获得2D结果:

>>> x = np.array([9, 8, 7, 6, 5])
>>> i = np.array([[1, 3], [2, 4]])
>>> x[i]
array([[8, 6],
       [7, 5]])

即使对于多维数组也是如此,除了将尾随尺寸添加到索引形状:

>>> x = np.random.rand(5, 4, 3)
>>> x[i].shape
(2, 2, 4, 3)

因此,如果您有一个形状为(1, 3)的索引数组,那么该形状将被“压印”在结果上:

>>> x = np.random.rand(5, 10, 2)
>>> i = np.array([[1, 2, 3]])
>>> x[i].shape
(1, 3, 10, 2)

这相当于轴上的take

>>> x.take(i, axis=0).shape
(1, 3, 10, 2)

您的问题不包含任何代码,但我怀疑当您将索引复制并粘贴到take时,您复制了额外的一对方括号:

>>> x[[1, 2, 3]].shape
(3, 10, 2)
>>> x.take([[1, 2, 3]], axis=0).shape
(1, 3, 10, 2)

索引数组的形状很重要,在使用numpy构造更复杂的表达式时非常有用。