我遇到了一些在我看起来像Numpy切片中不一致行为的东西。具体来说,请考虑以下示例:
import numpy as np
a = np.arange(9).reshape(3,3) # a 2d numpy array
y = np.array([1,2,2]) # vector that will be used to index the array
b = a[np.arange(len(a)),y] # a vector (what I want)
c = a[:,y] # a matrix ??
我想获得一个向量,使第i个元素为a[i,y[i]]
。我尝试了两件事(上面b
和c
)并且对b
和c
不一样感到惊讶......实际上一个是矢量而另一个是矩阵!我的印象是:
是"所有元素"的简写。但显然意思更微妙。
经过反复试验,我现在稍微理解了差异(b == np.diag(c)
),但希望澄清为什么它们不同,究竟使用:
意味着什么,以及如何了解何时使用这两种情况。
谢谢!
答案 0 :(得分:2)
在不了解广播的情况下,很难理解高级索引(使用列表或数组)。
In [487]: a=np.arange(9).reshape(3,3)
In [488]: idx = np.array([1,2,2])
带有(3,)和(3,)生成形状(3,)结果的索引:
In [489]: a[np.arange(3),idx]
Out[489]: array([1, 5, 8])
带(3,1)和(3,)的索引,结果为(3,3)
In [490]: a[np.arange(3)[:,None],idx]
Out[490]:
array([[1, 2, 2],
[4, 5, 5],
[7, 8, 8]])
切片:
基本上做同样的事情。有细微差别,但这里也是一样。
In [491]: a[:,idx]
Out[491]:
array([[1, 2, 2],
[4, 5, 5],
[7, 8, 8]])
ix_
做同样的事情,转换(3,)& (3,)至(3,1)和(1,3):
In [492]: np.ix_(np.arange(3),idx)
Out[492]:
(array([[0],
[1],
[2]]), array([[1, 2, 2]]))
广播的金额可能有助于形象化这两种情况:
In [495]: np.arange(3)*10+idx
Out[495]: array([ 1, 12, 22])
In [496]: np.sum(np.ix_(np.arange(3)*10,idx),axis=0)
Out[496]:
array([[ 1, 2, 2],
[11, 12, 12],
[21, 22, 22]])
答案 1 :(得分:0)
传递
时np.arange(len(a)), y
您可以将结果视为您传递的压缩元素的所有索引对。在这种情况下,按np.arange(len(a))
和y
np.arange(len(a))
# [0, 1, 2]
y
# [1, 2, 2]
有效地采用元素:(0,1),(1,2)和(2,2)。
print(a[0, 1], a[1, 2], a[2, 2]) # 0th, 1st, 2nd elements from each indexer
# 1 5 8
在第二种情况下,沿第一个维度拍摄整个切片。 (在冒号之前没有任何内容。)所以这是沿着第0轴的所有元素 。然后使用y
指定您希望每行包含第1,第2和第2个元素。 (0索引。)
正如您所指出的,鉴于切片的各个元素是等价的,结果可能看起来有点不直观:
a[:] == a[np.arange(len(a))]
和
a[:y] == a[:y]
但是,NumPy advanced indexing关心索引时传递的数据结构的类型(元组,整数等)。事情很快就会变得毛茸茸。
背后的细节是这样的:首先考虑所有 NumPy索引为一般形式x[obj]
,其中obj
是对你传递的任何内容的评估。 NumPy"表现如何"取决于对象obj
的类型:
当选择对象obj为a时,将触发高级索引 非元组序列对象,ndarray(数据类型为integer或bool), 或具有至少一个序列对象或ndarray(数据类型的元组)的元组 整数或布尔)。 ... 高级索引的定义意味着x [(1,2,3),]是 从根本上不同于x [(1,2,3)]。后者相当于 x [1,2,3]将触发基本选择而前者将触发 触发高级索引。一定要明白为什么会这样。
在你的第一个案例中,obj = np.arange(len(a)),y
,一个符合上面粗体法案的元组。这会触发高级索引并强制执行上述行为。
至于第二种情况,[:,y]
当至少有一个切片(:),省略号(...)或np.newaxis时 索引(或者数组的维度比高级更多 索引),那么行为可能会更复杂。 就像 连接每个高级索引元素的索引结果。
证明:
# Concatenate the indexing result for each advanced index element.
np.vstack((a[0, y], a[1, y], a[2, y]))