索引错误:形状不匹配:索引数组无法与形状一起广播

时间:2019-08-02 03:59:46

标签: python numpy indexing

我正在尝试从张量中取出某些数据,但出现奇怪的错误。在这里,我将尝试生成错误:

a=np.random.randn(5, 10, 5, 5)
a[:, [1, 6], np.triu_indices(5, 0)[0], np.triu_indices(5, 0)[1]].shape

我收到此错误

  

形状不匹配:索引数组无法与形状一起广播

我什至没有广播!都是切片的东西。

我想要什么?保持第零个轴不变(获取所有内容),从第一个轴获取[1]和[6],仅通过使用上三角元素将最后两个轴从[5,5]更改为[15]。

1 个答案:

答案 0 :(得分:2)

我们需要将第二个轴索引数组扩展到2D,以使其相对于np.triu_indices下的索引形成外平面。因此,它为我们提供了2D数组的mxn网格,其中m是第二个轴索引数组的长度,而nnp.triu_indices的长度那些。因此,从本质上讲,整个解决方案将简化为这样的内容-

r,c = np.triu_indices(5, 0)
out = a[:, np.array([1, 6])[:,None], r, c]

或将该扩展版本作为列表输入,即-

out = a[:, [[1],[6]], r, c]

我们还可以使用基于np.tri/np.triu的掩码,在更大的数组上可能会更快,因为我们跳过了创建所有整数索引的操作,就像这样-

mask = ~np.tri(5, k=-1, dtype=bool)
out = a[:, np.array([1, 6])[:,None], mask]