在PyTorch中使用张量索引多维张量

时间:2018-08-30 08:15:47

标签: pytorch tensor

我有以下代码:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

我有一个多维索引b,想用它来选择a中的单个单元格。如果b不是张量,我可以做:

a[1,1,1,1]

哪个返回正确的单元格,但是:

a[b]

不起作用,因为它只选择了a[1]四次。

我该怎么做?谢谢

2 个答案:

答案 0 :(得分:4)

您可以使用chunkb分成4个,然后使用成块的b来索引所需的特定元素:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

关于它的好处是可以轻松地将其推广到a的任何尺寸,您只需要使卡盘的数量等于a的尺寸即可。

答案 1 :(得分:4)

更优雅(更简单)的解决方案可能是将b强制转换为元组:

a[tuple(b)]
Out[10]: tensor(5.)

我很想知道它如何与“常规” numpy一起工作,并找到了相关的文章对此进行了很好的解释here