如何沿批处理维度广播numpy索引?

时间:2019-09-25 14:06:11

标签: python numpy multidimensional-array numpy-broadcasting matrix-indexing

例如,cat的形状为np.array([[1,2],[3,4]])[np.triu_indices(2)],是上三角条目的扁平列表。但是,如果我有一批2x2矩阵:

(3,)

我想获取每个矩阵的上三角索引,尝试的天真是:

foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)

但是,该对象实际上是foo[:,np.triu_indices(2)] 的形状(与我们可能期望的(30,2,3,2)形状相反,如果我们批量提取了上面的三角形条目。

我们如何沿批处理维度广播元组索引?

1 个答案:

答案 0 :(得分:2)

获取元组,并使用它们来索引最后两个暗淡-

r,c = np.triu_indices(2)
out = foo[:,r,c]

或者,带有Ellipsis的一线式同时适用于3D2D数组-

foo[(Ellipsis,)+np.triu_indices(2)]

它同样适用于2D数组-

out = foo[r,c] # foo as 2D input array

掩盖方式

3D阵列盒

我们还可以使用基于masking的掩码-

foo[:,~np.tri(2,k=-1, dtype=bool)]

二维阵列盒

foo[~np.tri(2,k=-1, dtype=bool)]