我正在尝试从三维矩阵中提取一个二维矩阵,最后一个维度具有来自三维矩阵的最后一维的值。例如,如果 尺寸P [2,2,3] =
[
[[5, 1, 5], [9, 9, 4]],
[[0, 9, 8], [8, 6, 8]]
]
什么是索引矩阵以获得out矩阵
[[1, 9],[0, 8]]
其中1是第一行第一列的第二个元素,9是第一行第二列的第一个元素,0是第二行第一列的第一个元素,8是第三行的第3个元素第二排第二列?
我的想法是,对于每一列,我有不同的分数。我想为每列检索一个我知道索引的不同分数。
我对Numpy中的高级索引感到有些困惑,我不是自己想出来的。谢谢!
答案 0 :(得分:2)
我假设有一个索引数组可以索引到最后一个轴。我们称之为idx
。对于问题中给定文本的给定样本,它将是 -
idx = np.array([[1,0],[0,2]])
具体来说,这是从引用文本中提取的:
1是第一行第一列的第二个元素,9是 第一行第二列的第一个元素,0是第一列 第二行第一列的元素,第8列是第三元素 第二行第二列
要解决这个问题,我们将使用带np.ogrid
的开放网格来索引输入数组的前两个轴 -
m,n = idx.shape
I,J = np.ogrid[:m,:n]
out = A[I,J,idx]
示例运行 -
In [57]: A
Out[57]:
array([[[5, 1, 5],
[9, 9, 4]],
[[0, 9, 8],
[8, 6, 8]]])
In [59]: idx = np.array([[1,0],[0,2]])
In [60]: m,n = idx.shape
In [61]: I,J = np.ogrid[:m,:n]
In [62]: A[I,J,idx]
Out[62]:
array([[1, 9],
[0, 8]])