火炬索引

时间:2019-03-03 00:43:47

标签: python pytorch

我有张量[[1,2],[4,5],[7,8]]和张量为[0,1,0]的张量。

我想将它们应用于第二维,以便它返回:[1,5,8]。

我应该怎么做?

谢谢!

2 个答案:

答案 0 :(得分:2)

 export default cache as typeof MyCache;

答案 1 :(得分:1)

假设您的输出是[1、5、7]:

一种解决方案是将维度0的所有索引和维度1的所需索引的张量组合在一起。

tensor = torch.tensor([[1,2],[4,5],[7,8]])
indices = torch.tensor([0,1,0])
output = tensor[torch.arange(0, tensor.size[0]), indices]

输出:

tensor([1, 5, 7])