用数组索引火炬张量

时间:2020-04-19 21:12:16

标签: python indexing pytorch tensor torch

我有以下火炬张量:

tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

和以下numpy数组:(如有必要,我可以将其转换为其他内容)

[1 0 1]

我想得到以下张量:

tensor([0.3, -0.5, 0.2])

即我希望numpy数组索引我的张量的每个子元素。最好不使用循环。

预先感谢

2 个答案:

答案 0 :(得分:2)

只需简单地将范围(len(index))用于第一维。

import torch

a = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

c = [1, 0, 1]


b = a[range(3),c]

print(b)

答案 1 :(得分:1)

您可能要使用torch.gather-“沿由dim指定的轴收集值。”

t = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])
idxs = np.array([1,0,1])

idxs = torch.from_numpy(idxs).long().unsqueeze(1)  
# or   torch.from_numpy(idxs).long().view(-1,1)
t.gather(1, idxs)
tensor([[ 0.3000],
        [-0.5000],
        [ 0.2000]])

在这里,您的索引是numpy数组,因此您必须将其转换为LongTensor。