我正在尝试从张量中取出某些数据,但出现奇怪的错误。在这里,我将尝试生成错误:
a=np.random.randn(5, 10, 5, 5)
a[:, [1, 6], np.triu_indices(5, 0)[0], np.triu_indices(5, 0)[1]].shape
我收到此错误
形状不匹配:索引数组无法与形状一起广播
我什至没有广播!都是切片的东西。
我想要什么?保持第零个轴不变(获取所有内容),从第一个轴获取[1]和[6],仅通过使用上三角元素将最后两个轴从[5,5]更改为[15]。
答案 0 :(得分:2)
我们需要将第二个轴索引数组扩展到2D
,以使其相对于np.triu_indices
下的索引形成外平面。因此,它为我们提供了2D
数组的mxn
网格,其中m
是第二个轴索引数组的长度,而n
是np.triu_indices
的长度那些。因此,从本质上讲,整个解决方案将简化为这样的内容-
r,c = np.triu_indices(5, 0)
out = a[:, np.array([1, 6])[:,None], r, c]
或将该扩展版本作为列表输入,即-
out = a[:, [[1],[6]], r, c]
我们还可以使用基于np.tri/np.triu
的掩码,在更大的数组上可能会更快,因为我们跳过了创建所有整数索引的操作,就像这样-
mask = ~np.tri(5, k=-1, dtype=bool)
out = a[:, np.array([1, 6])[:,None], mask]