例如,cat
的形状为np.array([[1,2],[3,4]])[np.triu_indices(2)]
,是上三角条目的扁平列表。但是,如果我有一批2x2矩阵:
(3,)
我想获取每个矩阵的上三角索引,尝试的天真是:
foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)
但是,该对象实际上是foo[:,np.triu_indices(2)]
的形状(与我们可能期望的(30,2,3,2)
形状相反,如果我们批量提取了上面的三角形条目。
我们如何沿批处理维度广播元组索引?
答案 0 :(得分:2)
获取元组,并使用它们来索引最后两个暗淡-
r,c = np.triu_indices(2)
out = foo[:,r,c]
或者,带有Ellipsis
的一线式同时适用于3D
和2D
数组-
foo[(Ellipsis,)+np.triu_indices(2)]
它同样适用于2D
数组-
out = foo[r,c] # foo as 2D input array
3D阵列盒
我们还可以使用基于masking
的掩码-
foo[:,~np.tri(2,k=-1, dtype=bool)]
二维阵列盒
foo[~np.tri(2,k=-1, dtype=bool)]