numpy数组中看似不一致的切片行为

时间:2017-10-06 19:53:03

标签: python numpy indexing

我遇到了一些在我看起来像Numpy切片中不一致行为的东西。具体来说,请考虑以下示例:

import numpy as np
a = np.arange(9).reshape(3,3)   # a 2d numpy array
y = np.array([1,2,2])           # vector that will be used to index the array

b = a[np.arange(len(a)),y]      # a vector (what I want)
c = a[:,y]                      # a matrix ??

我想获得一个向量,使第i个元素为a[i,y[i]]。我尝试了两件事(上面bc)并且对bc不一样感到惊讶......实际上一个是矢量而另一个是矩阵!我的印象是:是"所有元素"的简写。但显然意思更微妙。

经过反复试验,我现在稍微理解了差异(b == np.diag(c)),但希望澄清为什么它们不同,究竟使用:意味着什么,以及如何了解何时使用这两种情况。

谢谢!

2 个答案:

答案 0 :(得分:2)

在不了解广播的情况下,很难理解高级索引(使用列表或数组)。

In [487]: a=np.arange(9).reshape(3,3)
In [488]: idx = np.array([1,2,2])

带有(3,)和(3,)生成形状(3,)结果的索引:

In [489]: a[np.arange(3),idx]
Out[489]: array([1, 5, 8])

带(3,1)和(3,)的索引,结果为(3,3)

In [490]: a[np.arange(3)[:,None],idx]
Out[490]: 
array([[1, 2, 2],
       [4, 5, 5],
       [7, 8, 8]])

切片:基本上做同样的事情。有细微差别,但这里也是一样。

In [491]: a[:,idx]
Out[491]: 
array([[1, 2, 2],
       [4, 5, 5],
       [7, 8, 8]])

ix_做同样的事情,转换(3,)& (3,)至(3,1)和(1,3):

In [492]: np.ix_(np.arange(3),idx)
Out[492]: 
(array([[0],
        [1],
        [2]]), array([[1, 2, 2]]))

广播的金额可能有助于形象化这两种情况:

In [495]: np.arange(3)*10+idx
Out[495]: array([ 1, 12, 22])
In [496]: np.sum(np.ix_(np.arange(3)*10,idx),axis=0)
Out[496]: 
array([[ 1,  2,  2],
       [11, 12, 12],
       [21, 22, 22]])

答案 1 :(得分:0)

传递

np.arange(len(a)), y

您可以将结果视为您传递的压缩元素的所有索引。在这种情况下,按np.arange(len(a))y

编制索引
np.arange(len(a))
# [0, 1, 2]
y
# [1, 2, 2]

有效地采用元素:(0,1),(1,2)和(2,2)。

print(a[0, 1], a[1, 2], a[2, 2])  # 0th, 1st, 2nd elements from each indexer
# 1 5 8

在第二种情况下,沿第一个维度拍摄整个切片。 (在冒号之前没有任何内容。)所以这是沿着第0轴的所有元素 。然后使用y指定您希望每行包含第1,第2和第2个元素。 (0索引。)

正如您所指出的,鉴于切片的各个元素是等价的,结果可能看起来有点不直观:

a[:] == a[np.arange(len(a))]

a[:y] == a[:y]

但是,NumPy advanced indexing关心索引时传递的数据结构的类型(元组,整数等)。事情很快就会变得毛茸茸。

背后的细节是这样的:首先考虑所有 NumPy索引为一般形式x[obj],其中obj是对你传递的任何内容的评估。 NumPy"表现如何"取决于对象obj的类型:

  

当选择对象obj为a时,将触发高级索引   非元组序列对象,ndarray(数据类型为integer或bool),   或具有至少一个序列对象或ndarray(数据类型的元组)的元组   整数或布尔)。   ...   高级索引的定义意味着x [(1,2,3),]是   从根本上不同于x [(1,2,3)]。后者相当于   x [1,2,3]将触发基本选择而前者将触发   触发高级索引。一定要明白为什么会这样。

在你的第一个案例中,obj = np.arange(len(a)),y,一个符合上面粗体法案的元组。这会触发高级索引并强制执行上述行为。

至于第二种情况,[:,y]

  

当至少有一个切片(:),省略号(...)或np.newaxis时   索引(或者数组的维度比高级更多   索引),那么行为可能会更复杂。 就像   连接每个高级索引元素的索引结果。

证明:

# Concatenate the indexing result for each advanced index element.
np.vstack((a[0, y], a[1, y], a[2, y]))