通过2d索引访问3d张量

时间:2019-09-01 02:44:57

标签: pytorch

我正在尝试通过2d矩阵访问3d张量矩阵。返回值也应该是二维矩阵。以下是我尝试实现的目标。

dim1 = 2
dim2 = 3
dim3 = 4
source = torch.FloatTensor(dim1, dim2, dim3)
source.normal_()

check = zn > 0.0
index = torch.argmax(check, dim=0)

for i in range(dim2):
    for j in range(dim3):
        ret[i, j] = source[index[i, j], i, j]

1 个答案:

答案 0 :(得分:0)

我认为您正在寻找torch.gather,可能需要permute张量来满足`gather的要求。